Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def check_validity(self, sample: tuple) -> Union[None, int]: # pragma: no cover
Which sub_sample to be returned.
Return `None` if invalid.
"""
raise NotImplementedError(f"Child class must supply `check_validity` function.")
raise NotImplementedError("Child class must supply `check_validity` function.")

def unify(self, sample: tuple) -> Any:
index = self.check_validity(sample)
Expand Down
4 changes: 2 additions & 2 deletions packages/pipeline/src/pyearthtools/pipeline/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,13 @@ def set_parent_record(
def parent_pipeline(self) -> Pipeline:
"""Get parent pipeline of this `PipelineIndex`, will not include self"""
if self._partial_parent is None:
raise ValueError(f"Parent record has not been set with `set_parent_record`, cannot get parent pipeline")
raise ValueError("Parent record has not been set with `set_parent_record`, cannot get parent pipeline")
return self._partial_parent(*self._steps)

def as_pipeline(self) -> Pipeline:
"""Get `PipelineIndex` as full pipeline, will include self"""
if self._partial_parent is None:
raise ValueError(f"Parent record has not been set with `set_parent_record`, cannot get step as pipeline")
raise ValueError("Parent record has not been set with `set_parent_record`, cannot get step as pipeline")
return self._partial_parent(*self._steps, self)

@abstractmethod
Expand Down
2 changes: 1 addition & 1 deletion packages/pipeline/src/pyearthtools/pipeline/iterators.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from functools import cached_property
from abc import ABCMeta, abstractmethod

from typing import Any, Callable, Generator, Hashable, Iterable, Optional, Union, Set
from typing import Any, Callable, Generator, Hashable, Iterable, Optional, Union
from pathlib import Path

import numpy as np
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def filter(self, sample: da.Array):
If sample contains nan's
"""
if not bool(da.array(list(da.isnan(sample))).any()):
raise PipelineFilterException(sample, f"Data contained nan's.")
raise PipelineFilterException(sample, "Data contained nan's.")


class DropAllNan(daskFilter):
Expand All @@ -86,7 +86,7 @@ def filter(self, sample: da.Array):
If sample contains nan's
"""
if not bool(da.array(list(da.isnan(sample))).all()):
raise PipelineFilterException(sample, f"Data contained all nan's.")
raise PipelineFilterException(sample, "Data contained all nan's.")


class DropValue(daskFilter):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def __init__(
self.shape_attempt = shape_attempt

if isinstance(flatten_dims, int) and flatten_dims < 1:
raise ValueError(f"'flatten_dims' cannot be smaller than 1.")
raise ValueError("'flatten_dims' cannot be smaller than 1.")
self.flatten_dims = flatten_dims

def _prod_shape(self, shape):
Expand All @@ -205,8 +205,8 @@ def _prod_shape(self, shape):

def _configure_shape_attempt(self) -> tuple[Union[str, int], ...]:
if not self._fillshape or not self.shape_attempt:
raise RuntimeError(f"Cannot find shape to unflatten with, try flattening first.")
if not "..." in self.shape_attempt:
raise RuntimeError("Cannot find shape to unflatten with, try flattening first.")
if "..." not in self.shape_attempt:
return self.shape_attempt

shape_attempt = list(self.shape_attempt)
Expand Down Expand Up @@ -235,15 +235,15 @@ def apply(self, data: da.Array) -> da.Array:

def undo(self, data: da.Array) -> da.Array:
if self._unflattenshape is None:
raise RuntimeError(f"Shape not set, therefore cannot undo")
raise RuntimeError("Shape not set, therefore cannot undo")

def _unflatten(data, shape):
while len(data.shape) > len(shape):
shape = (data[-len(shape)], *shape)
return data.reshape(shape)

if self.flatten_dims is None:
raise RuntimeError(f"`flatten_dims` was not set, and this set hasn't been used. Cannot Unflatten.")
raise RuntimeError("`flatten_dims` was not set, and this set hasn't been used. Cannot Unflatten.")

data_shape = data.shape
parsed_shape = data_shape[: -1 * min(1, (self.flatten_dims - 1))] if len(data_shape) > 1 else data_shape
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def filter(self, sample: np.ndarray):
If sample contains nan's
"""
if not bool(np.array(list(np.isnan(sample))).any()):
raise PipelineFilterException(sample, f"Data contained nan's.")
raise PipelineFilterException(sample, "Data contained nan's.")


