diff --git a/packages/data/src/pyearthtools/data/transforms/normalisation/default.py b/packages/data/src/pyearthtools/data/transforms/normalisation/default.py index 34019014..b4df1784 100644 --- a/packages/data/src/pyearthtools/data/transforms/normalisation/default.py +++ b/packages/data/src/pyearthtools/data/transforms/normalisation/default.py @@ -44,7 +44,7 @@ def open_file(file: str | tuple | dict): data = open_files(file) if isinstance(data, (xr.Dataset, xr.DataArray)): data = get_default_transforms()(data) - data = pyearthtools.data.transforms.coordinates.drop("time", ignore_missing=True)(data) + data = pyearthtools.data.transforms.coordinates.Drop("time", ignore_missing=True)(data) return data @@ -61,7 +61,7 @@ def under_func(*args, **kwargs): return under_func -class normaliser: +class Normaliser: def __init__( self, index: pyearthtools.data.AdvancedTimeIndex, @@ -214,7 +214,7 @@ def get_aggregation( # ) aggregated_data = get_and_print( - lambda: pyearthtools.data.transforms.aggregation.over(method, dims)( + lambda: pyearthtools.data.transforms.aggregation.over(method=method, dimension=dims)( self.index.series( **retrieval_args, transforms=transforms, diff --git a/packages/data/src/pyearthtools/data/transforms/normalisation/normalise.py b/packages/data/src/pyearthtools/data/transforms/normalisation/normalise.py index d0969dd9..91a7f448 100644 --- a/packages/data/src/pyearthtools/data/transforms/normalisation/normalise.py +++ b/packages/data/src/pyearthtools/data/transforms/normalisation/normalise.py @@ -23,18 +23,18 @@ xr.set_options(keep_attrs=True) -from pyearthtools.data.transforms.normalisation.default import normaliser, open_file +from pyearthtools.data.transforms.normalisation.default import Normaliser, open_file from pyearthtools.data.transforms.transform import FunctionTransform, Transform -class Normalise(normaliser): +class Normalise(Normaliser): """ Normalise incoming data. Either call this class, or get attribute for specific normalisation strategy """ - @functools.wraps(normaliser.__init__) + @functools.wraps(Normaliser.__init__) def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/packages/data/src/pyearthtools/data/transforms/normalisation/unnormalise.py b/packages/data/src/pyearthtools/data/transforms/normalisation/unnormalise.py index af8b2047..38d4b262 100644 --- a/packages/data/src/pyearthtools/data/transforms/normalisation/unnormalise.py +++ b/packages/data/src/pyearthtools/data/transforms/normalisation/unnormalise.py @@ -21,16 +21,16 @@ import numpy as np import xarray as xr -from pyearthtools.data.transforms.normalisation.default import normaliser, open_file +from pyearthtools.data.transforms.normalisation.default import Normaliser, open_file from pyearthtools.data.transforms.transform import FunctionTransform, Transform xr.set_options(keep_attrs=True) -class Unnormalise(normaliser): +class Unnormalise(Normaliser): """Unnormalise Incoming Data""" - @functools.wraps(normaliser) + @functools.wraps(Normaliser) def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/packages/data/src/pyearthtools/data/transforms/transform.py b/packages/data/src/pyearthtools/data/transforms/transform.py index a2348749..7d29f879 100644 --- a/packages/data/src/pyearthtools/data/transforms/transform.py +++ b/packages/data/src/pyearthtools/data/transforms/transform.py @@ -240,8 +240,8 @@ def apply(self, dataset: XR_TYPES | tuple[XR_TYPES] | list[XR_TYPES] | dict[str, def __call__(self, dataset: XR_TYPES | tuple[XR_TYPES] | list[XR_TYPES] | dict[str, XR_TYPES]) -> XR_TYPES | Any: # Do not try to transform empty datasets - if not dataset: - return dataset + if dataset is None: + return None for transform in self._transforms: dataset = transform(dataset) diff --git a/packages/data/tests/data/transform/normalisation/test_default.py b/packages/data/tests/data/transform/normalisation/test_default.py new file mode 100644 index 00000000..b43a7651 --- /dev/null +++ b/packages/data/tests/data/transform/normalisation/test_default.py @@ -0,0 +1,86 @@ +import pyearthtools.data.transforms.normalisation +from pyearthtools.data.transforms.normalisation import default +from pyearthtools.data.time import Petdt +import pyearthtools.data.indexes +import xarray as xr +import numpy as np +import pytest + +sample_da = xr.DataArray(coords={"latitude": [1,2,3,4], + "longitude": [1,2,3], + "time": ["2023-02"] + }, + data=np.ones((4,3,1))) + +sample_ds = xr.Dataset(coords={"latitude": [1,2,3,4], "longitude": [1,2,3], "time": ["2023-02"]}, + data_vars={"temperature": sample_da}) + + +def test_open_file(monkeypatch): + + monkeypatch.setattr(pyearthtools.data.transforms.normalisation.default, + 'open_files', + lambda x: sample_da) + + result = default.open_file("pretend_filename.nc") + assert result is not None + + +def test_Normaliser(monkeypatch): + + monkeypatch.setattr("pyearthtools.data.indexes.AdvancedTimeIndex.__abstractmethods__", set()) + + data_interval = "day" + ati = pyearthtools.data.indexes.AdvancedTimeIndex(data_interval) + monkeypatch.setattr(ati, "get", lambda x: sample_da) + start = Petdt("2023-02") + end = Petdt("2023-03") + + n = default.Normaliser(ati, start, end, "month") + n.check_init_args() + + result = n.get_average("temperature") + assert result == 1 + + r_mean, r_std = n.get_deviation("temperature") + assert r_mean == 1 + assert r_std == 0 + + r_anomaly = n.get_anomaly("temperature") + assert r_anomaly is not None + + # FIXME: Need to update the whole test creation to be a time-aware dataset + # r_range = n.get_range("temperature") + # assert r_range["temperature"]["max"] == 1 + # assert r_range["temperature"]["min"] == 1 + + result = n.none + assert result is not None + +def test_Normaliser_errors(monkeypatch): + + monkeypatch.setattr("pyearthtools.data.indexes.AdvancedTimeIndex.__abstractmethods__", set()) + + data_interval = "day" + ati = pyearthtools.data.indexes.AdvancedTimeIndex(data_interval) + monkeypatch.setattr(ati, "get", lambda x: sample_da) + start = Petdt("2023-02") + end = Petdt("2023-03") + + n = default.Normaliser(ati, start, end, "month") + + with pytest.raises(NotImplementedError): + n.function() + + + not_implemented = [n.log, n.anomaly, n.deviation, n.deviation_spatial, n.range] + for ni in not_implemented: + with pytest.raises(NotImplementedError): + ni() + + + + + + +