diff --git a/notebooks/tutorial/AutoEncoder_Example.ipynb b/notebooks/tutorial/AutoEncoder_Example.ipynb new file mode 100644 index 00000000..9498c159 --- /dev/null +++ b/notebooks/tutorial/AutoEncoder_Example.ipynb @@ -0,0 +1,2124 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "af873d2f-3b70-4f36-8e75-71e813b8e379", + "metadata": {}, + "source": [ + "# Nowcasting Autoencoder Tutorial\n", + "\n", + "Creating an effective autoencoder is a great first step in developing a predictive model. Learning how to create a high-quality latent state is critical, and can later be used to train a new encoder for a pre-existing predictive model, or a new decoder. An example for this could be training an encoder which uses different or fewer variables than the original model. Another might be training a higher-resolution decoder which is capable of producing higher-resolution predictions from the same latent state.\n", + "\n", + "Autoencoders are also a great concept to learn when understanding neural network architectures.\n", + "\n", + "In this example, we first blend radar and satellite data onto the same grid, then train an autoencoder to perform dimensionality reduction and produce a useful latent state which can be used to resonstruct the original inputs." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "3585d2a8-87ab-4b2d-bbfd-be84978bc919", + "metadata": {}, + "outputs": [], + "source": [ + "import pyearthtools.data as petdata\n", + "import pyearthtools.pipeline as petpipe\n", + "import site_archive_nci\n", + "\n", + "from pyearthtools.data.time import Petdt\n", + "from pyearthtools.pipeline.operations.xarray.join import GeospatialTimeSeriesMerge\n", + "\n", + "import xarray as xr\n", + "\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.optim as optim\n", + "\n", + "import matplotlib.pyplot as plt\n", + "\n", + "\n", + "# Set random seed for reproducibility\n", + "torch.manual_seed(42)\n", + "\n", + "# Autodetect GPU and use if possible\n", + "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "8d868313-c96d-4cf1-a865-2f222467aba5", + "metadata": {}, + "outputs": [], + "source": [ + "rf3proj = petdata.transforms.projection.Rainfields3ProjAus()\n", + "radar_projector = petdata.transforms.projection.XYtoLonLatRectilinear(rf3proj)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "bd7a1c13-31db-4d73-951b-f1d15cd241e8", + "metadata": {}, + "outputs": [], + "source": [ + "# We specify the date, hour, and minute for querying data\n", + "doi = '2021-06-09T02'" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "2f19a69a-b000-478b-81db-3ecc6e49d31c", + "metadata": {}, + "outputs": [], + "source": [ + "himawari = petdata.archive.Himawari('surface_global_irradiance')" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "ac2f22d9-6b2f-4798-87a6-3a57f93a1083", + "metadata": {}, + "outputs": [], + "source": [ + "# TODO: It would be nice if this normalised the data nicely\n", + "satpipe = petpipe.Pipeline(\n", + " himawari\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "e4d697bd-5d14-4800-8635-e7704c8c8078", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/g/data/kd24/tjl/src/PyEarthTools/packages/data/src/pyearthtools/data/operations/index_routines.py:329: FutureWarning: In a future version of xarray the default value for data_vars will change from data_vars='all' to data_vars=None. This is likely to lead to different results when multiple datasets have matching variables with overlapping values. To opt in to new defaults and get rid of these warnings now use `set_options(use_new_combine_kwarg_defaults=True) or set data_vars explicitly.\n", + " full_ds = xr.open_mfdataset(\n" + ] + }, + { + "data": { + "text/html": [ + "
<xarray.Dataset> Size: 183MB\n", + "Dimensions: (time: 6, latitude: 1726, longitude: 2214)\n", + "Coordinates:\n", + " * time (time) datetime64[ns] 48B 2021-06-09T02:00:00 ...\n", + " * latitude (latitude) float32 7kB -44.5 -44.48 ... -10.0\n", + " * longitude (longitude) float32 9kB 112.0 112.0 ... 156.3\n", + "Data variables:\n", + " surface_global_irradiance (time, latitude, longitude) float64 183MB dask.array<chunksize=(1, 1726, 2214), meta=np.ndarray>\n", + "Attributes: (12/50)\n", + " Conventions: CF-1.7\n", + " Metadata_Conventions: Unidata Dataset Discovery v1.0\n", + " acknowledgment: The following acknowledgement is requir...\n", + " cdm_data_type: grid\n", + " comment: Solar radiation data derived from satel...\n", + " contributor_name: Mines ParisTech; Commonwealth of Austra...\n", + " ... ...\n", + " geospatial_lon_resolution: 0.02\n", + " bias_correction_applied_meaning: 0: not applied; 1:applied\n", + " quality_meaning: 0: no_known_issues 1: known_issue \n", + " project: Gridded Solar Observations\n", + " references: Poulsen C., Majewski L. J. (2022) Gridd...\n", + " NCO: netCDF Operators version 4.7.7 (Homepag...
<xarray.Dataset> Size: 245MB\n", + "Dimensions: (time: 1, latitude: 1726, longitude: 2214, n2: 2)\n", + "Coordinates:\n", + " * time (time) datetime64[ns] 8B 2021-01-01\n", + " * latitude (latitude) float32 7kB -44.5 -44.48 ... -10.0\n", + " * longitude (longitude) float32 9kB 112.0 112.0 ... 156.3\n", + " x (longitude, latitude) float64 31MB -1.651e+03 ...\n", + " y (longitude, latitude) float64 31MB -4.99e+03 ....\n", + "Dimensions without coordinates: n2\n", + "Data variables:\n", + " surface_global_irradiance (time, latitude, longitude) float64 31MB dask.array<chunksize=(1, 1726, 2214), meta=np.ndarray>\n", + " proj (time) int8 1B 0\n", + " y_bounds (time, longitude, latitude, n2) float64 61MB -...\n", + " x_bounds (time, longitude, latitude, n2) float64 61MB -...\n", + " rain_rate (time, longitude, latitude) float64 31MB nan ....\n", + "Attributes: (12/58)\n", + " Conventions: CF-1.7\n", + " Metadata_Conventions: Unidata Dataset Discovery v1.0\n", + " acknowledgment: The following acknowledgement is requir...\n", + " cdm_data_type: grid\n", + " comment: Solar radiation data derived from satel...\n", + " contributor_name: Mines ParisTech; Commonwealth of Austra...\n", + " ... ...\n", + " quality: 0\n", + " quality_meaning: 0: no_known_issues 1: known_issue \n", + " project: Gridded Solar Observations\n", + " history: Mon Mar 4 01:55:23 2024: ncatted -a re...\n", + " references: Poulsen C., Majewski L. J. (2022) Gridd...\n", + " NCO: netCDF Operators version 4.7.7 (Homepag...