class DropAllNan(NumpyFilter):
Expand Down Expand Up @@ -86,7 +86,7 @@ def filter(self, sample: np.ndarray):
If sample contains nan's
"""
if not bool(np.array(list(np.isnan(sample))).all()):
raise PipelineFilterException(sample, f"Data contained all nan's.")
raise PipelineFilterException(sample, "Data contained all nan's.")


class DropValue(NumpyFilter):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def _prod_shape(self, shape):
def _configure_shape_attempt(self) -> tuple[Union[str, int], ...]:
if not self._fillshape or not self.shape_attempt:
raise RuntimeError("Cannot find shape to unflatten with, try flattening first.")
if not "..." in self.shape_attempt:
if "..." not in self.shape_attempt:
return self.shape_attempt

shape_attempt = list(self.shape_attempt)
Expand Down Expand Up @@ -219,15 +219,15 @@ def apply(self, data: np.ndarray) -> np.ndarray:

def undo(self, data: np.ndarray) -> np.ndarray:
if self._unflattenshape is None:
raise RuntimeError(f"Shape not set, therefore cannot undo")
raise RuntimeError("Shape not set, therefore cannot undo")

def _unflatten(data, shape):
# while len(data.shape) > len(shape):
# shape = (data[-len(shape)], *shape)
return data.reshape(shape)

if self.flatten_dims is None:
raise RuntimeError(f"`flatten_dims` was not set, and this set hasn't been used. Cannot Unflatten.")
raise RuntimeError("`flatten_dims` was not set, and this set hasn't been used. Cannot Unflatten.")

data_shape = data.shape
# parsed_shape = data_shape[: -1 * min(1, (self.flatten_dims - 1))] if len(data_shape) > 1 else []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.


from typing import Any, Optional
from typing import Optional

import numpy as np

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def split(self, sample: np.ndarray) -> tuple[np.ndarray]:
def join(self, sample: tuple[np.ndarray]) -> np.ndarray:
"""Join `sample` together, recovering initial shape"""
if self.axis_size is None:
raise RuntimeError(f"`axis_size` not set.")
raise RuntimeError("`axis_size` not set.")

data = np.concatenate(sample, axis=0)
shape = data.shape
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.


from typing import Any, TypeVar, Union, Optional
from typing import TypeVar, Union, Optional

import xarray as xr

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.


from typing import TypeVar, Optional
from typing import TypeVar

import pandas as pd
import xarray as xr
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.


from typing import TypeVar, Optional
from typing import TypeVar

import xarray as xr

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def _check(self, sample: xr.Dataset):
sample = sample[self.variables]

if not bool(np.array(list(np.isnan(sample).values())).any()):
raise PipelineFilterException(sample, f"Data contained nan's.")
raise PipelineFilterException(sample, "Data contained nan's.")


class DropAllNan(XarrayFilter):
Expand Down Expand Up @@ -109,7 +109,7 @@ def _check(self, sample: xr.Dataset):
sample = sample[self.variables]

if not bool(np.array(list(np.isnan(sample).values())).all()):
raise PipelineFilterException(sample, f"Data contained all nan's.")
raise PipelineFilterException(sample, "Data contained all nan's.")


class DropValue(XarrayFilter):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import numpy as np
import xarray as xr
import dask
import pytest


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from pyearthtools.pipeline.operations.dask.compute import Compute

import xarray as xr
import dask
import numpy as np


Expand Down
4 changes: 1 addition & 3 deletions packages/pipeline/tests/pipeline/test_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@


from __future__ import annotations
from typing import Any

import pytest

Expand All @@ -24,10 +23,9 @@

import pyearthtools.data

from pyearthtools.pipeline import Pipeline, exceptions, modifications
from pyearthtools.pipeline import Pipeline
from pyearthtools.pipeline.modifications.idx_modification import SequenceRetrieval

from pyearthtools.pipeline import Operation
from pyearthtools.data import Index


Expand Down
4 changes: 1 addition & 3 deletions packages/pipeline/tests/pipeline/test_temporal_idx_mod.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@


from __future__ import annotations
from typing import Any

import pytest

Expand All @@ -24,10 +23,9 @@

import pyearthtools.data

from pyearthtools.pipeline import Pipeline, exceptions
from pyearthtools.pipeline import Pipeline
from pyearthtools.pipeline.modifications.idx_modification import TemporalRetrieval

from pyearthtools.pipeline import Operation
from pyearthtools.data import Index


Expand Down
2 changes: 0 additions & 2 deletions packages/utils/src/pyearthtools/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,7 @@
from __future__ import annotations

import ast
import base64
import builtins # Explicitly use builtins.set as 'set' will be shadowed by a function
import json
import os
import site
import sys
Expand Down
2 changes: 2 additions & 0 deletions packages/utils/src/pyearthtools/utils/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,5 @@

from pyearthtools.utils.data import converter
from pyearthtools.utils.data.tesselator import Tesselator

__all__ = ["converter", "Tesselator"]
8 changes: 2 additions & 6 deletions packages/utils/src/pyearthtools/utils/data/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

from __future__ import annotations

import dask
import dask.array as da
import importlib
import importlib.util
import json
Expand Down Expand Up @@ -326,10 +328,6 @@ def _rebuild_arrays(self, numpy_array: np.ndarray, xarray_distill: dict) -> xr.D
full_coords.pop("Variables", None)

ar = np
try:
import dask.array as da
except (ImportError, ModuleNotFoundError):
ar = da

for i in range(numpy_array.shape[xarray_distill["dims"].index("Variables")]):
data = ar.take(numpy_array, i, axis=xarray_distill["dims"].index("Variables"))
Expand Down Expand Up @@ -436,7 +434,6 @@ def convert_from_xarray(
(dask.array.Array | tuple[dask.array.Array, ...]):
Generated array/s from Dataset/s
"""
import dask.array as da

