diff --git a/docs/api/pipeline/pipeline_api.md b/docs/api/pipeline/pipeline_api.md index 6de27751..f1ef7d34 100644 --- a/docs/api/pipeline/pipeline_api.md +++ b/docs/api/pipeline/pipeline_api.md @@ -10,6 +10,8 @@ :members: .. autoclass:: pyearthtools.pipeline.Operation :members: +.. autoclass:: pyearthtools.pipeline.ReversedPipeline + :members: .. autoclass:: pyearthtools.pipeline.PipelineException :members: .. autoclass:: pyearthtools.pipeline.PipelineFilterException diff --git a/docs/api/pipeline/pipeline_index.md b/docs/api/pipeline/pipeline_index.md index 5dc215d0..f053df43 100644 --- a/docs/api/pipeline/pipeline_index.md +++ b/docs/api/pipeline/pipeline_index.md @@ -6,9 +6,10 @@ The rest of this page contains reference information for the components of the P | Module | Purpose | API Docs | |----------------------|--------------------------------------|--------------------------------------------------------------------------------------------------------------------| -| `pipeline` | | - [Sampler](pipeline_api.md#pyearthtools.pipeline.Sampler) | +| `pipeline` | | - [Sampler](pipeline_api.md#pyearthtools.pipeline.Sampler) | | | | - [Pipeline](pipeline_api.md#pyearthtools.pipeline.Pipeline) | | | | - [Operation](pipeline_api.md#pyearthtools.pipeline.Operation) | +| | | - [ReversedPipeline](pipeline_api.md#pyearthtools.pipeline.ReversedPipeline) | | | | - [PipelineException](pipeline_api.md#pyearthtools.pipeline.PipelineException) | | | | - [PipelineFilterException](pipeline_api.md#pyearthtools.pipeline.PipelineFilterException) | | | | - [PipelineRuntimeError](pipeline_api.md#pyearthtools.pipeline.PipelineRuntimeError) | diff --git a/notebooks/Gallery.ipynb b/notebooks/Gallery.ipynb index e6b3495d..62b27fed 100644 --- a/notebooks/Gallery.ipynb +++ b/notebooks/Gallery.ipynb @@ -115,7 +115,8 @@ "| Basics | Introduction to what a pipeline is (essential reading) | [Pipeline Basics](./pipeline/Basics.ipynb) | 18 Aug 2025 |\n", "| Operations | Introduction to pipeline operations | [Pipeline Operations](./pipeline/Operations.ipynb) | 18 Aug 2025 |\n", "| Modifications | Introduction to pipeline modifications | [Pipeline Modifications](./pipeline/Modifications.ipynb) | 22 Aug 2025 |\n", - "| Branching | -- | [Pipeline Branching](./pipeline/Branching.ipynb) | 18 Aug 2025 |\n" + "| Branching | -- | [Pipeline Branching](./pipeline/Branching.ipynb) | 18 Aug 2025 |\n", + "| Patterns | Recommended design patterns for pipelines | [Additional Pipeline Syntaxes](./pipeline/Patterns.ipynb) | 21 Oct 2025 |\n" ] } ], @@ -135,7 +136,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.13.5" + "version": "3.13.7" } }, "nbformat": 4, diff --git a/notebooks/pipeline/Patterns.ipynb b/notebooks/pipeline/Patterns.ipynb new file mode 100644 index 00000000..c543132b --- /dev/null +++ b/notebooks/pipeline/Patterns.ipynb @@ -0,0 +1,6054 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "63287ab4-74ec-4af2-9d70-5ed866ff1798", + "metadata": {}, + "source": [ + "# Additional Pipeline Syntaxes\n", + "\n", + "This notebooks introduces syntaxes to ease creation and manipulation of `Pipeline`objects:\n", + "\n", + "- named pipelines,\n", + "- combination using `|` operation,\n", + "- reversing pipelines." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "cbe5ebfb-f820-4b77-ad0f-91d507a5195f", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "\n", + "def repr_ndarray(arr):\n", + " return f\"array(..., shape={arr.shape}, dtype={arr.dtype})\"\n", + "\n", + "np.set_printoptions(override_repr=repr_ndarray)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "160c063a-de2e-4a28-a22d-d156d4bc06f3", + "metadata": {}, + "outputs": [], + "source": [ + "import pyearthtools.data\n", + "import pyearthtools.pipeline" + ] + }, + { + "cell_type": "markdown", + "id": "964e9683-2999-4a8b-8a1f-407dc24d4ffa", + "metadata": {}, + "source": [ + "To illustrate these features, we'll reuse the same pipeline as the one used in the [End-to-end CNN Training Example](../tutorial/CNN-Model-Training.ipynb).\n", + "Here is the original definition of the pipeline." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "0b273573-99d1-4c8e-834d-7c0183674774", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
Pipeline\n",
+       "\tDescription                    `pyearthtools.pipeline` Data Pipeline\n",
+       "\n",
+       "\n",
+       "\tInitialisation                 \n",
+       "\t\t exceptions_to_ignore           None\n",
+       "\t\t iterator                       None\n",
+       "\t\t name                           None\n",
+       "\t\t sampler                        None\n",
+       "\tSteps                          \n",
+       "\t\t weatherbench.WB2ERA5           {'WB2ERA5': {'level': '[850]', 'license_ok': 'True', 'resolution': "'64x32'", 'variables': "['2m_temperature', 'u', 'v', 'geopotential', 'vorticity']"}}\n",
+       "\t\t sort.Sort                      {'Sort': {'order': "['2m_temperature', 'u_component_of_wind', 'v_component_of_wind', 'vorticity', 'geopotential']", 'strict': 'False'}}\n",
+       "\t\t coordinates.StandardLongitude  {'StandardLongitude': {'longitude_name': "'longitude'", 'type': "'0-360'"}}\n",
+       "\t\t reshape.CoordinateFlatten      {'CoordinateFlatten': {'coordinate': "['level']", 'skip_missing': 'False'}}\n",
+       "\t\t idx_modification.TemporalRetrieval {'TemporalRetrieval': {'concat': 'True', 'delta_unit': 'None', 'merge_function': 'None', 'merge_kwargs': 'None', 'samples': '((0, 1), (6, 1))'}}\n",
+       "\t\t conversion.ToNumpy             {'ToNumpy': {'reference_dataset': 'None', 'run_parallel': 'False', 'saved_records': 'None', 'warn': 'True'}}\n",
+       "\t\t reshape.Rearrange              {'Rearrange': {'rearrange': "'c t h w -> t c h w'", 'rearrange_kwargs': 'None', 'reverse_rearrange': 'None', 'skip': 'False'}}\n",
+       "\t\t reshape.Squeeze                {'Squeeze': {'axis': '0'}}
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "

Graph

" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "WB2ERA5_477165bc-eb0a-4338-a48f-58fdc18ecec1\n", + "\n", + "weatherbench.WB2ERA5\n", + "\n", + "\n", + "\n", + "Sort_ba32b738-273a-43aa-9233-6ba1e5a967ed\n", + "\n", + "sort.Sort\n", + "\n", + "\n", + "\n", + "WB2ERA5_477165bc-eb0a-4338-a48f-58fdc18ecec1->Sort_ba32b738-273a-43aa-9233-6ba1e5a967ed\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "StandardLongitude_d612dd7e-48a0-426d-971f-d7cddbe5fbeb\n", + "\n", + "coordinates.StandardLongitude\n", + "\n", + "\n", + "\n", + "Sort_ba32b738-273a-43aa-9233-6ba1e5a967ed->StandardLongitude_d612dd7e-48a0-426d-971f-d7cddbe5fbeb\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "CoordinateFlatten_a0b91502-2203-48cc-9b91-16771be5029a\n", + "\n", + "reshape.CoordinateFlatten\n", + "\n", + "\n", + "\n", + "StandardLongitude_d612dd7e-48a0-426d-971f-d7cddbe5fbeb->CoordinateFlatten_a0b91502-2203-48cc-9b91-16771be5029a\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "TemporalRetrieval_e5df41dc-5e69-4515-806b-737d6862c177\n", + "\n", + "idx_modification.TemporalRetrieval\n", + "\n", + "\n", + "\n", + "CoordinateFlatten_a0b91502-2203-48cc-9b91-16771be5029a->TemporalRetrieval_e5df41dc-5e69-4515-806b-737d6862c177\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "ToNumpy_663e9c82-aab3-4aec-ae87-a947704b0014\n", + "\n", + "conversion.ToNumpy\n", + "\n", + "\n", + "\n", + "TemporalRetrieval_e5df41dc-5e69-4515-806b-737d6862c177->ToNumpy_663e9c82-aab3-4aec-ae87-a947704b0014\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "Rearrange_4ccf8752-7b2f-481b-a763-845bfceaa157\n", + "\n", + "reshape.Rearrange\n", + "\n", + "\n", + "\n", + "ToNumpy_663e9c82-aab3-4aec-ae87-a947704b0014->Rearrange_4ccf8752-7b2f-481b-a763-845bfceaa157\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "Squeeze_5a8bf3aa-24c0-4c37-87b9-7e06b2a56ae6\n", + "\n", + "reshape.Squeeze\n", + "\n", + "\n", + "\n", + "Rearrange_4ccf8752-7b2f-481b-a763-845bfceaa157->Squeeze_5a8bf3aa-24c0-4c37-87b9-7e06b2a56ae6\n", + "\n", + "\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "pyearthtools.pipeline.Pipeline(\n", + " pyearthtools.data.download.weatherbench.WB2ERA5(\n", + " variables=[\"2m_temperature\", \"u\", \"v\", \"geopotential\", \"vorticity\"],\n", + " level=[850],\n", + " license_ok=True,\n", + " ),\n", + " pyearthtools.pipeline.operations.xarray.Sort(\n", + " [\"2m_temperature\", \"u_component_of_wind\", \"v_component_of_wind\", \"vorticity\", \"geopotential\"]\n", + " ),\n", + " pyearthtools.data.transforms.coordinates.StandardLongitude(type=\"0-360\"),\n", + " pyearthtools.pipeline.operations.xarray.reshape.CoordinateFlatten([\"level\"]),\n", + " pyearthtools.pipeline.modifications.TemporalRetrieval(\n", + " concat=True, samples=((0, 1), (6, 1))\n", + " ),\n", + " pyearthtools.pipeline.operations.xarray.conversion.ToNumpy(),\n", + " pyearthtools.pipeline.operations.numpy.reshape.Rearrange(\"c t h w -> t c h w\"),\n", + " pyearthtools.pipeline.operations.numpy.reshape.Squeeze(axis=0),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "b60a86ed-d610-410f-8b03-aa535c9ba99a", + "metadata": {}, + "source": [ + "## Named pipelines\n", + "\n", + "When developing a new pipeline, it can be convenient to separate the main stages of a long pipeline into these sub-pipelines, and assemble them into one big pipeline afterwards.\n", + "However, once the pipeline has been assembled, we loose access to the sub-pipelines.\n", + "To solve this, we can add a **name** to each of the sub-pipelines.\n", + "Then, in the final pipeline, we can recover them via the `.named` attribute, which is a dictionary of all the named sub-pipelines contained in a pipeline.\n", + "\n", + "In the following example, we build the same pipeline but split into 3 stages:\n", + "- a named pipeline \"prepare\", to fetch the data and apply few transformation on it,\n", + "- a temporal retrieval step, to generate the tuple of (features, target) samples,\n", + "- a named pipeline \"reshape\", to do the final convertion to numpy and reshaping." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "9882f56e-cf55-4938-9bd8-0c6b3240541e", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
Pipeline\n",
+       "\tDescription                    `pyearthtools.pipeline` Data Pipeline\n",
+       "\n",
+       "\n",
+       "\tInitialisation                 \n",
+       "\t\t exceptions_to_ignore           None\n",
+       "\t\t iterator                       None\n",
+       "\t\t name                           None\n",
+       "\t\t sampler                        None\n",
+       "\tSteps                          \n",
+       "\t\t weatherbench.WB2ERA5           {'WB2ERA5': {'level': '[850]', 'license_ok': 'True', 'resolution': "'64x32'", 'variables': "['2m_temperature', 'u', 'v', 'geopotential', 'vorticity']"}}\n",
+       "\t\t sort.Sort                      {'Sort': {'order': "['2m_temperature', 'u_component_of_wind', 'v_component_of_wind', 'vorticity', 'geopotential']", 'strict': 'False'}}\n",
+       "\t\t coordinates.StandardLongitude  {'StandardLongitude': {'longitude_name': "'longitude'", 'type': "'0-360'"}}\n",
+       "\t\t reshape.CoordinateFlatten      {'CoordinateFlatten': {'coordinate': "['level']", 'skip_missing': 'False'}}\n",
+       "\t\t idx_modification.TemporalRetrieval {'TemporalRetrieval': {'concat': 'True', 'delta_unit': 'None', 'merge_function': 'None', 'merge_kwargs': 'None', 'samples': '((0, 1), (6, 1))'}}\n",
+       "\t\t conversion.ToNumpy             {'ToNumpy': {'reference_dataset': 'None', 'run_parallel': 'False', 'saved_records': 'None', 'warn': 'True'}}\n",
+       "\t\t reshape.Rearrange              {'Rearrange': {'rearrange': "'c t h w -> t c h w'", 'rearrange_kwargs': 'None', 'reverse_rearrange': 'None', 'skip': 'False'}}\n",
+       "\t\t reshape.Squeeze                {'Squeeze': {'axis': '0'}}
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "

Graph

" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "WB2ERA5_1db05602-c0b7-419f-b6f3-8f32405151ac\n", + "\n", + "weatherbench.WB2ERA5\n", + "\n", + "\n", + "\n", + "Sort_2d1b6be0-d9ec-4a44-94e1-20aa45479ec3\n", + "\n", + "sort.Sort\n", + "\n", + "\n", + "\n", + "WB2ERA5_1db05602-c0b7-419f-b6f3-8f32405151ac->Sort_2d1b6be0-d9ec-4a44-94e1-20aa45479ec3\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "StandardLongitude_3ad68fc7-c29d-427d-9d4c-990dbdeb272d\n", + "\n", + "coordinates.StandardLongitude\n", + "\n", + "\n", + "\n", + "Sort_2d1b6be0-d9ec-4a44-94e1-20aa45479ec3->StandardLongitude_3ad68fc7-c29d-427d-9d4c-990dbdeb272d\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "CoordinateFlatten_6c0a2ef7-2b79-4cb0-a232-7496e31bf682\n", + "\n", + "reshape.CoordinateFlatten\n", + "\n", + "\n", + "\n", + "StandardLongitude_3ad68fc7-c29d-427d-9d4c-990dbdeb272d->CoordinateFlatten_6c0a2ef7-2b79-4cb0-a232-7496e31bf682\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "TemporalRetrieval_fe44d7c0-23b9-48bc-8833-de47a7d72c01\n", + "\n", + "idx_modification.TemporalRetrieval\n", + "\n", + "\n", + "\n", + "CoordinateFlatten_6c0a2ef7-2b79-4cb0-a232-7496e31bf682->TemporalRetrieval_fe44d7c0-23b9-48bc-8833-de47a7d72c01\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "ToNumpy_3a4fc0b9-60b3-4a45-a5e2-df05fbd91a39\n", + "\n", + "conversion.ToNumpy\n", + "\n", + "\n", + "\n", + "TemporalRetrieval_fe44d7c0-23b9-48bc-8833-de47a7d72c01->ToNumpy_3a4fc0b9-60b3-4a45-a5e2-df05fbd91a39\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "Rearrange_333f3ef5-db26-40cf-bd3d-cf18fd0ae1ec\n", + "\n", + "reshape.Rearrange\n", + "\n", + "\n", + "\n", + "ToNumpy_3a4fc0b9-60b3-4a45-a5e2-df05fbd91a39->Rearrange_333f3ef5-db26-40cf-bd3d-cf18fd0ae1ec\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "Squeeze_b7533bed-0271-4494-831f-d406b3fcd9a4\n", + "\n", + "reshape.Squeeze\n", + "\n", + "\n", + "\n", + "Rearrange_333f3ef5-db26-40cf-bd3d-cf18fd0ae1ec->Squeeze_b7533bed-0271-4494-831f-d406b3fcd9a4\n", + "\n", + "\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "pipeline = pyearthtools.pipeline.Pipeline(\n", + " pyearthtools.pipeline.Pipeline(\n", + " pyearthtools.data.download.weatherbench.WB2ERA5(\n", + " variables=[\"2m_temperature\", \"u\", \"v\", \"geopotential\", \"vorticity\"],\n", + " level=[850],\n", + " license_ok=True,\n", + " ),\n", + " pyearthtools.pipeline.operations.xarray.Sort(\n", + " [\"2m_temperature\", \"u_component_of_wind\", \"v_component_of_wind\", \"vorticity\", \"geopotential\"]\n", + " ),\n", + " pyearthtools.data.transforms.coordinates.StandardLongitude(type=\"0-360\"),\n", + " pyearthtools.pipeline.operations.xarray.reshape.CoordinateFlatten([\"level\"]),\n", + " name=\"prepare\"\n", + " ),\n", + " pyearthtools.pipeline.modifications.TemporalRetrieval(concat=True, samples=((0, 1), (6, 1))),\n", + " pyearthtools.pipeline.Pipeline(\n", + " pyearthtools.pipeline.operations.xarray.conversion.ToNumpy(),\n", + " pyearthtools.pipeline.operations.numpy.reshape.Rearrange(\"c t h w -> t c h w\"),\n", + " pyearthtools.pipeline.operations.numpy.reshape.Squeeze(axis=0),\n", + " name=\"reshape\"\n", + " ),\n", + ")\n", + "pipeline" + ] + }, + { + "cell_type": "markdown", + "id": "a341cbf7-65a6-4ee1-a263-4b3b2b7f703f", + "metadata": {}, + "source": [ + "We can inspect the `.named` attribute to see which named pipelines are accessible within a pipeline." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "496c427e-37cf-45a9-8424-4bc0eb594997", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "dict_keys(['prepare', 'reshape'])" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pipeline.named.keys()" + ] + }, + { + "cell_type": "markdown", + "id": "a0598e50-0f1a-409d-9bab-a131fcaaa9f9", + "metadata": {}, + "source": [ + "Then we can access the named pipeline \"prepare\" as follows." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "fcb727e8-274f-49e7-b743-386fe665af49", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
Pipeline\n",
+       "\tDescription                    `pyearthtools.pipeline` Data Pipeline\n",
+       "\n",
+       "\n",
+       "\tInitialisation                 \n",
+       "\t\t exceptions_to_ignore           None\n",
+       "\t\t iterator                       None\n",
+       "\t\t name                           'prepare'\n",
+       "\t\t sampler                        None\n",
+       "\tSteps                          \n",
+       "\t\t weatherbench.WB2ERA5           {'WB2ERA5': {'level': '[850]', 'license_ok': 'True', 'resolution': "'64x32'", 'variables': "['2m_temperature', 'u', 'v', 'geopotential', 'vorticity']"}}\n",
+       "\t\t sort.Sort                      {'Sort': {'order': "['2m_temperature', 'u_component_of_wind', 'v_component_of_wind', 'vorticity', 'geopotential']", 'strict': 'False'}}\n",
+       "\t\t coordinates.StandardLongitude  {'StandardLongitude': {'longitude_name': "'longitude'", 'type': "'0-360'"}}\n",
+       "\t\t reshape.CoordinateFlatten      {'CoordinateFlatten': {'coordinate': "['level']", 'skip_missing': 'False'}}
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "

Graph

" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "WB2ERA5_880a48cc-25b0-4e7d-96ee-c7dda4ec248b\n", + "\n", + "weatherbench.WB2ERA5\n", + "\n", + "\n", + "\n", + "Sort_033ec7ad-0d71-4b89-8859-78f0e01d555e\n", + "\n", + "sort.Sort\n", + "\n", + "\n", + "\n", + "WB2ERA5_880a48cc-25b0-4e7d-96ee-c7dda4ec248b->Sort_033ec7ad-0d71-4b89-8859-78f0e01d555e\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "StandardLongitude_6fecb1ce-8bc9-4b59-80ec-9e1d1bdbcaa3\n", + "\n", + "coordinates.StandardLongitude\n", + "\n", + "\n", + "\n", + "Sort_033ec7ad-0d71-4b89-8859-78f0e01d555e->StandardLongitude_6fecb1ce-8bc9-4b59-80ec-9e1d1bdbcaa3\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "CoordinateFlatten_fd3022de-60e0-4f64-80c5-88d65a4a01c4\n", + "\n", + "reshape.CoordinateFlatten\n", + "\n", + "\n", + "\n", + "StandardLongitude_6fecb1ce-8bc9-4b59-80ec-9e1d1bdbcaa3->CoordinateFlatten_fd3022de-60e0-4f64-80c5-88d65a4a01c4\n", + "\n", + "\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "pipeline.named[\"prepare\"]" + ] + }, + { + "cell_type": "markdown", + "id": "bf6e89bb-683d-4790-8e29-b7b1c69532e0", + "metadata": {}, + "source": [ + "And even use it without the rest of the pipeline, as it includes a data source." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "221cd83e-3e38-4763-bccd-54a0e185a378", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.Dataset> Size: 42kB\n",
+       "Dimensions:                 (latitude: 32, longitude: 64, time: 1)\n",
+       "Coordinates:\n",
+       "  * latitude                (latitude) float64 256B -87.19 -81.56 ... 87.19\n",
+       "  * longitude               (longitude) float64 512B 0.0 5.625 ... 348.8 354.4\n",
+       "  * time                    (time) datetime64[ns] 8B 2021-01-01\n",
+       "Data variables:\n",
+       "    2m_temperature          (time, longitude, latitude) float32 8kB dask.array<chunksize=(1, 64, 32), meta=np.ndarray>\n",
+       "    u_component_of_wind850  (time, longitude, latitude) float32 8kB dask.array<chunksize=(1, 64, 32), meta=np.ndarray>\n",
+       "    v_component_of_wind850  (time, longitude, latitude) float32 8kB dask.array<chunksize=(1, 64, 32), meta=np.ndarray>\n",
+       "    vorticity850            (time, longitude, latitude) float32 8kB dask.array<chunksize=(1, 64, 32), meta=np.ndarray>\n",
+       "    geopotential850         (time, longitude, latitude) float32 8kB dask.array<chunksize=(1, 64, 32), meta=np.ndarray>\n",
+       "Attributes:\n",
+       "    level-dtype:  int64
" + ], + "text/plain": [ + " Size: 42kB\n", + "Dimensions: (latitude: 32, longitude: 64, time: 1)\n", + "Coordinates:\n", + " * latitude (latitude) float64 256B -87.19 -81.56 ... 87.19\n", + " * longitude (longitude) float64 512B 0.0 5.625 ... 348.8 354.4\n", + " * time (time) datetime64[ns] 8B 2021-01-01\n", + "Data variables:\n", + " 2m_temperature (time, longitude, latitude) float32 8kB dask.array\n", + " u_component_of_wind850 (time, longitude, latitude) float32 8kB dask.array\n", + " v_component_of_wind850 (time, longitude, latitude) float32 8kB dask.array\n", + " vorticity850 (time, longitude, latitude) float32 8kB dask.array\n", + " geopotential850 (time, longitude, latitude) float32 8kB dask.array\n", + "Attributes:\n", + " level-dtype: int64" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pipeline.named[\"prepare\"][\"20210101T00\"]" + ] + }, + { + "cell_type": "markdown", + "id": "1d8275f2-252a-411f-93eb-17541050f726", + "metadata": {}, + "source": [ + "## Pipe operator\n", + "\n", + "The `Pipeline` object also support the `|` operator (logical or) as a way to combine multiple pipelines together.\n", + "This has the same effect as creating a new `Pipeline` object as a combination of 2 pipelines (or a pipeline and a step).\n", + "This has not additional effect and can be used to increase the readability when building a long pipeline.\n", + "\n", + "In the following example, we now create the sub-pipelines and the temporal retrieval step separately, as distinct objects, then combine them using the `|` operator." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "19bcdc11-aeea-4a3d-b702-cad8073ba3f7", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
Pipeline\n",
+       "\tDescription                    `pyearthtools.pipeline` Data Pipeline\n",
+       "\n",
+       "\n",
+       "\tInitialisation                 \n",
+       "\t\t exceptions_to_ignore           None\n",
+       "\t\t iterator                       None\n",
+       "\t\t name                           None\n",
+       "\t\t sampler                        None\n",
+       "\tSteps                          \n",
+       "\t\t weatherbench.WB2ERA5           {'WB2ERA5': {'level': '[850]', 'license_ok': 'True', 'resolution': "'64x32'", 'variables': "['2m_temperature', 'u', 'v', 'geopotential', 'vorticity']"}}\n",
+       "\t\t sort.Sort                      {'Sort': {'order': "['2m_temperature', 'u_component_of_wind', 'v_component_of_wind', 'vorticity', 'geopotential']", 'strict': 'False'}}\n",
+       "\t\t coordinates.StandardLongitude  {'StandardLongitude': {'longitude_name': "'longitude'", 'type': "'0-360'"}}\n",
+       "\t\t reshape.CoordinateFlatten      {'CoordinateFlatten': {'coordinate': "['level']", 'skip_missing': 'False'}}\n",
+       "\t\t idx_modification.TemporalRetrieval {'TemporalRetrieval': {'concat': 'True', 'delta_unit': 'None', 'merge_function': 'None', 'merge_kwargs': 'None', 'samples': '((0, 1), (6, 1))'}}\n",
+       "\t\t conversion.ToNumpy             {'ToNumpy': {'reference_dataset': 'None', 'run_parallel': 'False', 'saved_records': 'None', 'warn': 'True'}}\n",
+       "\t\t reshape.Rearrange              {'Rearrange': {'rearrange': "'c t h w -> t c h w'", 'rearrange_kwargs': 'None', 'reverse_rearrange': 'None', 'skip': 'False'}}\n",
+       "\t\t reshape.Squeeze                {'Squeeze': {'axis': '0'}}
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "

Graph

" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "WB2ERA5_fa1f1141-de86-4dce-95f4-0097fa8a4d2b\n", + "\n", + "weatherbench.WB2ERA5\n", + "\n", + "\n", + "\n", + "Sort_3e930a6f-74f0-4b0d-be6e-5bf12f458581\n", + "\n", + "sort.Sort\n", + "\n", + "\n", + "\n", + "WB2ERA5_fa1f1141-de86-4dce-95f4-0097fa8a4d2b->Sort_3e930a6f-74f0-4b0d-be6e-5bf12f458581\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "StandardLongitude_c912e745-eae3-4c56-bb6c-67f9faa3dee1\n", + "\n", + "coordinates.StandardLongitude\n", + "\n", + "\n", + "\n", + "Sort_3e930a6f-74f0-4b0d-be6e-5bf12f458581->StandardLongitude_c912e745-eae3-4c56-bb6c-67f9faa3dee1\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "CoordinateFlatten_d148516b-3cbb-43c6-b502-74785f7d257d\n", + "\n", + "reshape.CoordinateFlatten\n", + "\n", + "\n", + "\n", + "StandardLongitude_c912e745-eae3-4c56-bb6c-67f9faa3dee1->CoordinateFlatten_d148516b-3cbb-43c6-b502-74785f7d257d\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "TemporalRetrieval_d81026f3-e037-413b-9648-73cff5528321\n", + "\n", + "idx_modification.TemporalRetrieval\n", + "\n", + "\n", + "\n", + "CoordinateFlatten_d148516b-3cbb-43c6-b502-74785f7d257d->TemporalRetrieval_d81026f3-e037-413b-9648-73cff5528321\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "ToNumpy_aef30dde-8087-4cb7-b235-7b3cc34a4728\n", + "\n", + "conversion.ToNumpy\n", + "\n", + "\n", + "\n", + "TemporalRetrieval_d81026f3-e037-413b-9648-73cff5528321->ToNumpy_aef30dde-8087-4cb7-b235-7b3cc34a4728\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "Rearrange_c0318ec1-228f-41a9-89b8-221d3e01b014\n", + "\n", + "reshape.Rearrange\n", + "\n", + "\n", + "\n", + "ToNumpy_aef30dde-8087-4cb7-b235-7b3cc34a4728->Rearrange_c0318ec1-228f-41a9-89b8-221d3e01b014\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "Squeeze_c4c632a7-34ad-4857-85ac-fc4f7f7a2322\n", + "\n", + "reshape.Squeeze\n", + "\n", + "\n", + "\n", + "Rearrange_c0318ec1-228f-41a9-89b8-221d3e01b014->Squeeze_c4c632a7-34ad-4857-85ac-fc4f7f7a2322\n", + "\n", + "\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "prepare = pyearthtools.pipeline.Pipeline(\n", + " pyearthtools.data.download.weatherbench.WB2ERA5(\n", + " variables=[\"2m_temperature\", \"u\", \"v\", \"geopotential\", \"vorticity\"],\n", + " level=[850],\n", + " license_ok=True,\n", + " ),\n", + " pyearthtools.pipeline.operations.xarray.Sort(\n", + " [\"2m_temperature\", \"u_component_of_wind\", \"v_component_of_wind\", \"vorticity\", \"geopotential\"]\n", + " ),\n", + " pyearthtools.data.transforms.coordinates.StandardLongitude(type=\"0-360\"),\n", + " pyearthtools.pipeline.operations.xarray.reshape.CoordinateFlatten([\"level\"]),\n", + " name=\"prepare\"\n", + ")\n", + "\n", + "retrieve = pyearthtools.pipeline.modifications.TemporalRetrieval(concat=True, samples=((0, 1), (6, 1)))\n", + "\n", + "reshape = pyearthtools.pipeline.Pipeline(\n", + " pyearthtools.pipeline.operations.xarray.conversion.ToNumpy(),\n", + " pyearthtools.pipeline.operations.numpy.reshape.Rearrange(\"c t h w -> t c h w\"),\n", + " pyearthtools.pipeline.operations.numpy.reshape.Squeeze(axis=0),\n", + " name=\"reshape\"\n", + ")\n", + "\n", + "pipeline = prepare | retrieve | reshape\n", + "pipeline" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "15781661-20ac-4235-aa75-e05184e9d1b4", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(array(..., shape=(5, 64, 32), dtype=float32),\n", + " array(..., shape=(5, 64, 32), dtype=float32))" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pipeline[\"20210101T00\"]" + ] + }, + { + "cell_type": "markdown", + "id": "f97ac27d-1e23-404d-b0ff-a265c6a04530", + "metadata": {}, + "source": [ + "## Reversed pipeline\n", + "\n", + "Pipelines can be reversed, i.e. undoing their effect.\n", + "The reverse of a pipeline can be obtained via the `.reversed` atttribute.\n", + "\n", + "**Important:** Depending on the steps in a pipeline, some might have a proper \"undo\" method but others will just be skipped, i.e. will not apply any change to the sample while undoing the pipeline. In our example, it is the case of the `pyearthtools.data.download.weatherbench.WB2ERA5` and `pyearthtools.pipeline.modifications.TemporalRetrieval` steps.\n", + "\n", + "In the following example, we'll reverse the \"reshape\" sub-pipeline, which is only made of reversible steps." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "0397753a-3fbd-453f-8402-7ff1db67745b", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
Pipeline\n",
+       "\tDescription                    `pyearthtools.pipeline` Data Pipeline\n",
+       "\n",
+       "\n",
+       "\tInitialisation                 \n",
+       "\t\t exceptions_to_ignore           None\n",
+       "\t\t iterator                       None\n",
+       "\t\t name                           'reshape'\n",
+       "\t\t sampler                        None\n",
+       "\tSteps                          \n",
+       "\t\t conversion.ToNumpy             {'ToNumpy': {'reference_dataset': 'None', 'run_parallel': 'False', 'saved_records': 'None', 'warn': 'True'}}\n",
+       "\t\t reshape.Rearrange              {'Rearrange': {'rearrange': "'c t h w -> t c h w'", 'rearrange_kwargs': 'None', 'reverse_rearrange': 'None', 'skip': 'False'}}\n",
+       "\t\t reshape.Squeeze                {'Squeeze': {'axis': '0'}}
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "

Graph

" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "ToNumpy_0b6249c6-a29c-4706-9d96-ad1ef6a9485d\n", + "\n", + "conversion.ToNumpy\n", + "\n", + "\n", + "\n", + "Rearrange_e433afd2-c361-4d56-8ae8-03da25bfc65b\n", + "\n", + "reshape.Rearrange\n", + "\n", + "\n", + "\n", + "ToNumpy_0b6249c6-a29c-4706-9d96-ad1ef6a9485d->Rearrange_e433afd2-c361-4d56-8ae8-03da25bfc65b\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "Squeeze_85e44efd-a7bf-4fc2-98b5-7c5c9f642487\n", + "\n", + "reshape.Squeeze\n", + "\n", + "\n", + "\n", + "Rearrange_e433afd2-c361-4d56-8ae8-03da25bfc65b->Squeeze_85e44efd-a7bf-4fc2-98b5-7c5c9f642487\n", + "\n", + "\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "pipeline.named[\"reshape\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "1b202180-6456-4989-99a4-fded7115a6c7", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
ReversedPipeline\n",
+       "\tInitialisation                 Operation reversing the effect of pipeline\n",
+       "\t\t forward_pipeline               {'Pipeline': {'__args': "(ToNumpy\\n\\tInitialisation                 Convert xarray objects to np.ndarray's\\n\\t\\t reference_dataset              None\\n\\t\\t run_parallel                   False\\n\\t\\t saved_records                  None\\n\\t\\t warn                           True, Rearrange\\n\\tInitialisation                 Operation to rearrange data using einops\\n\\t\\t rearrange                      'c t h w -> t c h w'\\n\\t\\t rearrange_kwargs               None\\n\\t\\t reverse_rearrange              None\\n\\t\\t skip                           False, Squeeze\\n\\tInitialisation                 Operation to Squeeze one-Dimensional axes at 'axis' location\\n\\t\\t axis                           0)", 'exceptions_to_ignore': 'None', 'iterator': 'None', 'name': "'reshape'", 'sampler': 'None'}}
" + ], + "text/plain": [ + "ReversedPipeline\n", + "\tInitialisation Operation reversing the effect of pipeline\n", + "\t\t forward_pipeline {'Pipeline': {'__args': \"(ToNumpy\\n\\tInitialisation Convert xarray objects to np.ndarray's\\n\\t\\t reference_dataset None\\n\\t\\t run_parallel False\\n\\t\\t saved_records None\\n\\t\\t warn True, Rearrange\\n\\tInitialisation Operation to rearrange data using einops\\n\\t\\t rearrange 'c t h w -> t c h w'\\n\\t\\t rearrange_kwargs None\\n\\t\\t reverse_rearrange None\\n\\t\\t skip False, Squeeze\\n\\tInitialisation Operation to Squeeze one-Dimensional axes at 'axis' location\\n\\t\\t axis 0)\", 'exceptions_to_ignore': 'None', 'iterator': 'None', 'name': \"'reshape'\", 'sampler': 'None'}}" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pipeline.named[\"reshape\"].reversed" + ] + }, + { + "cell_type": "markdown", + "id": "ed456cfb-fcff-43a8-ab82-88310af44741", + "metadata": {}, + "source": [ + "To test it, we'll fetch a sample from the \"prepare\" pipeline, reshape it and apply the reverse." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "0fbda3f4-8f44-4032-a218-115fe21fd664", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.Dataset> Size: 42kB\n",
+       "Dimensions:                 (time: 1, longitude: 64, latitude: 32)\n",
+       "Coordinates:\n",
+       "  * time                    (time) datetime64[ns] 8B 2021-01-01\n",
+       "  * longitude               (longitude) float64 512B 0.0 5.625 ... 348.8 354.4\n",
+       "  * latitude                (latitude) float64 256B -87.19 -81.56 ... 87.19\n",
+       "Data variables:\n",
+       "    2m_temperature          (time, longitude, latitude) float32 8kB 241.9 ......\n",
+       "    u_component_of_wind850  (time, longitude, latitude) float32 8kB -3.03 ......\n",
+       "    v_component_of_wind850  (time, longitude, latitude) float32 8kB -0.7963 ....\n",
+       "    vorticity850            (time, longitude, latitude) float32 8kB -1.147e-0...\n",
+       "    geopotential850         (time, longitude, latitude) float32 8kB 1.117e+04...\n",
+       "Attributes:\n",
+       "    level-dtype:  int64
" + ], + "text/plain": [ + " Size: 42kB\n", + "Dimensions: (time: 1, longitude: 64, latitude: 32)\n", + "Coordinates:\n", + " * time (time) datetime64[ns] 8B 2021-01-01\n", + " * longitude (longitude) float64 512B 0.0 5.625 ... 348.8 354.4\n", + " * latitude (latitude) float64 256B -87.19 -81.56 ... 87.19\n", + "Data variables:\n", + " 2m_temperature (time, longitude, latitude) float32 8kB 241.9 ......\n", + " u_component_of_wind850 (time, longitude, latitude) float32 8kB -3.03 ......\n", + " v_component_of_wind850 (time, longitude, latitude) float32 8kB -0.7963 ....\n", + " vorticity850 (time, longitude, latitude) float32 8kB -1.147e-0...\n", + " geopotential850 (time, longitude, latitude) float32 8kB 1.117e+04...\n", + "Attributes:\n", + " level-dtype: int64" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "(prepare | reshape | reshape.reversed)[\"20210101T00\"]" + ] + }, + { + "cell_type": "markdown", + "id": "4f62871a-52d1-44e9-88dc-3efdd453c25a", + "metadata": {}, + "source": [ + "### End-to-end inference pipeline example\n", + "\n", + "Now, let's imagine one of our steps is an inference from a model, returning a numpy array.\n", + "To create an end-to-end pipeline, generating xarray samples in the same space as the original data source, we will reverse the whole preparation pipeline and add it at the end.\n", + "Note that in the reversed pipeline, the `pyearthtools.data.download.weatherbench.WB2ERA5` and `pyearthtools.pipeline.modifications.TemporalRetrieval` steps do not apply any effect on the sample being transformed, being effectively skipped.\n", + "\n", + "For illustration purpose, we will use a simple persistence model, that return the last sample from the tuple generated via `pyearthtools.pipeline.modifications.TemporalRetrieval`." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "888d7803-de27-4c38-8fdc-a527f0a84b59", + "metadata": {}, + "outputs": [], + "source": [ + "from pyearthtools.pipeline.step import PipelineStep\n", + "\n", + "class Persistence(PipelineStep):\n", + " def __init__(self):\n", + " super().__init__()\n", + " self.record_initialisation()\n", + " def run(self, sample):\n", + " return sample[-1]\n", + "\n", + "persistence = Persistence()" + ] + }, + { + "cell_type": "markdown", + "id": "ed984c4f-8e44-4c67-9e6e-dd569727cbbc", + "metadata": {}, + "source": [ + "We can run create an inference pipeline by adding it at the end of our preprocessing pipeline.\n", + "Unfortunately, this returns a numpy array sample when queried, missing all the information about variable names and coordinates." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "ff3924fc-adb6-435f-8914-7fbbb39dc75a", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array(..., shape=(5, 64, 32), dtype=float32)" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "(pipeline | persistence)[\"20210101T00\"]" + ] + }, + { + "cell_type": "markdown", + "id": "c962c00f-46e0-4d68-b4ac-11bc844acd86", + "metadata": {}, + "source": [ + "To get a real end-to-end pipeline, we just need to add the preprocessing pipeline reversed." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "af17fc8b-ad48-4a13-a039-2df7d5ffebb3", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
Pipeline\n",
+       "\tDescription                    `pyearthtools.pipeline` Data Pipeline\n",
+       "\n",
+       "\n",
+       "\tInitialisation                 \n",
+       "\t\t exceptions_to_ignore           None\n",
+       "\t\t iterator                       None\n",
+       "\t\t name                           None\n",
+       "\t\t sampler                        None\n",
+       "\tSteps                          \n",
+       "\t\t weatherbench.WB2ERA5           {'WB2ERA5': {'level': '[850]', 'license_ok': 'True', 'resolution': "'64x32'", 'variables': "['2m_temperature', 'u', 'v', 'geopotential', 'vorticity']"}}\n",
+       "\t\t sort.Sort                      {'Sort': {'order': "['2m_temperature', 'u_component_of_wind', 'v_component_of_wind', 'vorticity', 'geopotential']", 'strict': 'False'}}\n",
+       "\t\t coordinates.StandardLongitude  {'StandardLongitude': {'longitude_name': "'longitude'", 'type': "'0-360'"}}\n",
+       "\t\t reshape.CoordinateFlatten      {'CoordinateFlatten': {'coordinate': "['level']", 'skip_missing': 'False'}}\n",
+       "\t\t idx_modification.TemporalRetrieval {'TemporalRetrieval': {'concat': 'True', 'delta_unit': 'None', 'merge_function': 'None', 'merge_kwargs': 'None', 'samples': '((0, 1), (6, 1))'}}\n",
+       "\t\t conversion.ToNumpy             {'ToNumpy': {'reference_dataset': 'None', 'run_parallel': 'False', 'saved_records': 'None', 'warn': 'True'}}\n",
+       "\t\t reshape.Rearrange              {'Rearrange': {'rearrange': "'c t h w -> t c h w'", 'rearrange_kwargs': 'None', 'reverse_rearrange': 'None', 'skip': 'False'}}\n",
+       "\t\t reshape.Squeeze                {'Squeeze': {'axis': '0'}}\n",
+       "\t\t __main__.Persistence           {'Persistence': {}}\n",
+       "\t\t controller.ReversedPipeline    {'ReversedPipeline': {'forward_pipeline': {'Pipeline': {'__args': '(WB2ERA5\\n\\tDescription                    WeatherBench2 cloud-optimized ground truth ERA5 dataset\\n\\t\\t link                           \\'https://github.com/google-research/weatherbench2\\'\\n\\n\\n\\tInitialisation                 \\n\\t\\t level                          [850]\\n\\t\\t license_ok                     True\\n\\t\\t resolution                     \\'64x32\\'\\n\\t\\t variables                      [\\'2m_temperature\\', \\'u\\', \\'v\\', \\'geopotential\\', \\'vorticity\\']\\n\\tTransforms                     \\n\\t\\t StandardCoordinateNames        {\\'latitude\\': "[\\'lat\\', \\'Latitude\\', \\'yt_ocean\\', \\'yt\\']", \\'longitude\\': "[\\'lon\\', \\'Longitude\\', \\'xt_ocean\\', \\'xt\\']", \\'replacement_dictionary\\': \\'None\\', \\'time\\': "[\\'Time\\']"}, Sort\\n\\tInitialisation                 Sort Variables of an `xarray` object\\n\\t\\t order                          [\\'2m_temperature\\', \\'u_component_of_wind\\', \\'v_component_of_wind\\', \\'vorticity\\', \\'geopotential\\']\\n\\t\\t strict                         False, StandardLongitude\\n\\tInitialisation                 Standardise format of longitude.\\n\\t\\t longitude_name                 \\'longitude\\'\\n\\t\\t type                           \\'0-360\\', CoordinateFlatten\\n\\tInitialisation                 Flatten a coordinate in a dataset into separate variables.\\n\\t\\t coordinate                     [\\'level\\']\\n\\t\\t skip_missing                   False, TemporalRetrieval\\n\\tInitialisation                 Retrieve a sequence of samples from `SequenceRetrieval`,\\n\\t\\t concat                         True\\n\\t\\t delta_unit                     None\\n\\t\\t merge_function                 None\\n\\t\\t merge_kwargs                   None\\n\\t\\t samples                        ((0, 1), (6, 1)), ToNumpy\\n\\tInitialisation                 Convert xarray objects to np.ndarray\\'s\\n\\t\\t reference_dataset              None\\n\\t\\t run_parallel                   False\\n\\t\\t saved_records                  None\\n\\t\\t warn                           True, Rearrange\\n\\tInitialisation                 Operation to rearrange data using einops\\n\\t\\t rearrange                      \\'c t h w -> t c h w\\'\\n\\t\\t rearrange_kwargs               None\\n\\t\\t reverse_rearrange              None\\n\\t\\t skip                           False, Squeeze\\n\\tInitialisation                 Operation to Squeeze one-Dimensional axes at \\'axis\\' location\\n\\t\\t axis                           0)', 'exceptions_to_ignore': 'None', 'iterator': 'None', 'name': 'None', 'sampler': 'None'}}}}
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "

Graph

" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "WB2ERA5_8f74e54e-ff0b-4fc8-af15-bed6d42148ee\n", + "\n", + "weatherbench.WB2ERA5\n", + "\n", + "\n", + "\n", + "Sort_6fc7dc0a-0622-41bc-9dce-d279eb05e3a0\n", + "\n", + "sort.Sort\n", + "\n", + "\n", + "\n", + "WB2ERA5_8f74e54e-ff0b-4fc8-af15-bed6d42148ee->Sort_6fc7dc0a-0622-41bc-9dce-d279eb05e3a0\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "StandardLongitude_025db7fd-0c24-48f2-8b66-c2dfa991b40d\n", + "\n", + "coordinates.StandardLongitude\n", + "\n", + "\n", + "\n", + "Sort_6fc7dc0a-0622-41bc-9dce-d279eb05e3a0->StandardLongitude_025db7fd-0c24-48f2-8b66-c2dfa991b40d\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "CoordinateFlatten_e7d0b52e-2594-4c33-b19a-03962c57694c\n", + "\n", + "reshape.CoordinateFlatten\n", + "\n", + "\n", + "\n", + "StandardLongitude_025db7fd-0c24-48f2-8b66-c2dfa991b40d->CoordinateFlatten_e7d0b52e-2594-4c33-b19a-03962c57694c\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "TemporalRetrieval_5af3ee8e-9679-4b78-a656-7553aef36853\n", + "\n", + "idx_modification.TemporalRetrieval\n", + "\n", + "\n", + "\n", + "CoordinateFlatten_e7d0b52e-2594-4c33-b19a-03962c57694c->TemporalRetrieval_5af3ee8e-9679-4b78-a656-7553aef36853\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "ToNumpy_270f54fc-de8b-46c7-8331-9564b180922d\n", + "\n", + "conversion.ToNumpy\n", + "\n", + "\n", + "\n", + "TemporalRetrieval_5af3ee8e-9679-4b78-a656-7553aef36853->ToNumpy_270f54fc-de8b-46c7-8331-9564b180922d\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "Rearrange_2bea21c5-f550-4e37-a4eb-a830664eeb03\n", + "\n", + "reshape.Rearrange\n", + "\n", + "\n", + "\n", + "ToNumpy_270f54fc-de8b-46c7-8331-9564b180922d->Rearrange_2bea21c5-f550-4e37-a4eb-a830664eeb03\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "Squeeze_7fdaa980-f936-4b22-8db5-02c13a59b44c\n", + "\n", + "reshape.Squeeze\n", + "\n", + "\n", + "\n", + "Rearrange_2bea21c5-f550-4e37-a4eb-a830664eeb03->Squeeze_7fdaa980-f936-4b22-8db5-02c13a59b44c\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "Persistence_3a531f51-7cd1-43a3-8f4e-ff3b9652edad\n", + "\n", + "__main__.Persistence\n", + "\n", + "\n", + "\n", + "Squeeze_7fdaa980-f936-4b22-8db5-02c13a59b44c->Persistence_3a531f51-7cd1-43a3-8f4e-ff3b9652edad\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "ReversedPipeline_2013b325-e01b-481b-b902-9f4e3c1f0931\n", + "\n", + "controller.ReversedPipeline\n", + "\n", + "\n", + "\n", + "Persistence_3a531f51-7cd1-43a3-8f4e-ff3b9652edad->ReversedPipeline_2013b325-e01b-481b-b902-9f4e3c1f0931\n", + "\n", + "\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "end_to_end = pipeline | persistence | pipeline.reversed\n", + "end_to_end" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "a5dc759a-800c-4eba-bed5-9417d843c8ee", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.Dataset> Size: 42kB\n",
+       "Dimensions:              (time: 1, longitude: 64, latitude: 32, level: 1)\n",
+       "Coordinates:\n",
+       "  * time                 (time) datetime64[ns] 8B 2021-01-01\n",
+       "  * longitude            (longitude) float64 512B 0.0 5.625 ... 348.8 354.4\n",
+       "  * latitude             (latitude) float64 256B -87.19 -81.56 ... 81.56 87.19\n",
+       "  * level                (level) float64 8B 850.0\n",
+       "Data variables:\n",
+       "    2m_temperature       (time, longitude, latitude) float32 8kB 241.8 ... 260.0\n",
+       "    u_component_of_wind  (time, longitude, latitude) float32 8kB -2.618 ... -...\n",
+       "    v_component_of_wind  (time, longitude, latitude) float32 8kB -2.405 ... 1...\n",
+       "    vorticity            (time, longitude, latitude) float32 8kB -1.714e-05 ....\n",
+       "    geopotential         (time, longitude, latitude) float32 8kB 1.12e+04 ......
" + ], + "text/plain": [ + " Size: 42kB\n", + "Dimensions: (time: 1, longitude: 64, latitude: 32, level: 1)\n", + "Coordinates:\n", + " * time (time) datetime64[ns] 8B 2021-01-01\n", + " * longitude (longitude) float64 512B 0.0 5.625 ... 348.8 354.4\n", + " * latitude (latitude) float64 256B -87.19 -81.56 ... 81.56 87.19\n", + " * level (level) float64 8B 850.0\n", + "Data variables:\n", + " 2m_temperature (time, longitude, latitude) float32 8kB 241.8 ... 260.0\n", + " u_component_of_wind (time, longitude, latitude) float32 8kB -2.618 ... -...\n", + " v_component_of_wind (time, longitude, latitude) float32 8kB -2.405 ... 1...\n", + " vorticity (time, longitude, latitude) float32 8kB -1.714e-05 ....\n", + " geopotential (time, longitude, latitude) float32 8kB 1.12e+04 ......" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "end_to_end[\"20210101T00\"]" + ] + }, + { + "cell_type": "markdown", + "id": "2b38cc92-50d5-4143-bb01-7e8aebf828f1", + "metadata": {}, + "source": [ + "Note that our simple model doesn't handle the time information, so we would need to fix the time information of the sample, for example adding 6 hours to the `time` coordinate if we are forecasting at T+6H." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ac3da9c0-11b1-4e36-84a6-33ae39348a89", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.7" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/packages/pipeline/src/pyearthtools/pipeline/__init__.py b/packages/pipeline/src/pyearthtools/pipeline/__init__.py index 5f2c85b7..43e54af1 100644 --- a/packages/pipeline/src/pyearthtools/pipeline/__init__.py +++ b/packages/pipeline/src/pyearthtools/pipeline/__init__.py @@ -47,7 +47,7 @@ # from pyearthtools.pipeline.save import save, load -from pyearthtools.pipeline.controller import Pipeline, PipelineIndex +from pyearthtools.pipeline.controller import Pipeline, PipelineIndex, ReversedPipeline from pyearthtools.pipeline._save_pipeline import load_pipeline as load from pyearthtools.pipeline.operation import Operation @@ -83,6 +83,7 @@ __all__ = [ "Sampler", "Pipeline", + "ReversedPipeline", "Operation", "branching", "exceptions", diff --git a/packages/pipeline/src/pyearthtools/pipeline/controller.py b/packages/pipeline/src/pyearthtools/pipeline/controller.py index c3805ecd..fabd933a 100644 --- a/packages/pipeline/src/pyearthtools/pipeline/controller.py +++ b/packages/pipeline/src/pyearthtools/pipeline/controller.py @@ -229,6 +229,7 @@ def __init__( iterator: Optional[Union[iterators.Iterator, tuple[iterators.Iterator, ...]]] = None, sampler: Optional[Union[samplers.Sampler, tuple[samplers.Sampler, ...]]] = None, exceptions_to_ignore: Optional[tuple[Union[str, Type[Exception]], ...]] = None, + name: str | None = None, **kwargs, ): """ @@ -311,13 +312,15 @@ def __init__( Can be used to randomly sample, drop out and more exceptions_to_ignore: Which exceptions to ignore when iterating. Defaults to None. + + name: Name of the pipeline, used in nested pipelines """ self.iterator = iterator self.sampler = sampler - + self.name = name + self._named = {} super().__init__(*steps, **kwargs) self.record_initialisation() - self.exceptions_to_ignore = exceptions_to_ignore @property @@ -394,10 +397,24 @@ def steps( # steps_list = [v] elif isinstance(v, Pipeline): steps_list.extend(v.steps) + self._add_named_pipe(v) else: steps_list.append(v) self._steps = tuple(steps_list) # type: ignore + @property + def named(self) -> dict[str, Pipeline]: + """Named sub-pipelines""" + return self._named.copy() + + def _add_named_pipe(self, pipe: Pipeline): + """add or merge a nested pipeline in the dictionary of named pipelines""" + new_pipes = pipe._named if pipe.name is None else {pipe.name: pipe} + for name, nested_pipe in new_pipes.items(): + if name in self._named: + raise KeyError(f"Named pipeline '{name}' already exists.") + self._named[name] = nested_pipe + @property def iterator(self): """Iterator of `Pipeline`""" @@ -772,8 +789,7 @@ def __contains__(self, id: Union[str, Type[Any]]) -> bool: return False def __add__(self, other: Union[_Pipeline, PipelineIndex, PipelineStep]) -> Pipeline: - """ - Combine pipelines + """Combine pipelines Will set `self` steps first then `other`. @@ -788,13 +804,42 @@ def __add__(self, other: Union[_Pipeline, PipelineIndex, PipelineStep]) -> Pipel new_init = dict(init) new_init.update({key: val for key, val in other_init.items() if val is not None}) + new_init["name"] = None - return Pipeline(*args, **new_init) + new_pipe = Pipeline(*args, **new_init) + new_pipe._add_named_pipe(self) + new_pipe._add_named_pipe(other) + return new_pipe assert isinstance(other, (PipelineIndex, PipelineStep)) - init = dict(self.initialisation) + init = {**self.initialisation, "name": None} args = (*init.pop("__args", []), other) - return Pipeline(*args, **init) + new_pipe = Pipeline(*args, **init) + new_pipe._add_named_pipe(self) + return new_pipe + + def __or__(self, other: Union[_Pipeline, PipelineIndex, PipelineStep]) -> Pipeline: + """Combine pipelines + + Same as + operator, alternative syntax. + """ + return self + other + + def __ror__( + self, + other: Union[ + VALID_PIPELINE_TYPES, + _Pipeline, + PipelineIndex, + tuple[Union[VALID_PIPELINE_TYPES, Literal["map", "map_copy"]], ...], + ], + ) -> Pipeline: + """Append a step in front of a pipeline""" + init = {**self.initialisation, "name": None} + args = (other, *init.pop("__args", [])) + new_pipe = Pipeline(*args, **init) + new_pipe._add_named_pipe(self) + return new_pipe def save(self, path: Optional[Union[str, Path]] = None, only_steps: bool = False) -> Union[str, None]: """ @@ -843,3 +888,26 @@ def sample( iterator=iterator, sampler=sampler, ) + + @property + def reversed(self) -> "ReversedPipeline": + return ReversedPipeline(self) + + +class ReversedPipeline(Operation): + """Operation reversing the effect of pipeline + + Applying this operation will undo the provided pipeline, while undoing this + operation will apply the pipeline. + """ + + def __init__(self, forward_pipeline: Pipeline): + super().__init__() + self.forward_pipeline = forward_pipeline + self.record_initialisation() + + def undo_func(self, sample): + return self.forward_pipeline.apply(sample) + + def apply_func(self, sample): + return self.forward_pipeline.undo(sample)