diff --git a/notebooks/tutorial/CNN-Model-Training.ipynb b/notebooks/tutorial/CNN-Model-Training.ipynb index 286173a0..b279d56b 100644 --- a/notebooks/tutorial/CNN-Model-Training.ipynb +++ b/notebooks/tutorial/CNN-Model-Training.ipynb @@ -75,9 +75,6 @@ "test_start = \"2017-01-01T00\"\n", "test_end = \"2017-01-12T00\"\n", "\n", - "# number of samples to estimate mean & standard deviation of fields\n", - "n_samples = 200\n", - "\n", "# data loader parameters\n", "batch_size = 1\n", "n_workers = 2\n", @@ -121,7 +118,9 @@ " - Index 0 of the sample contains the input data at T+0hr, using the (0, 1) tuple.\n", " - Index 1 of the sample contains the output data at T+6hrs, using the (6, 1) tuple.\n", "\n", - "The input data is used by the model to make predictions, while the output data represents the true values that the model aims to predict." + "The input data is used by the model to make predictions, while the output data represents the true values that the model aims to predict.\n", + "\n", + "*Note: A subset of ERA5 data is downloaded during the first instantiation of the pipeline, it may take a bit of time.*" ] }, { @@ -130,124 +129,6 @@ "id": "8eb82b2d", "metadata": {}, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Saving dataset, it will take at most 3.83 gigabytes of storage space.\n", - "macadamia - 2025-08-25 02:28:15,523 - pyearthtools.data.download.weatherbench - weatherbench - save_local_dataset - L123 - WARNING - Saving dataset, it will take at most 3.83 gigabytes of storage space.\n", - "Saving 2m_temperature variable under cnn_training/download/ee5f0931735d8d146214aa551175dfa34b3e33093aec3b93d20bcc78c56bcd1b/2m_temperature.zarr, it will take at most 766.31 megabytes of storage space.\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "d5d6a6b09b424e338a0d46ac97fda57c", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Writing: 0%| | 0/937 [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Saving 2m_temperature variable finished.\n", - "Saving geopotential variable (level 850) under cnn_training/download/ee5f0931735d8d146214aa551175dfa34b3e33093aec3b93d20bcc78c56bcd1b/geopotential_level-850.zarr, it will take at most 766.31 megabytes of storage space.\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "9486859921504cf7b0cb65fd94a0291f", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Writing: 0%| | 0/2809 [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Saving geopotential variable (level 850) finished.\n", - "Saving u_component_of_wind variable (level 850) under cnn_training/download/ee5f0931735d8d146214aa551175dfa34b3e33093aec3b93d20bcc78c56bcd1b/u_component_of_wind_level-850.zarr, it will take at most 766.31 megabytes of storage space.\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "3354968bd882408d8997f82eb35e34aa", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Writing: 0%| | 0/2809 [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Saving u_component_of_wind variable (level 850) finished.\n", - "Saving v_component_of_wind variable (level 850) under cnn_training/download/ee5f0931735d8d146214aa551175dfa34b3e33093aec3b93d20bcc78c56bcd1b/v_component_of_wind_level-850.zarr, it will take at most 766.31 megabytes of storage space.\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "904b6ad17d794561bec8c1bae842f1be", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Writing: 0%| | 0/2809 [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Saving v_component_of_wind variable (level 850) finished.\n", - "Saving vorticity variable (level 850) under cnn_training/download/ee5f0931735d8d146214aa551175dfa34b3e33093aec3b93d20bcc78c56bcd1b/vorticity_level-850.zarr, it will take at most 766.31 megabytes of storage space.\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "de213bc9db374fce9f1fa888cc50cfa5", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Writing: 0%| | 0/2809 [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Saving vorticity variable (level 850) finished.\n" - ] - }, { "data": { "text/html": [ @@ -647,7 +528,7 @@ "\t\t sort.Sort {'Sort': {'order': "['2m_temperature', 'u_component_of_wind', 'v_component_of_wind', 'vorticity', 'geopotential']", 'strict': 'False'}}\n", "\t\t coordinates.StandardLongitude {'StandardLongitude': {'longitude_name': "'longitude'", 'type': "'0-360'"}}\n", "\t\t reshape.CoordinateFlatten {'CoordinateFlatten': {'__args': '()', 'coordinate': "'level'", 'skip_missing': 'False'}}\n", - "\t\t idx_modification.TemporalRetrieval {'TemporalRetrieval': {'concat': 'True', 'delta_unit': 'None', 'merge_function': 'None', 'merge_kwargs': 'None', 'samples': '((0, 1), (6, 1))'}}
" + "\t\t idx_modification.TemporalRetrieval {'TemporalRetrieval': {'concat': 'True', 'delta_unit': 'None', 'merge_function': 'None', 'merge_kwargs': 'None', 'samples': '((0, 1), (6, 1))'}}" ], "text/plain": [ "DateRandomise\n",
"\tInitialisation Wrap around another `Iterator` and randomly sample\n",
"\t\t iterator {'DateRange': {'allowlist': 'None', 'blocklist': 'None', 'end': "'2015-01-12T00'", 'interval': "'6h'", 'start': "'2013-01-01T00'"}}\n",
- "\t\t seed 42"
+ "\t\t seed 42"
],
"text/plain": [
"DateRandomise\n",
@@ -2267,7 +2148,7 @@
"\t\t blocklist None\n",
"\t\t end '2016-01-12T00'\n",
"\t\t interval '6h'\n",
- "\t\t start '2016-01-01T00'"
+ "\t\t start '2016-01-01T00'"
],
"text/plain": [
"DateRange\n",
@@ -2301,7 +2182,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "Train Splits: (Petdt('2013-04-11T12'), Petdt('2013-09-10T00'), Petdt('2014-03-19T00'), Petdt('2013-10-18T00'), Petdt('2013-04-16T00'))\n",
+ "Train Splits: (Petdt('2014-02-15T00'), Petdt('2013-08-19T00'), Petdt('2014-11-12T06'), Petdt('2015-01-04T06'), Petdt('2013-01-15T06'))\n",
"Val Splits: (Petdt('2016-01-01T00'), Petdt('2016-01-01T06'), Petdt('2016-01-01T12'), Petdt('2016-01-01T18'), Petdt('2016-01-02T00'))\n"
]
}
@@ -2333,8 +2214,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "CPU times: user 1min 56s, sys: 2.67 s, total: 1min 59s\n",
- "Wall time: 2min 1s\n"
+ "CPU times: user 1min 52s, sys: 4.18 s, total: 1min 56s\n",
+ "Wall time: 1min 58s\n"
]
}
],
@@ -2806,7 +2687,7 @@
"\t\t reshape.Rearrange {'Rearrange': {'rearrange': "'c t h w -> t c h w'", 'rearrange_kwargs': 'None', 'reverse_rearrange': 'None', 'skip': 'False'}}\n",
"\t\t reshape.Squeeze {'Squeeze': {'axis': '0'}}\n",
"\t\t normalisation.Deviation {'Deviation': {'deviation': "PosixPath('cnn_training/std.npy')", 'expand': 'False', 'mean': "PosixPath('cnn_training/mean.npy')"}}\n",
- "\t\t cache.Cache {'Cache': {'cache': "'/var/home/riomaxim/Synced/work/en_cours/PyEarthTools/notebooks/tutorial/cnn_training/cache'", 'cache_validity': "'warn'", 'pattern': 'None', 'pattern_kwargs': {'extension': "'npy'"}, 'save_kwargs': 'None'}}"
+ "\t\t cache.Cache {'Cache': {'cache': "'/var/home/riomaxim/Synced/work/en_cours/PyEarthTools/notebooks/tutorial/cnn_training/cache'", 'cache_validity': "'warn'", 'pattern': 'None', 'pattern_kwargs': {'extension': "'npy'"}, 'save_kwargs': 'None'}}"
],
"text/plain": [
"