self._set_records(data, replace=replace)

Expand All @@ -456,7 +453,6 @@ def convert(dataset: xr.DataArray | xr.Dataset) -> da.Array:
raise TypeError(f"Unable to convert data of {type(data)} to `da.array`")

def convert_to_xarray(self, data, pop: bool = True):
import dask.array as da

# if isinstance(data, da.Array):
# data = data.compute()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -245,12 +245,12 @@ def stitch(
if (np.array(self.stride) > np.array(self.kernel_size)).any():
raise TesselatorException(
"Stride is larger than kernel size, incoming data cannot be complete.\n"
f"Set stride smaller, or use another tesselator "
"Set stride smaller, or use another tesselator "
)

if self._layout is None:
raise TesselatorException(
f"This tesselator has not be used to patch, therefore it has not recorded the layout, and cannot be used to stitch."
"This tesselator has not be used to patch, therefore it has not recorded the layout, and cannot be used to stitch."
)

if not input_data.shape == self._post_shape and not self.ignore_difference:
Expand All @@ -261,7 +261,7 @@ def stitch(

all_patches = []
for input_patch in input_data:
all_patches.append(_patching.reorder.reorder(input_patch, data_format, "TCHW"))
all_patches.append(_patching.reorder(input_patch, data_format, "TCHW"))

all_patches = np.array(all_patches)

Expand All @@ -274,7 +274,7 @@ def stitch(
full_prediction = _patching.subset.center(full_prediction, self._initial_shape[-2:])
except TesselatorException:
warnings.warn(
f"Could not trim to initial_shape, if 'padding' is None, an incomplete patch set is made. Padding with nans",
"Could not trim to initial_shape, if 'padding' is None, an incomplete patch set is made. Padding with nans",
TesselatorWarning,
)
pad_width = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,16 @@

"""
Data processing tools for use by the [Tesselator][pyearthtools.utils.data.Tesselator]

"""
from ._reorder import reorder

DEFAULT_FORMAT_SUBSET: str = "...HW"

DEFAULT_FORMAT_PATCH_ORGANISE: str = "P...HW"
DEFAULT_FORMAT_PATCH: str = "RP...HW"
DEFAULT_FORMAT_PATCH_AFTER: str = "...HW"

from . import patches, reorder, subset
from . import patches, subset # noqa

__all__ = ["patches", "reorder", "subset"]
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@
DEFAULT_FORMAT_PATCH_AFTER,
DEFAULT_FORMAT_PATCH_ORGANISE,
)
from pyearthtools.utils.data.tesselator._patching.reorder import reorder
from pyearthtools.utils.data.tesselator._patching.subset import center, cut_center
from pyearthtools.utils.data.tesselator._patching._reorder import reorder
from pyearthtools.utils.data.tesselator._patching.subset import cut_center


def factors(value: int) -> list[list[int, int]]:
Expand Down Expand Up @@ -273,7 +273,7 @@ def find_dim_expand(length, kernel, stride, rounding_flip: bool = False):
)
# padd_width.append(((kernel_size[0] - stride[0])//2,(kernel_size[1] - stride[1])//2))
if (np.array(padd_width) < 0).any():
raise ValueError(f"Padding width cannot be negative, try setting `padding` to None")
raise ValueError("Padding width cannot be negative, try setting `padding` to None")

if padding == "constant":
kwargs["constant_values"] = kwargs.get("constant_values", np.nan)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import numpy as np

from pyearthtools.utils.data.tesselator._patching import DEFAULT_FORMAT_SUBSET
from pyearthtools.utils.data.tesselator._patching.reorder import move_to_end, reorder
from pyearthtools.utils.data.tesselator._patching._reorder import move_to_end, reorder
from pyearthtools.utils.exceptions import TesselatorException


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,11 @@ def dynamic_import(object_path: str) -> Callable | ModuleType:
pass

if not object_path:
raise ImportError(f"object_path cannot be empty")
raise ImportError("object_path cannot be empty")
try:
return importlib.import_module(object_path)
except ModuleNotFoundError:
object_path_list = object_path.split(".")
return getattr(dynamic_import(".".join(object_path_list[:-1])), object_path_list[-1])
except ValueError as e:
except ValueError:
raise ModuleNotFoundError("End of module definition reached")
Loading