From 8d2e45a52f821c6eb59973d0e63bdbc9c3f2c9ff Mon Sep 17 00:00:00 2001 From: Gemma Mason Date: Wed, 30 Apr 2025 15:14:10 +1200 Subject: [PATCH 01/13] rename Squish to Squeeze --- notebooks/tutorial/Accessing_ERA5_Data.ipynb | 4 +- notebooks/tutorial/CNN_model_training.ipynb | 8 ++-- notebooks/tutorial/Data_Pipelines.ipynb | 4 +- notebooks/tutorial/Downloading_ERA5.ipynb | 2 +- .../tutorial/Interfacing_to_Data_at_NCI.ipynb | 2 +- notebooks/tutorial/MultipleSources.ipynb | 2 +- .../tutorial/Working_with_Climate_Data.ipynb | 2 +- .../pipeline/assets/ComplexPipeline.svg | 2 +- .../pipeline/assets/pipeline_example.svg | 2 +- .../docs/documentation/pipeline/available.md | 4 +- .../docs/documentation/pipeline/creating.md | 38 +++++++++++-------- .../pipeline/details/operation.md | 10 ++--- old_docs/docs/documentation/pipeline/index.md | 13 ++++--- packages/pipeline/assets/pipeline_example.svg | 2 +- .../pipeline/operations/dask/__init__.py | 2 +- .../pipeline/operations/dask/reshape.py | 8 ++-- .../pipeline/operations/numpy/__init__.py | 2 +- .../pipeline/operations/numpy/reshape.py | 10 ++--- 18 files changed, 62 insertions(+), 55 deletions(-) diff --git a/notebooks/tutorial/Accessing_ERA5_Data.ipynb b/notebooks/tutorial/Accessing_ERA5_Data.ipynb index 82186fe1..eaca8e4f 100644 --- a/notebooks/tutorial/Accessing_ERA5_Data.ipynb +++ b/notebooks/tutorial/Accessing_ERA5_Data.ipynb @@ -2332,7 +2332,7 @@ ], "metadata": { "kernelspec": { - "display_name": "pet-tutorials", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -2346,7 +2346,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.11" + "version": "3.12.9" } }, "nbformat": 4, diff --git a/notebooks/tutorial/CNN_model_training.ipynb b/notebooks/tutorial/CNN_model_training.ipynb index 8beff0e3..16e2b6f8 100644 --- a/notebooks/tutorial/CNN_model_training.ipynb +++ b/notebooks/tutorial/CNN_model_training.ipynb @@ -782,10 +782,10 @@ " * Input data: (-1, 1) tuple is used to select a single timestamp at -1hrs\n", " * Target data: (6, 1) tuple is used to select a single timestamp at +5hrs ahead. \n", "\n", - "### Aditional Steps:\n", + "### Additional Steps:\n", "1. Export to a NumPy array.\n", "2. Rearrange the axes of the NumPy array.\n", - "3. Remove dimensions of size 1 via a \"squish\" (equivalent to a NumPy \"squeeze\") operation." + "3. Remove dimensions of size 1 via a \"squeeze\" operation." ] }, { @@ -1347,7 +1347,7 @@ " ),\n", " pyearthtools.pipeline.operations.xarray.conversion.ToNumpy(),\n", " pyearthtools.pipeline.operations.numpy.reshape.Rearrange(\"c t h w -> t c h w\"),\n", - " pyearthtools.pipeline.operations.numpy.reshape.Squish(axis=0),\n", + " pyearthtools.pipeline.operations.numpy.reshape.Squeeze(axis=0),\n", ")\n", "data_preparation" ] @@ -8350,7 +8350,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.7" + "version": "3.12.9" } }, "nbformat": 4, diff --git a/notebooks/tutorial/Data_Pipelines.ipynb b/notebooks/tutorial/Data_Pipelines.ipynb index aa97a35a..599c25f9 100644 --- a/notebooks/tutorial/Data_Pipelines.ipynb +++ b/notebooks/tutorial/Data_Pipelines.ipynb @@ -116,7 +116,7 @@ " # These methods will be explained when we create a pipeline for machine learning. \n", " # pyearthtools.pipeline.operations.xarray.reshape.CoordinateFlatten('level'),\n", " # pyearthtools.pipeline.operations.xarray.conversion.ToNumpy(),\n", - " # pyearthtools.pipeline.operations.numpy.reshape.Squish(1),\n", + " # pyearthtools.pipeline.operations.numpy.reshape.Squeeze(1),\n", ")\n" ] }, @@ -2081,7 +2081,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.7" + "version": "3.12.9" } }, "nbformat": 4, diff --git a/notebooks/tutorial/Downloading_ERA5.ipynb b/notebooks/tutorial/Downloading_ERA5.ipynb index 1438c9e1..5fcedd05 100644 --- a/notebooks/tutorial/Downloading_ERA5.ipynb +++ b/notebooks/tutorial/Downloading_ERA5.ipynb @@ -106,7 +106,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.7" + "version": "3.12.9" } }, "nbformat": 4, diff --git a/notebooks/tutorial/Interfacing_to_Data_at_NCI.ipynb b/notebooks/tutorial/Interfacing_to_Data_at_NCI.ipynb index d053b530..35f789fc 100644 --- a/notebooks/tutorial/Interfacing_to_Data_at_NCI.ipynb +++ b/notebooks/tutorial/Interfacing_to_Data_at_NCI.ipynb @@ -1246,7 +1246,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.7" + "version": "3.12.9" } }, "nbformat": 4, diff --git a/notebooks/tutorial/MultipleSources.ipynb b/notebooks/tutorial/MultipleSources.ipynb index 4a80497f..5e21085b 100644 --- a/notebooks/tutorial/MultipleSources.ipynb +++ b/notebooks/tutorial/MultipleSources.ipynb @@ -4534,7 +4534,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.7" + "version": "3.12.9" } }, "nbformat": 4, diff --git a/notebooks/tutorial/Working_with_Climate_Data.ipynb b/notebooks/tutorial/Working_with_Climate_Data.ipynb index b7ac0275..3df0395e 100644 --- a/notebooks/tutorial/Working_with_Climate_Data.ipynb +++ b/notebooks/tutorial/Working_with_Climate_Data.ipynb @@ -5483,7 +5483,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.7" + "version": "3.12.9" } }, "nbformat": 4, diff --git a/old_docs/docs/documentation/pipeline/assets/ComplexPipeline.svg b/old_docs/docs/documentation/pipeline/assets/ComplexPipeline.svg index c03b837a..49fac8f1 100644 --- a/old_docs/docs/documentation/pipeline/assets/ComplexPipeline.svg +++ b/old_docs/docs/documentation/pipeline/assets/ComplexPipeline.svg @@ -211,7 +211,7 @@ Squish_2f9471ae-d492-4692-8c44-46bcb75a5526 -reshape.Squish +reshape.Squeeze diff --git a/old_docs/docs/documentation/pipeline/assets/pipeline_example.svg b/old_docs/docs/documentation/pipeline/assets/pipeline_example.svg index c2b4a3b7..8ca41d66 100644 --- a/old_docs/docs/documentation/pipeline/assets/pipeline_example.svg +++ b/old_docs/docs/documentation/pipeline/assets/pipeline_example.svg @@ -91,7 +91,7 @@ Squish_7b413dbf-6cbe-4cf1-a375-17fc36cc9eda -Squish +Squeeze diff --git a/old_docs/docs/documentation/pipeline/available.md b/old_docs/docs/documentation/pipeline/available.md index bbbfc157..01ce30e3 100644 --- a/old_docs/docs/documentation/pipeline/available.md +++ b/old_docs/docs/documentation/pipeline/available.md @@ -47,7 +47,7 @@ Here are the default operations included with `pyearthtools.pipeline`, accessibl | filters | Filter data when iterating | `DropAnyNan`, `DropAllNan`, `DropValue`, `Shape` | | join | Combine tuples of `np.ndarrays` | `Stack`, `VStack`, `HStack`, `Concatenate` | | normalisation | Normalise arrays | `Anomaly`, `Deviation`, `Division`, `Evaluated` | -| reshape | Reshape numpy array | `Rearrange`, `Squish`, `Expand`, `Flatten`, `SwapAxis` | +| reshape | Reshape numpy array | `Rearrange`, `Squeeze`, `Expand`, `Flatten`, `SwapAxis` | | select | Select elements from array | `Select`, `Slice` | | split | Split numpy arrays into tuples | `OnAxis`, `OnSlice`, `VSplit`, `HSplit` | | values | Modify values of arrays | `FillNan`, `MaskValue`, `ForceNormalised` | @@ -67,7 +67,7 @@ Here are the default operations included with `pyearthtools.pipeline`, accessibl | filters | Filter data when iterating | `DropAnyNan`, `DropAllNan`, `DropValue`, `Shape` | | join | Combine tuples of `np.ndarrays` | `Stack`, `VStack`, `HStack`, `Concatenate` | | normalisation | Normalise arrays | `Anomaly`, `Deviation`, `Division`, `Evaluated` | -| reshape | Reshape numpy array | `Rearrange`, `Squish`, `Expand`, `Flatten`, `SwapAxis` | +| reshape | Reshape numpy array | `Rearrange`, `Squeeze`, `Expand`, `Flatten`, `SwapAxis` | | select | Select elements from array | `Select`, `Slice` | | split | Split numpy arrays into tuples | `OnAxis`, `OnSlice`, `VSplit`, `HSplit` | | values | Modify values of arrays | `FillNan`, `MaskValue`, `ForceNormalised` | diff --git a/old_docs/docs/documentation/pipeline/creating.md b/old_docs/docs/documentation/pipeline/creating.md index 55035306..5c56cf51 100644 --- a/old_docs/docs/documentation/pipeline/creating.md +++ b/old_docs/docs/documentation/pipeline/creating.md @@ -48,27 +48,29 @@ pyearthtools.pipeline.Pipeline( ( pyearthtools.data.archive.ERA5(['tcwv', 'skt', 'sp']), pyearthtools.pipeline.operations.Transforms( - apply = pyearthtools.pipeline.operations.transform.AddCoordinates(('latitude', 'longitude'))), + apply=pyearthtools.pipeline.operations.transform.AddCoordinates(('latitude', 'longitude'))), pyearthtools.pipeline.operations.xarray.Sort(('var_latitude', 'var_longitude', 'tcwv', 'skt', 'sp'))), ( - pyearthtools.data.archive.ERA5(['t', 'u', 'v'], level_value = [1,50,150,250,400,600,750,900,1000]), + pyearthtools.data.archive.ERA5(['t', 'u', 'v'], level_value=[1, 50, 150, 250, 400, 600, 750, 900, 1000]), pyearthtools.pipeline.operations.xarray.Sort(('t', 'u', 'v')) ), ( ( - pyearthtools.data.archive.ERA5( - ['mtnlwrf', 'msdwswrf', 'msdwlwrf', 'mtpr', 'mslhf', 'msshf', 'mtnswrf', 'mtdwswrf', 'msnswrf', 'msnlwrf'], - transforms = pyearthtools.data.transforms.derive( - mtupswrf = 'mtnswrf - mtdwswrf', - msupswrf = 'msnswrf - msdwswrf', - msuplwrf = 'msnlwrf - msdwlwrf', - drop = True + pyearthtools.data.archive.ERA5( + ['mtnlwrf', 'msdwswrf', 'msdwlwrf', 'mtpr', 'mslhf', 'msshf', 'mtnswrf', 'mtdwswrf', 'msnswrf', + 'msnlwrf'], + transforms=pyearthtools.data.transforms.derive( + mtupswrf='mtnswrf - mtdwswrf', + msupswrf='msnswrf - msdwswrf', + msuplwrf='msnlwrf - msdwlwrf', + drop=True ) ), - pyearthtools.data.archive.ERA5('!accumulate[period:"6 hours"]:tp>tp_accum') + pyearthtools.data.archive.ERA5('!accumulate[period:"6 hours"]:tp>tp_accum') ), pyearthtools.pipeline.operations.xarray.Merge(), - pyearthtools.pipeline.operations.xarray.Sort(('mslhf', 'msshf', 'msuplwrf', 'msupswrf', 'mtnlwrf', 'mtpr', 'mtupswrf', 'tp_accum')) + pyearthtools.pipeline.operations.xarray.Sort( + ('mslhf', 'msshf', 'msuplwrf', 'msupswrf', 'mtnlwrf', 'mtpr', 'mtupswrf', 'tp_accum')) ), ( pyearthtools.data.archive.ERA5(['mtdwswrf', 'z_surface', 'lsm', 'ci']), @@ -78,15 +80,19 @@ pyearthtools.pipeline.Pipeline( pyearthtools.pipeline.operations.xarray.reshape.CoordinateFlatten('level', skip_missing=True), pyearthtools.pipeline.operations.xarray.reshape.Dimensions(('time', 'latitude', 'longitude')), pyearthtools.pipeline.operations.Transforms( - apply = pyearthtools.data.transform.coordinates.pad(coordinates = {'latitude': 1, 'longitude': 1}, mode = 'wrap') + pyearthtools.data.transforms.values.fill(coordinates = ['latitude', 'longitude'], direction = 'forward') + pyearthtools.data.transforms.interpolation.like(pipe['2020-01-01T00'], drop_coords = 'time')), + apply=pyearthtools.data.transform.coordinates.pad(coordinates={'latitude': 1, 'longitude': 1}, + mode='wrap') + pyearthtools.data.transforms.values.fill( + coordinates=['latitude', 'longitude'], + direction='forward') + pyearthtools.data.transforms.interpolation.like(pipe['2020-01-01T00'], + drop_coords='time')), pyearthtools.pipeline.operations.xarray.Merge(), - pyearthtools.pipeline.operations.xarray.Sort(order, safe = True), + pyearthtools.pipeline.operations.xarray.Sort(order, safe=True), pyearthtools.pipeline.operations.xarray.conversion.ToDask(), - pyearthtools.pipeline.operations.dask.reshape.Squish(axis=1), - pyearthtools.pipeline.modifications.Cache('temp', pattern_kwargs = dict(extension = 'npy')) - pyearthtools.pipeline.modifications.TemporalRetrieval(((-6, 1), (6,1))), + pyearthtools.pipeline.operations.dask.reshape.Squeeze(axis=1), + pyearthtools.pipeline.modifications.Cache('temp', pattern_kwargs=dict(extension='npy')) +pyearthtools.pipeline.modifications.TemporalRetrieval(((-6, 1), (6, 1))), ) ``` ![Pipeline Graph](./assets/ComplexPipeline.svg) diff --git a/old_docs/docs/documentation/pipeline/details/operation.md b/old_docs/docs/documentation/pipeline/details/operation.md index ea37a116..6139863e 100644 --- a/old_docs/docs/documentation/pipeline/details/operation.md +++ b/old_docs/docs/documentation/pipeline/details/operation.md @@ -60,7 +60,7 @@ inverse = example_operation.T ## Example Implementation -Here is the implementation of `numpy.reshape.Squish`, to flatten a one element axis of an array +Here is the implementation of `numpy.reshape.Squeeze`, to flatten a one element axis of an array ```python from typing import Union, Optional, Any @@ -68,16 +68,16 @@ import numpy as np from pyearthtools.pipeline.operation import Operation -class Squish(Operation): +class Squeeze(Operation): """ - Operation to Squish one Dimensional axis at 'axis' location + Operation to Squeeze one Dimensional axis at 'axis' location """ _override_interface = ["Delayed", "Serial"] # Which parallel interfaces to use in order of priority. - _interface_kwargs = {"Delayed": {"name": "Squish"}} + _interface_kwargs = {"Delayed": {"name": "Squeeze"}} def __init__(self, axis: Union[tuple[int, ...], int]) -> None: - """Squish Dimension of Data + """Squeeze Dimension of Data Args: axis (Union[tuple[int, ...], int]): diff --git a/old_docs/docs/documentation/pipeline/index.md b/old_docs/docs/documentation/pipeline/index.md index eae7f475..e6b9956f 100644 --- a/old_docs/docs/documentation/pipeline/index.md +++ b/old_docs/docs/documentation/pipeline/index.md @@ -15,20 +15,21 @@ Here is an example pipeline, for use with PanguWeather. import pyearthtools.data import pyearthtools.pipeline - data_preperation = pyearthtools.pipeline.Pipeline( ( pyearthtools.data.archive.ERA5(['msl', '10u', '10v', '2t']), - pyearthtools.data.archive.ERA5(['z', 'q', 't', 'u', 'v'], level_value = [50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, 1000]) + pyearthtools.data.archive.ERA5(['z', 'q', 't', 'u', 'v'], + level_value=[50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, 1000]) ), pyearthtools.pipeline.operations.xarray.Merge(), pyearthtools.pipeline.operations.xarray.Sort(['msl', '10u', '10v', '2t', 'z', 'q', 't', 'u', 'v']), pyearthtools.pipeline.operations.Transforms( - apply = pyearthtools.data.transforms.coordinates.standard_longitude(type = '0-360') + pyearthtools.data.transforms.coordinates.ReIndex(level = 'reversed') - ), - pyearthtools.pipeline.operations.xarray.reshape.CoordinateFlatten(coordinate = 'level'), + apply=pyearthtools.data.transforms.coordinates.standard_longitude( + type='0-360') + pyearthtools.data.transforms.coordinates.ReIndex(level='reversed') + ), + pyearthtools.pipeline.operations.xarray.reshape.CoordinateFlatten(coordinate='level'), pyearthtools.pipeline.operations.xarray.conversion.ToNumpy(), - pyearthtools.pipeline.operations.numpy.reshape.Squish(axis = 1), + pyearthtools.pipeline.operations.numpy.reshape.Squeeze(axis=1), ) ``` diff --git a/packages/pipeline/assets/pipeline_example.svg b/packages/pipeline/assets/pipeline_example.svg index c2b4a3b7..8ca41d66 100644 --- a/packages/pipeline/assets/pipeline_example.svg +++ b/packages/pipeline/assets/pipeline_example.svg @@ -91,7 +91,7 @@ Squish_7b413dbf-6cbe-4cf1-a375-17fc36cc9eda -Squish +Squeeze diff --git a/packages/pipeline/src/pyearthtools/pipeline/operations/dask/__init__.py b/packages/pipeline/src/pyearthtools/pipeline/operations/dask/__init__.py index 7ac0c908..7795eded 100644 --- a/packages/pipeline/src/pyearthtools/pipeline/operations/dask/__init__.py +++ b/packages/pipeline/src/pyearthtools/pipeline/operations/dask/__init__.py @@ -24,7 +24,7 @@ | filters | Filter data when iterating | `DropAnyNan`, `DropAllNan`, `DropValue`, `Shape` | | join | Combine tuples of `np.ndarrays` | `Stack`, `VStack`, `HStack`, `Concatenate` | | normalisation | Normalise arrays | `Anomaly`, `Deviation`, `Division`, `Evaluated` | -| reshape | Reshape numpy array | `Rearrange`, `Squish`, `Expand`, `Flatten`, `SwapAxis` | +| reshape | Reshape numpy array | `Rearrange`, `Squeeze`, `Expand`, `Flatten`, `SwapAxis` | | select | Select elements from array | `Select`, `Slice` | | split | Split numpy arrays into tuples | `OnAxis`, `OnSlice`, `VSplit`, `HSplit` | | values | Modify values of arrays | `FillNan`, `MaskValue`, `ForceNormalised` | diff --git a/packages/pipeline/src/pyearthtools/pipeline/operations/dask/reshape.py b/packages/pipeline/src/pyearthtools/pipeline/operations/dask/reshape.py index 58165f1e..e5f0fd6e 100644 --- a/packages/pipeline/src/pyearthtools/pipeline/operations/dask/reshape.py +++ b/packages/pipeline/src/pyearthtools/pipeline/operations/dask/reshape.py @@ -110,16 +110,16 @@ def undo_func(self, data: Union[np.ndarray, da.Array]): return self._rearrange(data, pattern) -class Squish(DaskOperation): +class Squeeze(DaskOperation): """ - Operation to Squish one Dimensional axis at 'axis' location + Operation to Squeeze one Dimensional axis at 'axis' location """ _override_interface = ["Serial"] - _numpy_counterpart = "reshape.Squish" + _numpy_counterpart = "reshape.Squeeze" def __init__(self, axis: Union[tuple[int, ...], int]) -> None: - """Squish Dimension of Data + """Squeeze Dimension of Data Args: axis (Union[tuple[int, ...], int]): diff --git a/packages/pipeline/src/pyearthtools/pipeline/operations/numpy/__init__.py b/packages/pipeline/src/pyearthtools/pipeline/operations/numpy/__init__.py index 0f06e0cb..f240e6a2 100644 --- a/packages/pipeline/src/pyearthtools/pipeline/operations/numpy/__init__.py +++ b/packages/pipeline/src/pyearthtools/pipeline/operations/numpy/__init__.py @@ -22,7 +22,7 @@ | filters | Filter data when iterating | `DropAnyNan`, `DropAllNan`, `DropValue`, `Shape` | | join | Combine tuples of `np.ndarrays` | `Stack`, `VStack`, `HStack`, `Concatenate` | | normalisation | Normalise arrays | `Anomaly`, `Deviation`, `Division`, `Evaluated` | -| reshape | Reshape numpy array | `Rearrange`, `Squish`, `Expand`, `Flatten`, `SwapAxis` | +| reshape | Reshape numpy array | `Rearrange`, `Squeeze`, `Expand`, `Flatten`, `SwapAxis` | | select | Select elements from array | `Select`, `Slice` | | split | Split numpy arrays into tuples | `OnAxis`, `OnSlice`, `VSplit`, `HSplit` | | values | Modify values of arrays | `FillNan`, `MaskValue`, `ForceNormalised` | diff --git a/packages/pipeline/src/pyearthtools/pipeline/operations/numpy/reshape.py b/packages/pipeline/src/pyearthtools/pipeline/operations/numpy/reshape.py index 794b388e..328e5de0 100644 --- a/packages/pipeline/src/pyearthtools/pipeline/operations/numpy/reshape.py +++ b/packages/pipeline/src/pyearthtools/pipeline/operations/numpy/reshape.py @@ -94,21 +94,21 @@ def undo_func(self, data: np.ndarray): return self._rearrange(data, pattern) -class Squish(Operation): +class Squeeze(Operation): """ - Operation to Squish one Dimensional axis at 'axis' location + Operation to Squeeze one-Dimensional axes at 'axis' location """ _override_interface = ["Delayed", "Serial"] - _interface_kwargs = {"Delayed": {"name": "Squish"}} + _interface_kwargs = {"Delayed": {"name": "Squeeze"}} def __init__(self, axis: Union[tuple[int, ...], int]) -> None: - """Squish Dimension of Data + """Squeeze Dimension of Data, removing dimensions of length 1. Args: axis (Union[tuple[int, ...], int]): - Axis to squish at + Axis to squeeze at """ super().__init__( split_tuples=True, From d64da90f7000d42188536d430b5ece0202862b32 Mon Sep 17 00:00:00 2001 From: Gemma Mason Date: Mon, 5 May 2025 13:54:00 +1200 Subject: [PATCH 02/13] remove early return line that was pre-empting the skip functionality. --- .../src/pyearthtools/pipeline/operations/numpy/reshape.py | 1 - 1 file changed, 1 deletion(-) diff --git a/packages/pipeline/src/pyearthtools/pipeline/operations/numpy/reshape.py b/packages/pipeline/src/pyearthtools/pipeline/operations/numpy/reshape.py index 328e5de0..f855d872 100644 --- a/packages/pipeline/src/pyearthtools/pipeline/operations/numpy/reshape.py +++ b/packages/pipeline/src/pyearthtools/pipeline/operations/numpy/reshape.py @@ -69,7 +69,6 @@ def __init__( self.skip = skip def _rearrange(self, data: np.ndarray, pattern: str, catch=True): - return einops.rearrange(data, pattern, **self.rearrange_kwargs) try: return einops.rearrange(data, pattern, **self.rearrange_kwargs) From 3998f1a9b22b496ae9c028783fe36a79b33c5b85 Mon Sep 17 00:00:00 2001 From: Gemma Mason Date: Wed, 7 May 2025 13:01:26 +1200 Subject: [PATCH 03/13] tests for Rearrange and Squeeze --- .../operations/numpy/test_numpy_reshape.py | 84 +++++++++++++++++++ 1 file changed, 84 insertions(+) create mode 100644 packages/pipeline/tests/operations/numpy/test_numpy_reshape.py diff --git a/packages/pipeline/tests/operations/numpy/test_numpy_reshape.py b/packages/pipeline/tests/operations/numpy/test_numpy_reshape.py new file mode 100644 index 00000000..a5ff9ab4 --- /dev/null +++ b/packages/pipeline/tests/operations/numpy/test_numpy_reshape.py @@ -0,0 +1,84 @@ +# Copyright Commonwealth of Australia, Bureau of Meteorology 2025. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pyearthtools.pipeline.operations.numpy import reshape + +import numpy as np +import pytest + +def test_Rearrange(): + r = reshape.Rearrange('h l w -> h w l') + h_dim = 2 + l_dim = 10 + w_dim = 20 + random_array = np.random.randn(h_dim, l_dim, w_dim) + output = r.apply_func(random_array) + undo_output = r.undo_func(output) + + assert output.shape == (h_dim, w_dim, l_dim), "Check dimensions rearranged correctly." + assert np.all(undo_output.shape == random_array.shape), "Check undo successfully reverses." + +def test_Rearrange_explicit_reverse(): + """The undo can be detected automatically or given explicitly. This version tests what happens when it is + given explicitly.""" + r = reshape.Rearrange('h l w -> l w h', reverse_rearrange='l w h -> h l w') + h_dim = 1 + l_dim = 12 + w_dim = 6 + random_array = np.random.randn(h_dim, l_dim, w_dim) + output = r.apply_func(random_array) + undo_output = r.undo_func(output) + + assert np.all(undo_output == random_array), "Check explicit undo successfully reverses." + +def test_Rearrange_skip(): + """Check that the operation can be skipped, if the skip flag is True.""" + r = reshape.Rearrange('h l w -> l w h', skip=True) + h_dim = 1 + l_dim = 12 + wrong_shape_array = np.random.randn(h_dim, l_dim) + output = r.apply_func(wrong_shape_array) + + assert np.all(output == wrong_shape_array), "Check skip can leave array unchanged." + +def test_Rearrange_not_skip(): + """Check that the operation is not skipped, if the skip flag is not set to True.""" + r = reshape.Rearrange('h l w -> l w h') + h_dim = 1 + l_dim = 12 + wrong_shape_array = np.random.randn(h_dim, l_dim) + with pytest.raises(Exception): + r.apply_func(wrong_shape_array) + + +def test_Squeeze(): + s = reshape.Squeeze(axis=(2, 3)) + random_array = np.random.randn(8, 8, 1, 1, 2, 1) + assert s.apply_func(random_array).shape == (8, 8, 2, 1), "Squeeze only the correct axes." + +def test_undo_Squeeze(): + s = reshape.Squeeze(axis=(2, 3)) + random_array = np.random.randn(8, 8, 1, 1, 2, 1) + output = s.apply_func(random_array) + undo_output = s.undo_func(output) + assert random_array.shape == undo_output.shape, "Check Squeeze can correctly undo itself." + +def test_Squeeze_error(): + """Check we get an error if we try to squeeze an axis not of length 1.""" + s = reshape.Squeeze(axis=(1, 3)) # Note axis 1, below, is not of length 1. + random_array = np.random.randn(8, 8, 1, 1, 2, 1) + with pytest.raises(Exception): + s.apply_func(random_array) + + From db489e696f5e967d4757d326b5a6e4ba9e98c77d Mon Sep 17 00:00:00 2001 From: Gemma Mason Date: Thu, 8 May 2025 09:19:38 +1200 Subject: [PATCH 04/13] bug detected in reshape --- .../operations/numpy/test_numpy_reshape.py | 58 +++++++++++++++---- 1 file changed, 46 insertions(+), 12 deletions(-) diff --git a/packages/pipeline/tests/operations/numpy/test_numpy_reshape.py b/packages/pipeline/tests/operations/numpy/test_numpy_reshape.py index a5ff9ab4..c3817e70 100644 --- a/packages/pipeline/tests/operations/numpy/test_numpy_reshape.py +++ b/packages/pipeline/tests/operations/numpy/test_numpy_reshape.py @@ -53,7 +53,7 @@ def test_Rearrange_skip(): assert np.all(output == wrong_shape_array), "Check skip can leave array unchanged." def test_Rearrange_not_skip(): - """Check that the operation is not skipped, if the skip flag is not set to True.""" + """Check that the operation can raise an error, if the skip flag is not set to True.""" r = reshape.Rearrange('h l w -> l w h') h_dim = 1 l_dim = 12 @@ -61,24 +61,58 @@ def test_Rearrange_not_skip(): with pytest.raises(Exception): r.apply_func(wrong_shape_array) - def test_Squeeze(): - s = reshape.Squeeze(axis=(2, 3)) - random_array = np.random.randn(8, 8, 1, 1, 2, 1) - assert s.apply_func(random_array).shape == (8, 8, 2, 1), "Squeeze only the correct axes." - -def test_undo_Squeeze(): s = reshape.Squeeze(axis=(2, 3)) random_array = np.random.randn(8, 8, 1, 1, 2, 1) output = s.apply_func(random_array) undo_output = s.undo_func(output) + assert output.shape == (8, 8, 2, 1), "Squeeze only the correct axes." assert random_array.shape == undo_output.shape, "Check Squeeze can correctly undo itself." + with pytest.raises(Exception): + s.apply_func(output) # Output doesn't have the correct axes of length 1, so we get an error. -def test_Squeeze_error(): - """Check we get an error if we try to squeeze an axis not of length 1.""" - s = reshape.Squeeze(axis=(1, 3)) # Note axis 1, below, is not of length 1. - random_array = np.random.randn(8, 8, 1, 1, 2, 1) + +def test_Expand(): + e = reshape.Expand(axis=(0, 2)) + random_array = np.random.randn(4, 3, 5) + output = e.apply_func(random_array) + undo_output = e.undo_func(output) + assert output.shape == (1, 4, 1, 3, 5), "Expand the correct axes." + assert undo_output.shape == random_array.shape, "Expand can undo itself." with pytest.raises(Exception): - s.apply_func(random_array) + e.undo_func(random_array) + +def test_Squeeze_reverses_Expand(): + e = reshape.Expand(axis=(0, 2)) + s = reshape.Squeeze(axis=(0, 2)) + random_array = np.random.randn(4, 3, 5) + expand_output = e.apply_func(random_array) + squeeze_output = s.apply_func(expand_output) + assert squeeze_output.shape == random_array.shape, "Squeeze reverses Expand." + + +def test_Flattener(): + f = reshape.Flattener() + random_array = np.random.randn(4, 3, 5) + output = f.apply(random_array) + assert len(output.shape) == 1, "Flattener produces a 1D array." + +def test_Flatten(): + f1 = reshape.Flatten(flatten_dims=2) + random_array = np.random.randn(4, 3, 5) + output = f1.apply_func(random_array) + undo_output = f1.undo_func(output) + assert output.shape == (4, 3*5), "Flatten acts on the last few dimensions." + assert np.all(undo_output == random_array), "Flatten can undo itself." + + f2 = reshape.Flatten(flatten_dims=1) + random_array = np.random.randn(4, 3, 5) + output = f2.apply_func(random_array) + assert np.all(output == random_array), "Flatten 1 dimension does nothing." + + f3 = reshape.Flatten() + random_array3 = np.random.randn(6, 7, 5, 2) + output = f3.apply_func(random_array3) + assert f3.undo_func(output).shape == (6, 7, 5, 2), "Undo Flatten all dimensions." From efa96d9973d123a927ea34b2aabd494af9ed4bbe Mon Sep 17 00:00:00 2001 From: Gemma Mason Date: Thu, 8 May 2025 10:54:07 +1200 Subject: [PATCH 05/13] bug detected in reshape --- packages/pipeline/tests/operations/numpy/test_numpy_reshape.py | 1 + 1 file changed, 1 insertion(+) diff --git a/packages/pipeline/tests/operations/numpy/test_numpy_reshape.py b/packages/pipeline/tests/operations/numpy/test_numpy_reshape.py index c3817e70..fda12d20 100644 --- a/packages/pipeline/tests/operations/numpy/test_numpy_reshape.py +++ b/packages/pipeline/tests/operations/numpy/test_numpy_reshape.py @@ -95,6 +95,7 @@ def test_Flattener(): f = reshape.Flattener() random_array = np.random.randn(4, 3, 5) output = f.apply(random_array) + undo_output = f.undo_func(output) assert len(output.shape) == 1, "Flattener produces a 1D array." def test_Flatten(): From 7b5f74baf73b3e7d863ee38e0f5d335017840d05 Mon Sep 17 00:00:00 2001 From: Gemma Mason Date: Thu, 8 May 2025 16:36:54 +1200 Subject: [PATCH 06/13] fixed bug in Flattener.undo --- .../src/pyearthtools/pipeline/operations/numpy/reshape.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/pipeline/src/pyearthtools/pipeline/operations/numpy/reshape.py b/packages/pipeline/src/pyearthtools/pipeline/operations/numpy/reshape.py index f855d872..10b9d60e 100644 --- a/packages/pipeline/src/pyearthtools/pipeline/operations/numpy/reshape.py +++ b/packages/pipeline/src/pyearthtools/pipeline/operations/numpy/reshape.py @@ -230,7 +230,7 @@ def _unflatten(data, shape): raise RuntimeError(f"`flatten_dims` was not set, and this set hasn't been used. Cannot Unflatten.") data_shape = data.shape - parsed_shape = data_shape[: -1 * min(1, (self.flatten_dims - 1))] if len(data_shape) > 1 else data_shape + parsed_shape = data_shape[: -1 * min(1, (self.flatten_dims - 1))] if len(data_shape) > 1 else [] attempts = [ (*parsed_shape, *self._unflattenshape), ] From acfabf247ee798daa4ed8d088e725b5922bc988c Mon Sep 17 00:00:00 2001 From: Gemma Mason Date: Wed, 21 May 2025 13:16:49 +1200 Subject: [PATCH 07/13] fix ability to undo trivial flatten --- .../src/pyearthtools/pipeline/operations/numpy/reshape.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/packages/pipeline/src/pyearthtools/pipeline/operations/numpy/reshape.py b/packages/pipeline/src/pyearthtools/pipeline/operations/numpy/reshape.py index 10b9d60e..67f42a49 100644 --- a/packages/pipeline/src/pyearthtools/pipeline/operations/numpy/reshape.py +++ b/packages/pipeline/src/pyearthtools/pipeline/operations/numpy/reshape.py @@ -223,14 +223,15 @@ def undo(self, data: np.ndarray) -> np.ndarray: def _unflatten(data, shape): while len(data.shape) > len(shape): - shape = (data[-len(shape)], *shape) + shape = (data.shape[-len(shape)], *shape) return data.reshape(shape) if self.flatten_dims is None: raise RuntimeError(f"`flatten_dims` was not set, and this set hasn't been used. Cannot Unflatten.") data_shape = data.shape - parsed_shape = data_shape[: -1 * min(1, (self.flatten_dims - 1))] if len(data_shape) > 1 else [] + # parsed_shape = data_shape[: -1 * min(1, (self.flatten_dims - 1))] if len(data_shape) > 1 else [] + parsed_shape = data_shape[: -1] if len(data_shape) > 1 else [] attempts = [ (*parsed_shape, *self._unflattenshape), ] From ae4bf05d8804ddabbbf75323d27a908d1f0f8192 Mon Sep 17 00:00:00 2001 From: Gemma Mason Date: Wed, 21 May 2025 14:50:23 +1200 Subject: [PATCH 08/13] test swapaxis --- .../operations/numpy/test_numpy_reshape.py | 23 ++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/packages/pipeline/tests/operations/numpy/test_numpy_reshape.py b/packages/pipeline/tests/operations/numpy/test_numpy_reshape.py index fda12d20..5e38b1b0 100644 --- a/packages/pipeline/tests/operations/numpy/test_numpy_reshape.py +++ b/packages/pipeline/tests/operations/numpy/test_numpy_reshape.py @@ -95,8 +95,17 @@ def test_Flattener(): f = reshape.Flattener() random_array = np.random.randn(4, 3, 5) output = f.apply(random_array) - undo_output = f.undo_func(output) + undo_output = f.undo(output) assert len(output.shape) == 1, "Flattener produces a 1D array." + assert np.all(undo_output == random_array), "Flattener can undo itself." + +def test_Flattener_1_dim(): + f2 = reshape.Flattener(flatten_dims=1) + random_array = np.random.randn(4, 3, 5) + output = f2.apply(random_array) + undo_output = f2.undo(output) # Check that the undo still works. + assert np.all(output == random_array), "Flatten 1 dimension does nothing." + assert np.all(undo_output == random_array), "Undo Flatten 1 dimension." def test_Flatten(): f1 = reshape.Flatten(flatten_dims=2) @@ -106,14 +115,26 @@ def test_Flatten(): assert output.shape == (4, 3*5), "Flatten acts on the last few dimensions." assert np.all(undo_output == random_array), "Flatten can undo itself." +def test_Flatten_1_dim(): f2 = reshape.Flatten(flatten_dims=1) random_array = np.random.randn(4, 3, 5) output = f2.apply_func(random_array) + undo_output = f2.undo_func(output) # Check that the undo still works. assert np.all(output == random_array), "Flatten 1 dimension does nothing." + assert np.all(undo_output == random_array), "Undo Flatten 1 dimension." +def test_Flatten_all_dims(): f3 = reshape.Flatten() random_array3 = np.random.randn(6, 7, 5, 2) output = f3.apply_func(random_array3) + assert output.shape == (6*7*5*2,) assert f3.undo_func(output).shape == (6, 7, 5, 2), "Undo Flatten all dimensions." +def test_SwapAxis(): + s = reshape.SwapAxis(1, 3) + random_array = np.random.randn(5, 7, 8, 2) + output = s.apply_func(random_array) + assert output.shape == (5, 2, 8, 7), "Swap axes 1 and 3" + undo_output = s.undo_func(output) + assert np.all(undo_output == random_array), "Undo axis swap." From bab31ada7699264dda2252f4a231f3c6868b9821 Mon Sep 17 00:00:00 2001 From: Gemma Mason Date: Wed, 21 May 2025 15:02:48 +1200 Subject: [PATCH 09/13] test swapaxis --- .../pipeline/tests/operations/numpy/test_numpy_reshape.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/packages/pipeline/tests/operations/numpy/test_numpy_reshape.py b/packages/pipeline/tests/operations/numpy/test_numpy_reshape.py index 5e38b1b0..f6faa739 100644 --- a/packages/pipeline/tests/operations/numpy/test_numpy_reshape.py +++ b/packages/pipeline/tests/operations/numpy/test_numpy_reshape.py @@ -130,6 +130,11 @@ def test_Flatten_all_dims(): assert output.shape == (6*7*5*2,) assert f3.undo_func(output).shape == (6, 7, 5, 2), "Undo Flatten all dimensions." +def test_Flatten_with_shape_attempt(): + f = reshape.Flatten(shape_attempt = (2, 1, 1, 1)) + undo_data = np.zeros((2)) + assert f.undo_func(undo_data).shape == (2, 1, 1, 1) + def test_SwapAxis(): s = reshape.SwapAxis(1, 3) random_array = np.random.randn(5, 7, 8, 2) From e5d5aeaa44dc6e33e7878f8e26dfdafd2da4e9eb Mon Sep 17 00:00:00 2001 From: Gemma Mason Date: Wed, 21 May 2025 15:15:54 +1200 Subject: [PATCH 10/13] test shape_attempt --- .../tests/operations/numpy/test_numpy_reshape.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/packages/pipeline/tests/operations/numpy/test_numpy_reshape.py b/packages/pipeline/tests/operations/numpy/test_numpy_reshape.py index f6faa739..ab541ac9 100644 --- a/packages/pipeline/tests/operations/numpy/test_numpy_reshape.py +++ b/packages/pipeline/tests/operations/numpy/test_numpy_reshape.py @@ -131,7 +131,16 @@ def test_Flatten_all_dims(): assert f3.undo_func(output).shape == (6, 7, 5, 2), "Undo Flatten all dimensions." def test_Flatten_with_shape_attempt(): + incoming_data = np.zeros((8, 1, 3, 3)) f = reshape.Flatten(shape_attempt = (2, 1, 1, 1)) + f.apply_func(incoming_data) + undo_data = np.zeros((2)) + assert f.undo_func(undo_data).shape == (2, 1, 1, 1) + +def test_Flatten_with_shape_attempt_with_ellipses(): + incoming_data = np.zeros((8, 1, 3, 3)) + f = reshape.Flatten(shape_attempt = (2, '...', 1, 1)) + f.apply_func(incoming_data) undo_data = np.zeros((2)) assert f.undo_func(undo_data).shape == (2, 1, 1, 1) From 4a600a66d48666b01f7472975664784fb7af4ce6 Mon Sep 17 00:00:00 2001 From: Gemma Mason Date: Wed, 21 May 2025 15:16:14 +1200 Subject: [PATCH 11/13] tests complete, final changes --- .../src/pyearthtools/pipeline/operations/numpy/reshape.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/pipeline/src/pyearthtools/pipeline/operations/numpy/reshape.py b/packages/pipeline/src/pyearthtools/pipeline/operations/numpy/reshape.py index 67f42a49..a82017a0 100644 --- a/packages/pipeline/src/pyearthtools/pipeline/operations/numpy/reshape.py +++ b/packages/pipeline/src/pyearthtools/pipeline/operations/numpy/reshape.py @@ -222,8 +222,8 @@ def undo(self, data: np.ndarray) -> np.ndarray: raise RuntimeError(f"Shape not set, therefore cannot undo") def _unflatten(data, shape): - while len(data.shape) > len(shape): - shape = (data.shape[-len(shape)], *shape) + # while len(data.shape) > len(shape): + # shape = (data[-len(shape)], *shape) return data.reshape(shape) if self.flatten_dims is None: From ce268a0b32fea037373603a7db8fc4ca9cfcb473 Mon Sep 17 00:00:00 2001 From: Gemma Mason Date: Thu, 22 May 2025 09:27:40 +1200 Subject: [PATCH 12/13] remove extraneous brackets --- .../pipeline/tests/operations/numpy/test_numpy_reshape.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/pipeline/tests/operations/numpy/test_numpy_reshape.py b/packages/pipeline/tests/operations/numpy/test_numpy_reshape.py index ab541ac9..e02201bd 100644 --- a/packages/pipeline/tests/operations/numpy/test_numpy_reshape.py +++ b/packages/pipeline/tests/operations/numpy/test_numpy_reshape.py @@ -134,14 +134,14 @@ def test_Flatten_with_shape_attempt(): incoming_data = np.zeros((8, 1, 3, 3)) f = reshape.Flatten(shape_attempt = (2, 1, 1, 1)) f.apply_func(incoming_data) - undo_data = np.zeros((2)) + undo_data = np.zeros(2) assert f.undo_func(undo_data).shape == (2, 1, 1, 1) def test_Flatten_with_shape_attempt_with_ellipses(): incoming_data = np.zeros((8, 1, 3, 3)) f = reshape.Flatten(shape_attempt = (2, '...', 1, 1)) f.apply_func(incoming_data) - undo_data = np.zeros((2)) + undo_data = np.zeros(2) assert f.undo_func(undo_data).shape == (2, 1, 1, 1) def test_SwapAxis(): From 991bb97a15a6ac237dbb6678e000a57dfd4693b2 Mon Sep 17 00:00:00 2001 From: Tennessee Leeuwenburg Date: Thu, 22 May 2025 20:16:40 +1000 Subject: [PATCH 13/13] Apply automated code reformatting --- .../pipeline/operations/numpy/reshape.py | 2 +- .../operations/numpy/test_numpy_reshape.py | 36 ++++++++++++------- 2 files changed, 25 insertions(+), 13 deletions(-) diff --git a/packages/pipeline/src/pyearthtools/pipeline/operations/numpy/reshape.py b/packages/pipeline/src/pyearthtools/pipeline/operations/numpy/reshape.py index a82017a0..4eed6832 100644 --- a/packages/pipeline/src/pyearthtools/pipeline/operations/numpy/reshape.py +++ b/packages/pipeline/src/pyearthtools/pipeline/operations/numpy/reshape.py @@ -231,7 +231,7 @@ def _unflatten(data, shape): data_shape = data.shape # parsed_shape = data_shape[: -1 * min(1, (self.flatten_dims - 1))] if len(data_shape) > 1 else [] - parsed_shape = data_shape[: -1] if len(data_shape) > 1 else [] + parsed_shape = data_shape[:-1] if len(data_shape) > 1 else [] attempts = [ (*parsed_shape, *self._unflattenshape), ] diff --git a/packages/pipeline/tests/operations/numpy/test_numpy_reshape.py b/packages/pipeline/tests/operations/numpy/test_numpy_reshape.py index e02201bd..d788ec67 100644 --- a/packages/pipeline/tests/operations/numpy/test_numpy_reshape.py +++ b/packages/pipeline/tests/operations/numpy/test_numpy_reshape.py @@ -17,8 +17,9 @@ import numpy as np import pytest + def test_Rearrange(): - r = reshape.Rearrange('h l w -> h w l') + r = reshape.Rearrange("h l w -> h w l") h_dim = 2 l_dim = 10 w_dim = 20 @@ -29,10 +30,11 @@ def test_Rearrange(): assert output.shape == (h_dim, w_dim, l_dim), "Check dimensions rearranged correctly." assert np.all(undo_output.shape == random_array.shape), "Check undo successfully reverses." + def test_Rearrange_explicit_reverse(): """The undo can be detected automatically or given explicitly. This version tests what happens when it is given explicitly.""" - r = reshape.Rearrange('h l w -> l w h', reverse_rearrange='l w h -> h l w') + r = reshape.Rearrange("h l w -> l w h", reverse_rearrange="l w h -> h l w") h_dim = 1 l_dim = 12 w_dim = 6 @@ -42,9 +44,10 @@ def test_Rearrange_explicit_reverse(): assert np.all(undo_output == random_array), "Check explicit undo successfully reverses." + def test_Rearrange_skip(): """Check that the operation can be skipped, if the skip flag is True.""" - r = reshape.Rearrange('h l w -> l w h', skip=True) + r = reshape.Rearrange("h l w -> l w h", skip=True) h_dim = 1 l_dim = 12 wrong_shape_array = np.random.randn(h_dim, l_dim) @@ -52,15 +55,17 @@ def test_Rearrange_skip(): assert np.all(output == wrong_shape_array), "Check skip can leave array unchanged." + def test_Rearrange_not_skip(): """Check that the operation can raise an error, if the skip flag is not set to True.""" - r = reshape.Rearrange('h l w -> l w h') + r = reshape.Rearrange("h l w -> l w h") h_dim = 1 l_dim = 12 wrong_shape_array = np.random.randn(h_dim, l_dim) with pytest.raises(Exception): r.apply_func(wrong_shape_array) + def test_Squeeze(): s = reshape.Squeeze(axis=(2, 3)) random_array = np.random.randn(8, 8, 1, 1, 2, 1) @@ -69,7 +74,7 @@ def test_Squeeze(): assert output.shape == (8, 8, 2, 1), "Squeeze only the correct axes." assert random_array.shape == undo_output.shape, "Check Squeeze can correctly undo itself." with pytest.raises(Exception): - s.apply_func(output) # Output doesn't have the correct axes of length 1, so we get an error. + s.apply_func(output) # Output doesn't have the correct axes of length 1, so we get an error. def test_Expand(): @@ -82,6 +87,7 @@ def test_Expand(): with pytest.raises(Exception): e.undo_func(random_array) + def test_Squeeze_reverses_Expand(): e = reshape.Expand(axis=(0, 2)) s = reshape.Squeeze(axis=(0, 2)) @@ -99,51 +105,58 @@ def test_Flattener(): assert len(output.shape) == 1, "Flattener produces a 1D array." assert np.all(undo_output == random_array), "Flattener can undo itself." + def test_Flattener_1_dim(): f2 = reshape.Flattener(flatten_dims=1) random_array = np.random.randn(4, 3, 5) output = f2.apply(random_array) - undo_output = f2.undo(output) # Check that the undo still works. + undo_output = f2.undo(output) # Check that the undo still works. assert np.all(output == random_array), "Flatten 1 dimension does nothing." assert np.all(undo_output == random_array), "Undo Flatten 1 dimension." + def test_Flatten(): f1 = reshape.Flatten(flatten_dims=2) random_array = np.random.randn(4, 3, 5) output = f1.apply_func(random_array) undo_output = f1.undo_func(output) - assert output.shape == (4, 3*5), "Flatten acts on the last few dimensions." + assert output.shape == (4, 3 * 5), "Flatten acts on the last few dimensions." assert np.all(undo_output == random_array), "Flatten can undo itself." + def test_Flatten_1_dim(): f2 = reshape.Flatten(flatten_dims=1) random_array = np.random.randn(4, 3, 5) output = f2.apply_func(random_array) - undo_output = f2.undo_func(output) # Check that the undo still works. + undo_output = f2.undo_func(output) # Check that the undo still works. assert np.all(output == random_array), "Flatten 1 dimension does nothing." assert np.all(undo_output == random_array), "Undo Flatten 1 dimension." + def test_Flatten_all_dims(): f3 = reshape.Flatten() random_array3 = np.random.randn(6, 7, 5, 2) output = f3.apply_func(random_array3) - assert output.shape == (6*7*5*2,) + assert output.shape == (6 * 7 * 5 * 2,) assert f3.undo_func(output).shape == (6, 7, 5, 2), "Undo Flatten all dimensions." + def test_Flatten_with_shape_attempt(): incoming_data = np.zeros((8, 1, 3, 3)) - f = reshape.Flatten(shape_attempt = (2, 1, 1, 1)) + f = reshape.Flatten(shape_attempt=(2, 1, 1, 1)) f.apply_func(incoming_data) undo_data = np.zeros(2) assert f.undo_func(undo_data).shape == (2, 1, 1, 1) + def test_Flatten_with_shape_attempt_with_ellipses(): incoming_data = np.zeros((8, 1, 3, 3)) - f = reshape.Flatten(shape_attempt = (2, '...', 1, 1)) + f = reshape.Flatten(shape_attempt=(2, "...", 1, 1)) f.apply_func(incoming_data) undo_data = np.zeros(2) assert f.undo_func(undo_data).shape == (2, 1, 1, 1) + def test_SwapAxis(): s = reshape.SwapAxis(1, 3) random_array = np.random.randn(5, 7, 8, 2) @@ -151,4 +164,3 @@ def test_SwapAxis(): assert output.shape == (5, 2, 8, 7), "Swap axes 1 and 3" undo_output = s.undo_func(output) assert np.all(undo_output == random_array), "Undo axis swap." -