diff --git a/notebooks/tutorial/HadISD/1_HadISD_Download.ipynb b/notebooks/tutorial/HadISD/1_HadISD_Download.ipynb index 31a7c932..ad495a08 100644 --- a/notebooks/tutorial/HadISD/1_HadISD_Download.ipynb +++ b/notebooks/tutorial/HadISD/1_HadISD_Download.ipynb @@ -30,7 +30,8 @@ "from tqdm.auto import tqdm\n", "import tarfile\n", "import gzip\n", - "import shutil" + "import shutil\n", + "from pathlib import Path" ] }, { @@ -45,13 +46,13 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "id": "2eadaf27", "metadata": {}, "outputs": [], "source": [ - "%run Data_Config.ipynb\n", - "print(f\"Data will be downloaded to: {download_dir}\") " + "# ruff: noqa: F821\n", + "%run Data_Config.ipynb" ] }, { @@ -60,95 +61,93 @@ "metadata": {}, "source": [ "### Download HadISD Data\n", - "The following code will download the HadISD data files. Some files take longer to download than others depending on time of day. To download different WMO datasets, you can change `wmo_id_range` in the `Data_Config.ipynb` notebook .\n", + "The following code will download the HadISD data files. Some files take longer to download than others depending on time of day. To download different WMO datasets, you can change `wmo_id_ranges` in the `Data_Config.ipynb` notebook.\n", "\n", "The full list of available data can be found here:\n", - "https://www.metoffice.gov.uk/hadobs/hadisd/v340_2023f/download.html" + "https://www.metoffice.gov.uk/hadobs/hadisd/v340_2023f/download.html\n", + "\n", + "Station data has been split up into ranges to make downloads more managable. You may download as much or as little as you like. To get started we reccomend just downloading a few station ranges to get an idea of how to use HadISD data with PyEarthTools. " ] }, { "cell_type": "code", - "execution_count": null, - "id": "feb8d671", + "execution_count": 3, + "id": "8ddbebda", "metadata": {}, "outputs": [], "source": [ - "# Explain why stations are split into ranges, file size, and how it's not neccesssary to download all stations. " + "wmo_id_ranges = wmo_id_ranges # This has been defined in HadISD_data_config.ipynb" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "id": "11a188d4", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Downloading HadISD data for WMO range: ['500000-549999', '722000-722999', '800000-849999']\n" + ] + } + ], "source": [ - "print(f\"Downloading HadISD data for WMO range: {wmo_id_range}\")" + "print(f\"Downloading HadISD data for WMO range: {wmo_id_ranges}\")" ] }, { "cell_type": "code", - "execution_count": null, - "id": "8ddbebda", - "metadata": {}, - "outputs": [], - "source": [ - "wmo_id_range = wmo_id_range # This has been defined in HadISD_data_config.ipynb\n", - "\n", - "wmo_str = f\"WMO_{wmo_id_range}\"\n", - "url = f\"https://www.metoffice.gov.uk/hadobs/hadisd/v340_2023f/data/{wmo_str}.tar.gz\"\n", - "tar_name = f\"{wmo_str}.tar\"\n", - "filename = download_dir / tar_name" - ] - }, - { - "cell_type": "code", - "execution_count": null, + "execution_count": 5, "id": "08ac36fd", "metadata": {}, "outputs": [], "source": [ - "# Get remote file size using HTTP HEAD\n", - "head = requests.head(url, allow_redirects=True)\n", - "remote_size = int(head.headers.get('content-length', 0))\n", - "\n", - "local_size = filename.stat().st_size if filename.exists() else 0\n", - "\n", - "if filename.exists() and local_size == remote_size:\n", - " print(f\"File already fully downloaded: {filename} ({local_size/1024**2:.2f} MB)\")\n", - "else:\n", - " headers = {}\n", - " mode = 'wb'\n", - " initial_pos = 0\n", - " if filename.exists() and local_size < remote_size:\n", - " headers['Range'] = f'bytes={local_size}-'\n", - " mode = 'ab'\n", - " initial_pos = local_size\n", - " print(f\"Resuming download for {filename.name} at {local_size/1024**2:.2f} MB...\")\n", - " else:\n", - " print(f\"Starting download for {filename.name}...\")\n", - "\n", - " response = requests.get(url, stream=True, headers=headers)\n", - " total = remote_size\n", - "\n", - " with open(filename, mode) as f, tqdm(\n", - " desc=f\"Downloading {filename.name}\",\n", - " total=total,\n", - " initial=initial_pos,\n", - " unit='B', unit_scale=True, unit_divisor=1024\n", - " ) as bar:\n", - " for chunk in response.iter_content(chunk_size=8192):\n", - " if chunk:\n", - " f.write(chunk)\n", - " bar.update(len(chunk))\n", - "\n", - " final_size = filename.stat().st_size\n", - " if final_size == remote_size:\n", - " print(f\"Download complete: {filename} ({final_size/1024**2:.2f} MB)\")\n", + "def download_wmo_range(wmo_id_range, download_dir):\n", + " wmo_str = f\"WMO_{wmo_id_range}\"\n", + " url = f\"https://www.metoffice.gov.uk/hadobs/hadisd/v340_2023f/data/{wmo_str}.tar.gz\"\n", + " tar_name = f\"{wmo_str}.tar\"\n", + " filename = Path(download_dir) / tar_name\n", + "\n", + " head = requests.head(url, allow_redirects=True)\n", + " remote_size = int(head.headers.get('content-length', 0))\n", + " local_size = filename.stat().st_size if filename.exists() else 0\n", + "\n", + " if filename.exists() and local_size == remote_size:\n", + " print(f\"File already fully downloaded: {filename} ({local_size/1024**2:.2f} MB)\")\n", " else:\n", - " print(f\"Warning: Download incomplete. Local size: {final_size}, Remote size: {remote_size}\")\n", + " headers = {}\n", + " mode = 'wb'\n", + " initial_pos = 0\n", + " if filename.exists() and local_size < remote_size:\n", + " headers['Range'] = f'bytes={local_size}-'\n", + " mode = 'ab'\n", + " initial_pos = local_size\n", + " print(f\"Resuming download for {filename.name} at {local_size/1024**2:.2f} MB...\")\n", + " else:\n", + " print(f\"Starting download for {filename.name}...\")\n", + "\n", + " response = requests.get(url, stream=True, headers=headers)\n", + " total = remote_size\n", + " with open(filename, mode) as f, tqdm(\n", + " desc=f\"Downloading {filename.name}\",\n", + " total=total,\n", + " initial=initial_pos,\n", + " unit='B', unit_scale=True, unit_divisor=1024\n", + " ) as bar:\n", + " for chunk in response.iter_content(chunk_size=8192):\n", + " if chunk:\n", + " f.write(chunk)\n", + " bar.update(len(chunk))\n", + "\n", + " final_size = filename.stat().st_size\n", + " if final_size == remote_size:\n", + " print(f\"Download complete: {filename} ({final_size/1024**2:.2f} MB)\")\n", + " else:\n", + " print(f\"Warning: Download incomplete. Local size: {final_size}, Remote size: {remote_size}\")\n", "\n", - "# Possibly also add check to see if netcdf files esist for the downloaded tar file, if so then don't download again" + " return filename, tar_name\n" ] }, { @@ -161,77 +160,239 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "id": "fb79a81c", "metadata": {}, "outputs": [], "source": [ - "extract_dir = download_dir / tar_name.replace('.tar', '')\n", - "extract_dir.mkdir(exist_ok=True)\n", - "\n", - "extracted_files = list(extract_dir.glob('*'))\n", - "if extracted_files:\n", - " print(f\"Extraction directory '{extract_dir}' already contains {len(extracted_files)} files. Skipping extraction.\")\n", - "elif filename.exists():\n", - " with tarfile.open(filename, \"r:gz\") as tar:\n", - " tar.extractall(path=extract_dir)\n", + "def extract_wmo_tar(filename, tar_name, download_dir):\n", + " extract_dir = Path(download_dir) / tar_name.replace('.tar', '')\n", + " extract_dir.mkdir(exist_ok=True)\n", " extracted_files = list(extract_dir.glob('*'))\n", " if extracted_files:\n", - " print(f\"Extraction successful. {len(extracted_files)} files found in {extract_dir}.\")\n", - " # Delete the tar file after extraction\n", - " filename.unlink()\n", - " print(f\"Deleted tar file: {filename}\")\n", + " print(f\"Extraction directory '{extract_dir}' already contains {len(extracted_files)} files. Skipping extraction.\")\n", + " elif filename.exists():\n", + " with tarfile.open(filename, \"r:gz\") as tar:\n", + " tar.extractall(path=extract_dir)\n", + " extracted_files = list(extract_dir.glob('*'))\n", + " if extracted_files:\n", + " print(f\"Extraction successful. {len(extracted_files)} files found in {extract_dir}.\")\n", + " filename.unlink()\n", + " print(f\"Deleted tar file: {filename}\")\n", + " else:\n", + " print(f\"Warning: No files extracted to {extract_dir}. Tar file will not be deleted.\")\n", + " raise RuntimeError(\"Extraction failed, tar file not deleted.\")\n", " else:\n", - " print(f\"Warning: No files extracted to {extract_dir}. Tar file will not be deleted.\")\n", - " raise RuntimeError(\"Extraction failed, tar file not deleted.\")\n", - "else:\n", - " print(f\"No tar file found and extraction directory is empty. Nothing to extract.\")\n", - " raise FileNotFoundError(f\"Missing tar file: {filename}\")\n" + " print(\"No tar file found and extraction directory is empty. Nothing to extract.\")\n", + " raise FileNotFoundError(f\"Missing tar file: {filename}\")\n", + " return extract_dir" ] }, { "cell_type": "code", - "execution_count": null, - "id": "53161550", + "execution_count": 7, + "id": "4e43dcc4", "metadata": {}, "outputs": [], "source": [ - "# Create subfolder for netcdf\n", - "netcdf_dir = download_dir / \"netcdf\"\n", - "netcdf_dir.mkdir(parents=True, exist_ok=True)" + "# Move extracted .nc files into netcdf_dir after extraction\n", + "def move_netcdf_files(extract_dir, download_dir):\n", + " netcdf_dir = Path(download_dir) / \"netcdf\"\n", + " netcdf_dir.mkdir(parents=True, exist_ok=True)\n", + " num_files = 0\n", + " for gz_path in extract_dir.glob('*.nc.gz'):\n", + " nc_path = gz_path.with_suffix('') # Remove .gz extension\n", + " with gzip.open(gz_path, 'rb') as f_in, open(nc_path, 'wb') as f_out:\n", + " f_out.write(f_in.read())\n", + " gz_path.unlink()\n", + " shutil.move(str(nc_path), netcdf_dir / nc_path.name)\n", + " num_files += 1\n", + " print(f\"{num_files} .nc files have been extracted, cleaned up, and moved to the netcdf directory: {netcdf_dir}\")\n", + "\n", + " # Delete extraction directory\n", + " try:\n", + " shutil.rmtree(extract_dir)\n", + " print(f\"Deleted extraction directory: {extract_dir}\")\n", + " except Exception as e:\n", + " print(f\"Could not delete extraction directory {extract_dir}: {e}\")" + ] + }, + { + "cell_type": "markdown", + "id": "932e8906", + "metadata": {}, + "source": [ + "### Idempotent Checks" ] }, { "cell_type": "code", - "execution_count": null, - "id": "4e43dcc4", + "execution_count": 8, + "id": "dcb9b902", "metadata": {}, "outputs": [], "source": [ - "# Move extracted .nc files into netcdf_dir after extraction\n", - "num_files = 0\n", - "for gz_path in extract_dir.glob('*.nc.gz'):\n", - " nc_path = gz_path.with_suffix('') # Remove .gz extension\n", - " with gzip.open(gz_path, 'rb') as f_in, open(nc_path, 'wb') as f_out:\n", - " f_out.write(f_in.read())\n", - " gz_path.unlink() # Delete the .gz file after extraction\n", - " shutil.move(str(nc_path), netcdf_dir / nc_path.name)\n", - " num_files += 1\n", - "\n", - "print(f\"{num_files} .nc files have been extracted, cleaned up, and moved to the netcdf directory: {netcdf_dir}\")\n", - "\n", - "# Delete the extraction directory after processing\n", - "try:\n", - " shutil.rmtree(extract_dir)\n", - " print(f\"Deleted extraction directory: {extract_dir}\")\n", - "except Exception as e:\n", - " print(f\"Could not delete extraction directory {extract_dir}: {e}\")" + "def netcdf_files_exist_for_range(wmo_id_range, netcdf_dir):\n", + " \"\"\"Check if any .nc files for the given WMO range exist in the netcdf directory.\"\"\"\n", + " start, end = map(int, wmo_id_range.split('-'))\n", + " nc_files = list(Path(netcdf_dir).glob(\"*.nc\"))\n", + " for nc_file in nc_files:\n", + " try:\n", + " # Extract the first 6 digits from the station part of the filename\n", + " station_part = nc_file.stem.split('_')[-1]\n", + " wmo_number = int(station_part.split('-')[0])\n", + " if start <= wmo_number <= end:\n", + " return True\n", + " except Exception as e:\n", + " print(f\"Skipping file {nc_file.name}: {e}\")\n", + " continue\n", + " print(f\"No NetCDF files found for WMO range {wmo_id_range}.\")\n", + " return False" ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "36ad9920", + "metadata": {}, + "outputs": [], + "source": [ + "def is_tar_fully_downloaded(wmo_id_range, download_dir):\n", + " \"\"\"Check if the tar file exists and is fully downloaded (size matches remote).\"\"\"\n", + " wmo_str = f\"WMO_{wmo_id_range}\"\n", + " tar_name = f\"{wmo_str}.tar\"\n", + " tar_path = Path(download_dir) / tar_name\n", + " url = f\"https://www.metoffice.gov.uk/hadobs/hadisd/v340_2023f/data/{wmo_str}.tar.gz\"\n", + "\n", + " if not tar_path.exists():\n", + " return False\n", + "\n", + " # Get remote file size\n", + " head = requests.head(url, allow_redirects=True)\n", + " remote_size = int(head.headers.get('content-length', 0))\n", + " local_size = tar_path.stat().st_size\n", + "\n", + " return local_size == remote_size" + ] + }, + { + "cell_type": "markdown", + "id": "77f044b9", + "metadata": {}, + "source": [ + "### Loop through each WMO range, download if necessary, extract, and move files" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "ffcc5730", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "No NetCDF files found for WMO range 500000-549999.\n", + "Starting download for WMO_500000-549999.tar...\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "274ec4d5674b4781a8657b2c98db1d66", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Downloading WMO_500000-549999.tar: 0%| | 0.00/411M [00:00\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.Dataset> Size: 327GB\n",
+       "Dimensions:                (station: 770, time: 473352, flagged: 19, test: 71,\n",
+       "                            reporting_v: 19, reporting_t: 1116, reporting_2: 2)\n",
+       "Coordinates:\n",
+       "    elevation              (station) float64 6kB 1.676e+03 401.0 ... 117.7 36.0\n",
+       "    latitude               (station) float64 6kB 38.47 42.33 ... 45.93 30.48\n",
+       "    longitude              (station) float64 6kB 102.2 120.7 ... 126.6 -87.19\n",
+       "  * station                (station) object 6kB '526740-99999' ... '722223-13...\n",
+       "  * time                   (time) datetime64[ns] 4MB 1970-01-01 ... 2023-12-3...\n",
+       "Dimensions without coordinates: flagged, test, reporting_v, reporting_t,\n",
+       "                                reporting_2\n",
+       "Data variables: (12/25)\n",
+       "    cloud_base             (station, time) float64 3GB dask.array<chunksize=(1, 59169), meta=np.ndarray>\n",
+       "    dewpoints              (station, time) float64 3GB dask.array<chunksize=(1, 59169), meta=np.ndarray>\n",
+       "    flagged_obs            (station, time, flagged) float64 55GB dask.array<chunksize=(1, 29585, 3), meta=np.ndarray>\n",
+       "    high_cloud_cover       (station, time) float64 3GB dask.array<chunksize=(1, 59169), meta=np.ndarray>\n",
+       "    low_cloud_cover        (station, time) float64 3GB dask.array<chunksize=(1, 59169), meta=np.ndarray>\n",
+       "    mid_cloud_cover        (station, time) float64 3GB dask.array<chunksize=(1, 59169), meta=np.ndarray>\n",
+       "    ...                     ...\n",
+       "    stnlp                  (station, time) float64 3GB dask.array<chunksize=(1, 59169), meta=np.ndarray>\n",
+       "    temperatures           (station, time) float64 3GB dask.array<chunksize=(1, 59169), meta=np.ndarray>\n",
+       "    total_cloud_cover      (station, time) float64 3GB dask.array<chunksize=(1, 59169), meta=np.ndarray>\n",
+       "    wind_gust              (station, time) float64 3GB dask.array<chunksize=(1, 59169), meta=np.ndarray>\n",
+       "    winddirs               (station, time) float64 3GB dask.array<chunksize=(1, 59169), meta=np.ndarray>\n",
+       "    windspeeds             (station, time) float64 3GB dask.array<chunksize=(1, 59169), meta=np.ndarray>\n",
+       "Attributes: (12/39)\n",
+       "    Conventions:                 CF-1.6\n",
+       "    Metadata_Conventions:        Unidata Dataset Discovery v1.0, CF Discrete ...\n",
+       "    acknowledgement:             RJHD was supported by the Joint BEIS/Defra M...\n",
+       "    cdm_data_type:               station\n",
+       "    creator_email:               robert.dunn@metoffice.gov.uk\n",
+       "    creator_name:                Robert Dunn\n",
+       "    ...                          ...\n",
+       "    station_id:                  526740-99999\n",
+       "    station_information:         Where station is a composite the station id ...\n",
+       "    summary:                     Quality-controlled, sub-daily, station datas...\n",
+       "    time_coverage_end:           2002-05-13T09:00Z\n",
+       "    time_coverage_start:         1964-01-01T00:00Z\n",
+       "    title:                       HadISD
" + ], + "text/plain": [ + " Size: 327GB\n", + "Dimensions: (station: 770, time: 473352, flagged: 19, test: 71,\n", + " reporting_v: 19, reporting_t: 1116, reporting_2: 2)\n", + "Coordinates:\n", + " elevation (station) float64 6kB 1.676e+03 401.0 ... 117.7 36.0\n", + " latitude (station) float64 6kB 38.47 42.33 ... 45.93 30.48\n", + " longitude (station) float64 6kB 102.2 120.7 ... 126.6 -87.19\n", + " * station (station) object 6kB '526740-99999' ... '722223-13...\n", + " * time (time) datetime64[ns] 4MB 1970-01-01 ... 2023-12-3...\n", + "Dimensions without coordinates: flagged, test, reporting_v, reporting_t,\n", + " reporting_2\n", + "Data variables: (12/25)\n", + " cloud_base (station, time) float64 3GB dask.array\n", + " dewpoints (station, time) float64 3GB dask.array\n", + " flagged_obs (station, time, flagged) float64 55GB dask.array\n", + " high_cloud_cover (station, time) float64 3GB dask.array\n", + " low_cloud_cover (station, time) float64 3GB dask.array\n", + " mid_cloud_cover (station, time) float64 3GB dask.array\n", + " ... ...\n", + " stnlp (station, time) float64 3GB dask.array\n", + " temperatures (station, time) float64 3GB dask.array\n", + " total_cloud_cover (station, time) float64 3GB dask.array\n", + " wind_gust (station, time) float64 3GB dask.array\n", + " winddirs (station, time) float64 3GB dask.array\n", + " windspeeds (station, time) float64 3GB dask.array\n", + "Attributes: (12/39)\n", + " Conventions: CF-1.6\n", + " Metadata_Conventions: Unidata Dataset Discovery v1.0, CF Discrete ...\n", + " acknowledgement: RJHD was supported by the Joint BEIS/Defra M...\n", + " cdm_data_type: station\n", + " creator_email: robert.dunn@metoffice.gov.uk\n", + " creator_name: Robert Dunn\n", + " ... ...\n", + " station_id: 526740-99999\n", + " station_information: Where station is a composite the station id ...\n", + " summary: Quality-controlled, sub-daily, station datas...\n", + " time_coverage_end: 2002-05-13T09:00Z\n", + " time_coverage_start: 1964-01-01T00:00Z\n", + " title: HadISD" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "ds_combined = load_combined_dataset(zarr_output_dir)\n", "ds_combined" ] - }, - { - "cell_type": "markdown", - "id": "d448a05f", - "metadata": {}, - "source": [ - "# Data Organization: NetCDF and Zarr stores\n", - "\n", - "To keep your workflow clear and reproducible, we recommend storing both the raw NetCDF files and the processed Zarr data in separate subfolders inside your main WMO directory. For example:\n", - "\n", - "- `HadISD_data/WMO_080000-099999/netcdf/` (raw NetCDF files)\n", - "- `HadISD_data/WMO_080000-099999/zarr/` (processed Zarr stores with harmonized time coordinates)\n", - "\n", - "This makes it obvious which data is raw and which is ready for fast, parallel analysis." - ] } ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "pyearthtools", "language": "python", "name": "python3" }, @@ -312,7 +3812,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.13.2" + "version": "3.13.5" } }, "nbformat": 4, diff --git a/notebooks/tutorial/HadISD/3_HadISD_XGBoost_Pipeline.ipynb b/notebooks/tutorial/HadISD/3_HadISD_XGBoost_Pipeline.ipynb index 8d2cf388..fc7dc079 100644 --- a/notebooks/tutorial/HadISD/3_HadISD_XGBoost_Pipeline.ipynb +++ b/notebooks/tutorial/HadISD/3_HadISD_XGBoost_Pipeline.ipynb @@ -19,8 +19,9 @@ "outputs": [], "source": [ "import numpy as np\n", - "import pandas as pd\n", - "from pathlib import Path\n", + "import matplotlib.pyplot as plt\n", + "from xgboost import XGBClassifier\n", + "from sklearn.metrics import classification_report, confusion_matrix\n", "\n", "import pyearthtools.pipeline as petpipe\n", "import pyearthtools.data as petdata\n", @@ -36,10 +37,11 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ + "# ruff: noqa: F821\n", "%run Pipeline_Config.ipynb" ] }, @@ -53,9 +55,17 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Total number of stations: 770\n" + ] + } + ], "source": [ "hadisd = HadISDIndex()\n", "all_stations = hadisd.get_all_station_ids()\n", @@ -65,13 +75,21 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "List of first ten stations: ['501360-99999', '502460-99999', '503530-99999', '504340-99999', '504420-99999', '504680-99999', '505270-99999', '505480-99999', '505570-99999', '505640-99999']\n" + ] + } + ], "source": [ "# Select first n stations\n", "first_ten_stations = all_stations_ordered[:10]\n", - "print(f\"List of first ten stations:\", first_ten_stations)" + "print(\"List of first ten stations:\", first_ten_stations)" ] }, { @@ -90,16 +108,721 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
Pipeline\n",
+       "\tDescription                    `pyearthtools.pipeline` Data Pipeline\n",
+       "\n",
+       "\n",
+       "\tInitialisation                 \n",
+       "\t\t exceptions_to_ignore           None\n",
+       "\t\t iterator                       None\n",
+       "\t\t sampler                        None\n",
+       "\tSteps                          \n",
+       "\t\t HadisdDataClass.HadISDIndex    {'HadISDIndex': {'station': "['501360-99999', '502460-99999', '503530-99999', '504340-99999', '504420-99999', '504680-99999', '505270-99999', '505480-99999', '505570-99999', '505640-99999']", 'variables': 'None'}}\n",
+       "\t\t values.AddFlaggedObs           {'AddFlaggedObs': {'flagged_labels': "['temperatures', 'dewpoints', 'slp', 'stnlp' ... 'precip12_depth', 'precip15_depth', 'precip18_depth', 'precip24_depth']"}}\n",
+       "\t\t values.SetMissingToNaN         {'SetMissingToNaN': {'varname_val_map': {'total_cloud_cover': '-999.0', 'low_cloud_cover': '-999.0', 'mid_cloud_cover': '-999.0', 'high_cloud_cover': '-999.0', 'winddirs': '-999.0'}}}\n",
+       "\t\t variables.Drop                 {'Drop': {'__args': '()', 'variables': "'flagged_obs'"}}\n",
+       "\t\t coordinates.Drop               {'Drop': {'__args': '()', 'coordinates': "['latitude', 'longitude', 'elevation']", 'ignore_missing': 'False'}}\n",
+       "\t\t __main__.TrainTestSplit        {'TrainTestSplit': {}}\n",
+       "\t\t __main__.FeatureTargetSplit    {'FeatureTargetSplit': {}}\n",
+       "\t\t conversion.ToNumpy             {'ToNumpy': {'reference_dataset': 'None', 'run_parallel': 'False', 'saved_records': 'None', 'warn': 'True'}}\n",
+       "\t\t __main__.MedianImputePerStation {'MedianImputePerStation': {}}\n",
+       "\t\t __main__.TransposeAndFlattenX  {'TransposeAndFlattenX': {}}\n",
+       "\t\t conversion.ToNumpy[1]          {'ToNumpy': {'reference_dataset': 'None', 'run_parallel': 'False', 'saved_records': 'None', 'warn': 'True'}}\n",
+       "\t\t __main__.SelectAndFlattenY     {'SelectAndFlattenY': {}}\n",
+       "\t\t __main__.FilterNaNSamples      {'FilterNaNSamples': {}}\n",
+       "\t\t __main__.FeatureTargetSplit[1] {'FeatureTargetSplit': {}}\n",
+       "\t\t conversion.ToNumpy[2]          {'ToNumpy': {'reference_dataset': 'None', 'run_parallel': 'False', 'saved_records': 'None', 'warn': 'True'}}\n",
+       "\t\t __main__.MedianImputePerStation[1] {'MedianImputePerStation': {}}\n",
+       "\t\t __main__.TransposeAndFlattenX[1] {'TransposeAndFlattenX': {}}\n",
+       "\t\t conversion.ToNumpy[3]          {'ToNumpy': {'reference_dataset': 'None', 'run_parallel': 'False', 'saved_records': 'None', 'warn': 'True'}}\n",
+       "\t\t __main__.SelectAndFlattenY[1]  {'SelectAndFlattenY': {}}\n",
+       "\t\t __main__.FilterNaNSamples[1]   {'FilterNaNSamples': {}}
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "

Graph

" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "HadISDIndex_8fcd7115-0d04-46ad-930b-ca03e14913a3\n", + "\n", + "HadisdDataClass.HadISDIndex\n", + "\n", + "\n", + "\n", + "AddFlaggedObs_9b13e396-f451-47e2-bf2a-cd1fde1498a6\n", + "\n", + "values.AddFlaggedObs\n", + "\n", + "\n", + "\n", + "HadISDIndex_8fcd7115-0d04-46ad-930b-ca03e14913a3->AddFlaggedObs_9b13e396-f451-47e2-bf2a-cd1fde1498a6\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "SetMissingToNaN_24f2948c-1801-416a-9800-63a82673caf7\n", + "\n", + "values.SetMissingToNaN\n", + "\n", + "\n", + "\n", + "AddFlaggedObs_9b13e396-f451-47e2-bf2a-cd1fde1498a6->SetMissingToNaN_24f2948c-1801-416a-9800-63a82673caf7\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "Drop_3fcf5737-43ba-4619-b446-dd0ca4cd8f4a\n", + "\n", + "variables.Drop\n", + "\n", + "\n", + "\n", + "SetMissingToNaN_24f2948c-1801-416a-9800-63a82673caf7->Drop_3fcf5737-43ba-4619-b446-dd0ca4cd8f4a\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "Drop_3bb14fd5-7ee9-494f-966c-352afdf69348\n", + "\n", + "coordinates.Drop\n", + "\n", + "\n", + "\n", + "Drop_3fcf5737-43ba-4619-b446-dd0ca4cd8f4a->Drop_3bb14fd5-7ee9-494f-966c-352afdf69348\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "TrainTestSplit_efda0281-ace4-45f9-97b6-c5a90ab7cef1\n", + "\n", + "__main__.TrainTestSplit\n", + "\n", + "\n", + "\n", + "Drop_3bb14fd5-7ee9-494f-966c-352afdf69348->TrainTestSplit_efda0281-ace4-45f9-97b6-c5a90ab7cef1\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "FeatureTargetSplit_18774330-4f23-463d-801a-a66de3b568a1\n", + "\n", + "__main__.FeatureTargetSplit\n", + "\n", + "\n", + "\n", + "TrainTestSplit_efda0281-ace4-45f9-97b6-c5a90ab7cef1->FeatureTargetSplit_18774330-4f23-463d-801a-a66de3b568a1\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "FeatureTargetSplit_891d48df-c84a-4c0f-b14f-1a5f1707d584\n", + "\n", + "__main__.FeatureTargetSplit\n", + "\n", + "\n", + "\n", + "TrainTestSplit_efda0281-ace4-45f9-97b6-c5a90ab7cef1->FeatureTargetSplit_891d48df-c84a-4c0f-b14f-1a5f1707d584\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "ToNumpy_607494ac-d92e-463b-936a-ce83b20b6187\n", + "\n", + "conversion.ToNumpy\n", + "\n", + "\n", + "\n", + "FeatureTargetSplit_18774330-4f23-463d-801a-a66de3b568a1->ToNumpy_607494ac-d92e-463b-936a-ce83b20b6187\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "ToNumpy_ca59c6d8-f1c6-4b20-9ef2-acc334da1ad5\n", + "\n", + "conversion.ToNumpy\n", + "\n", + "\n", + "\n", + "FeatureTargetSplit_18774330-4f23-463d-801a-a66de3b568a1->ToNumpy_ca59c6d8-f1c6-4b20-9ef2-acc334da1ad5\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "MedianImputePerStation_bd3a0a4a-f8eb-4f00-9a2c-314c013ac8e0\n", + "\n", + "__main__.MedianImputePerStation\n", + "\n", + "\n", + "\n", + "ToNumpy_607494ac-d92e-463b-936a-ce83b20b6187->MedianImputePerStation_bd3a0a4a-f8eb-4f00-9a2c-314c013ac8e0\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "TransposeAndFlattenX_fd90c437-479d-49fe-a223-54429c389000\n", + "\n", + "__main__.TransposeAndFlattenX\n", + "\n", + "\n", + "\n", + "MedianImputePerStation_bd3a0a4a-f8eb-4f00-9a2c-314c013ac8e0->TransposeAndFlattenX_fd90c437-479d-49fe-a223-54429c389000\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "FilterNaNSamples_82e45d93-8672-4efd-bd48-7e2aaf5fb7b4\n", + "\n", + "__main__.FilterNaNSamples\n", + "\n", + "\n", + "\n", + "TransposeAndFlattenX_fd90c437-479d-49fe-a223-54429c389000->FilterNaNSamples_82e45d93-8672-4efd-bd48-7e2aaf5fb7b4\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "SelectAndFlattenY_2ad4a151-252a-47cf-8d3d-e09727602866\n", + "\n", + "__main__.SelectAndFlattenY\n", + "\n", + "\n", + "\n", + "ToNumpy_ca59c6d8-f1c6-4b20-9ef2-acc334da1ad5->SelectAndFlattenY_2ad4a151-252a-47cf-8d3d-e09727602866\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "SelectAndFlattenY_2ad4a151-252a-47cf-8d3d-e09727602866->FilterNaNSamples_82e45d93-8672-4efd-bd48-7e2aaf5fb7b4\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "ToNumpy_036f6dbc-cd85-4f8c-bf7d-23e4a3f4d971\n", + "\n", + "conversion.ToNumpy\n", + "\n", + "\n", + "\n", + "FeatureTargetSplit_891d48df-c84a-4c0f-b14f-1a5f1707d584->ToNumpy_036f6dbc-cd85-4f8c-bf7d-23e4a3f4d971\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "ToNumpy_e4b5574b-2dff-4fe9-8edc-9dc90d46fa50\n", + "\n", + "conversion.ToNumpy\n", + "\n", + "\n", + "\n", + "FeatureTargetSplit_891d48df-c84a-4c0f-b14f-1a5f1707d584->ToNumpy_e4b5574b-2dff-4fe9-8edc-9dc90d46fa50\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "MedianImputePerStation_876cb65b-1c97-4b0a-855d-607bf960730c\n", + "\n", + "__main__.MedianImputePerStation\n", + "\n", + "\n", + "\n", + "ToNumpy_036f6dbc-cd85-4f8c-bf7d-23e4a3f4d971->MedianImputePerStation_876cb65b-1c97-4b0a-855d-607bf960730c\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "TransposeAndFlattenX_6bcf84c0-b50d-4f66-b725-a8aa32ce6fd0\n", + "\n", + "__main__.TransposeAndFlattenX\n", + "\n", + "\n", + "\n", + "MedianImputePerStation_876cb65b-1c97-4b0a-855d-607bf960730c->TransposeAndFlattenX_6bcf84c0-b50d-4f66-b725-a8aa32ce6fd0\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "FilterNaNSamples_c58d7fad-bd15-437d-a663-b7209d458754\n", + "\n", + "__main__.FilterNaNSamples\n", + "\n", + "\n", + "\n", + "TransposeAndFlattenX_6bcf84c0-b50d-4f66-b725-a8aa32ce6fd0->FilterNaNSamples_c58d7fad-bd15-437d-a663-b7209d458754\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "SelectAndFlattenY_69d7fd56-8e65-4603-9b90-314f0b76a9f5\n", + "\n", + "__main__.SelectAndFlattenY\n", + "\n", + "\n", + "\n", + "ToNumpy_e4b5574b-2dff-4fe9-8edc-9dc90d46fa50->SelectAndFlattenY_69d7fd56-8e65-4603-9b90-314f0b76a9f5\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "SelectAndFlattenY_69d7fd56-8e65-4603-9b90-314f0b76a9f5->FilterNaNSamples_c58d7fad-bd15-437d-a663-b7209d458754\n", + "\n", + "\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ "data_prep_pipe = petpipe.Pipeline(\n", " petdata.archive.hadisd(station = first_ten_stations), # Use all stations\n", - " SqueezeStationCoordinates(),\n", " petdata.transforms.values.AddFlaggedObs(flagged_labels),\n", " petdata.transforms.values.SetMissingToNaN(varname_val_map),\n", " petdata.transforms.variables.Drop(\"flagged_obs\"),\n", + " petdata.transforms.coordinates.Drop(['latitude', 'longitude', 'elevation']),\n", " TrainTestSplit(test_size=0.2, random_state=42, dim='time'), # returns (ds_train, ds_test)\n", " (\n", " # branch 1: train\n", @@ -114,7 +837,7 @@ " ),\n", " # branch y_train\n", " (\n", - " petpipe.operations.xarray.conversion.ToNumpy(),\n", + " petpipe.operations.xarray.conversion.ToNumpy(),\n", " SelectAndFlattenY(test_number=33),\n", " ),\n", " 'map'\n", @@ -148,9 +871,18 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/joelmiller/Projects/Python/PyEarthTools_Fork/PyEarthTools/packages/data/src/pyearthtools/data/indexes/_indexes.py:487: IndexWarning: Could not find time in dataset to select on. Petdt('1969-01-01T00')\n", + " warnings.warn(\n" + ] + } + ], "source": [ "# Select data based on date\n", "(X_train, y_train),(X_test, y_test) = data_prep_pipe[\"1969-01-01T00\"] # Curretnly we have to select a date outside of the date range to select the entire dataset (fix coming soon)" @@ -158,9 +890,20 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "X_train shape: (920121, 22)\n", + "y_train shape: (920121,)\n", + "X_test shape: (232093, 22)\n", + "y_test shape: (232093,)\n" + ] + } + ], "source": [ "# Print the shapes of the resulting arrays\n", "print(f\"X_train shape: {X_train.shape}\")\n", @@ -178,12 +921,21 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/joelmiller/miniconda3/envs/pyearthtools/lib/python3.13/site-packages/xgboost/training.py:183: UserWarning: [17:41:03] WARNING: /Users/runner/work/xgboost/xgboost/src/learner.cc:738: \n", + "Parameters: { \"use_label_encoder\" } are not used.\n", + "\n", + " bst.update(dtrain, iteration=i, fobj=obj)\n" + ] + } + ], "source": [ - "from xgboost import XGBClassifier\n", - "\n", "# Calculate scale_pos_weight for class imbalance\n", "scale_pos_weight = (len(y_train) - np.sum(y_train)) / np.sum(y_train)\n", "#scale_pos_weight = num_zeros / num_ones \n", @@ -213,25 +965,71 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 9, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " precision recall f1-score support\n", + "\n", + " 0.0 0.99 0.97 0.98 223592\n", + " 1.0 0.53 0.76 0.63 8501\n", + "\n", + " accuracy 0.97 232093\n", + " macro avg 0.76 0.87 0.80 232093\n", + "weighted avg 0.97 0.97 0.97 232093\n", + "\n", + "[[217897 5695]\n", + " [ 2013 6488]]\n" + ] + } + ], "source": [ "# compare predictions with true labels\n", - "from sklearn.metrics import classification_report, confusion_matrix\n", "print(classification_report(y_test, y_pred))\n", - "print(confusion_matrix(y_test, y_pred)) \n", - "\n", - "# Plot confusion matrix\n", - "import matplotlib.pyplot as plt\n", - "import seaborn as sns\n", + "print(confusion_matrix(y_test, y_pred)) " + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Plot confusion matrix using only matplotlib\n", "def plot_confusion_matrix(cm, classes, title='Confusion Matrix', cmap=plt.cm.Blues):\n", " plt.figure(figsize=(8, 6))\n", - " sns.heatmap(cm, annot=True, fmt='d', cmap=cmap,\n", - " xticklabels=classes, yticklabels=classes)\n", + " plt.imshow(cm, interpolation='nearest', cmap=cmap)\n", " plt.title(title)\n", + " plt.colorbar()\n", + " tick_marks = np.arange(len(classes))\n", + " plt.xticks(tick_marks, classes)\n", + " plt.yticks(tick_marks, classes)\n", + "\n", + " # Annotate cells with counts\n", + " thresh = cm.max() / 2.\n", + " for i in range(cm.shape[0]):\n", + " for j in range(cm.shape[1]):\n", + " plt.text(j, i, format(cm[i, j], 'd'),\n", + " ha=\"center\", va=\"center\",\n", + " color=\"white\" if cm[i, j] > thresh else \"black\")\n", + "\n", " plt.ylabel('True label')\n", " plt.xlabel('Predicted label')\n", + " plt.tight_layout()\n", " plt.show()\n", "\n", "cm = confusion_matrix(y_test, y_pred)\n", @@ -241,7 +1039,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "pyearthtools", "language": "python", "name": "python3" }, @@ -255,7 +1053,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.13.2" + "version": "3.13.5" } }, "nbformat": 4, diff --git a/notebooks/tutorial/HadISD/Data_Config.ipynb b/notebooks/tutorial/HadISD/Data_Config.ipynb index db28f31b..ad00c678 100644 --- a/notebooks/tutorial/HadISD/Data_Config.ipynb +++ b/notebooks/tutorial/HadISD/Data_Config.ipynb @@ -57,27 +57,42 @@ "metadata": {}, "outputs": [], "source": [ - "# A sample list of WMO number ranges. Users can find more at the official HadISD download page.\n", + "# For any station ranges you don't want to download, you can comment them out here\n", "wmo_id_ranges = [\n", - " \"000000-029999\",\n", - " \"080000-099999\",\n", - " \"200000-249999\",\n", - " \"720000-721999\",\n", + " #\"000000-029999\",\n", + " #\"030000-049999\",\n", + " #\"050000-079999\",\n", + " #\"080000-099999\",\n", + " #\"100000-149999\",\n", + " #\"150000-199999\",\n", + " #\"200000-249999\",\n", + " #\"250000-299999\",\n", + " #\"300000-349999\",\n", + " #\"350000-399999\",\n", + " #\"400000-449999\",\n", + " #\"450000-499999\",\n", + " \"500000-549999\",\n", + " #\"550000-599999\",\n", + " #\"600000-649999\",\n", + " #\"650000-699999\",\n", + " #\"700000-709999\",\n", + " #\"710000-714999\",\n", + " #\"715000-719999\",\n", + " #\"720000-721999\",\n", + " \"722000-722999\",\n", + " #\"723000-723999\",\n", + " #\"724000-724999\",\n", + " #\"725000-725999\",\n", + " #\"726000-726999\",\n", + " #\"727000-729999\",\n", + " #\"730000-799999\",\n", + " \"800000-849999\",\n", + " #\"850000-899999\",\n", + " #\"900000-949999\",\n", + " #\"950000-999999\",\n", "]" ] }, - { - "cell_type": "code", - "execution_count": null, - "id": "35617ad2", - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "# User sets the WMO number range to download\n", - "wmo_id_range = \"080000-099999\" # Change this to the desired WMO range, either from the sample list or from the HadISD page." - ] - }, { "cell_type": "markdown", "id": "7aec321a", @@ -97,15 +112,15 @@ "# Set the date range to reindex the time coordinate\n", "DATE_RANGE = (\"1970-01-01T00\", \"2023-12-31T23\")\n", "# Set the input directory to the folder with raw NetCDFs\n", - "input_dir = download_dir / f\"WMO_{wmo_id_range}\" / \"netcdf\"\n", + "input_dir = download_dir / \"netcdf\"\n", "# Set the Zarr output directory to a sibling folder under the same WMO directory\n", - "zarr_output_dir = download_dir / f\"WMO_{wmo_id_range}\" / \"zarr\"" + "zarr_output_dir = download_dir / \"zarr\"" ] } ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "pet_tutorial", "language": "python", "name": "python3" }, @@ -119,7 +134,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.13.2" + "version": "3.11.11" } }, "nbformat": 4, diff --git a/notebooks/tutorial/HadISD/HadISD_QC_Exploration.ipynb b/notebooks/tutorial/HadISD/HadISD_QC_Exploration.ipynb index c26523fa..0b3694cf 100644 --- a/notebooks/tutorial/HadISD/HadISD_QC_Exploration.ipynb +++ b/notebooks/tutorial/HadISD/HadISD_QC_Exploration.ipynb @@ -15,10 +15,7 @@ "metadata": {}, "outputs": [], "source": [ - "import datetime\n", "import numpy as np\n", - "import pandas as pd\n", - "from pathlib import Path\n", "\n", "import pyearthtools.pipeline as petpipe\n", "import pyearthtools.data as petdata\n", @@ -27,12 +24,13 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": null, "id": "4a7da841", "metadata": {}, "outputs": [], "source": [ - "# %run HadISD_config.ipynb" + "# ruff: noqa: F821\n", + "%run Pipeline_Config.ipynb" ] }, { @@ -89,17 +87,7 @@ "metadata": {}, "outputs": [], "source": [ - "y = data_prep_pipe[\"1969-01-01T07\"]\n", - "y" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3d3819b9", - "metadata": {}, - "outputs": [], - "source": [ + "x = data_prep_pipe[\"1969-01-01T07\"]\n", "x" ] }, @@ -110,7 +98,7 @@ "metadata": {}, "outputs": [], "source": [ - "qc = y[\"quality_control_flags\"].values\n" + "qc = x[\"quality_control_flags\"].values\n" ] }, { @@ -214,21 +202,13 @@ "# for qc, test 12, time 826, station 0, print the value of the test\n", "print(\"QC value for test 12, time 826, station 0:\", qc[0, 826, 12])" ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "8c8f1bc9", - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "pyearthtools", "language": "python", - "name": "python3" + "name": "pyearthtools" }, "language_info": { "codemirror_mode": { @@ -240,7 +220,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.13.2" + "version": "3.13.5" } }, "nbformat": 4, diff --git a/notebooks/tutorial/HadISD/Pipeline_Config.ipynb b/notebooks/tutorial/HadISD/Pipeline_Config.ipynb index e0dcc486..4b5298b3 100644 --- a/notebooks/tutorial/HadISD/Pipeline_Config.ipynb +++ b/notebooks/tutorial/HadISD/Pipeline_Config.ipynb @@ -2,12 +2,12 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "id": "1abab3c3", "metadata": {}, "outputs": [], "source": [ - "# import pyearthtools.pipeline as petpipe" + "import pyearthtools.pipeline as petpipe" ] }, { @@ -15,7 +15,7 @@ "id": "2307aa00", "metadata": {}, "source": [ - "# Lists and Dictionaries Required For Sustom Pipeline Steps" + "# Lists and Dictionaries Required For Custom Pipeline Steps" ] }, { @@ -62,6 +62,16 @@ "source": [ "# Custom operation to remove redundent coordinates\n", "class SqueezeStationCoordinates(petpipe.Operation):\n", + " \"\"\"\n", + " Squeeze singleton dimensions from specified station-based coordinates in an xarray.Dataset.\n", + "\n", + " This operation is useful for removing unnecessary singleton dimensions (e.g., shape (n, 1))\n", + " from coordinates like latitude, longitude, and elevation, ensuring they are 1D and indexed\n", + " by 'station'.\n", + "\n", + " Args:\n", + " coords (tuple of str): Names of coordinates to squeeze. Defaults to (\"latitude\", \"longitude\", \"elevation\").\n", + " \"\"\"\n", " def __init__(self, coords=(\"latitude\", \"longitude\", \"elevation\")):\n", " super().__init__()\n", " self.coords = coords\n", @@ -74,7 +84,7 @@ " # Undo function added otherwise pyearthtools will complain\n", " def undo_func(self, ds):\n", " # No undo operation needed for this operation\n", - " return ds" + " return ds\n" ] }, { @@ -314,8 +324,6 @@ "metadata": {}, "outputs": [], "source": [ - "import xarray as xr\n", - "import numpy as np\n", "from pyearthtools.data.transforms.values import AddFlaggedObs\n", "\n", "def test_add_flagged_obs():\n", diff --git a/packages/data/src/pyearthtools/data/indexes/_indexes.py b/packages/data/src/pyearthtools/data/indexes/_indexes.py index d08abd03..b4c358ff 100644 --- a/packages/data/src/pyearthtools/data/indexes/_indexes.py +++ b/packages/data/src/pyearthtools/data/indexes/_indexes.py @@ -66,7 +66,7 @@ LOG = logging.getLogger("pyearthtools.data") - + class Index(CallRedirectMixin, CatalogMixin, metaclass=ABCMeta): """ Base Level Index to define the structure diff --git a/packages/data/src/pyearthtools/data/indexes/utilities/fileload.py b/packages/data/src/pyearthtools/data/indexes/utilities/fileload.py index 2ed6f9b9..eee5cf2f 100644 --- a/packages/data/src/pyearthtools/data/indexes/utilities/fileload.py +++ b/packages/data/src/pyearthtools/data/indexes/utilities/fileload.py @@ -83,7 +83,7 @@ def get_config(mf: bool = False): return xr.open_mfdataset( filter_files(location), decode_timedelta=True, # TODO: should we raise a warning? It seems to be required for almost all our data. - compat='override', + compat="override", **get_config(True), ) diff --git a/packages/data/src/pyearthtools/data/transforms/variables.py b/packages/data/src/pyearthtools/data/transforms/variables.py index 93d0117e..fec0253f 100644 --- a/packages/data/src/pyearthtools/data/transforms/variables.py +++ b/packages/data/src/pyearthtools/data/transforms/variables.py @@ -87,11 +87,20 @@ def apply(self, dataset: xr.Dataset) -> xr.Dataset: if self._variables is None: return dataset - var_included = set(dataset.data_vars).difference(set(self._variables)) + # 3/9/2025 - old logic was replaced with a simple drop of the variables + # A new issue will be raised to review how coordinate protection should + # work because people need a way to drop coords when needed. - if not var_included: - return dataset - return dataset[var_included] + # Calculate the difference between the data variables on the dataset + # and the variables requested for drop. This leaves coordinate variables + # unaffected + # var_included = set(dataset.data_vars).difference(set(self._variables)) + + # if not var_included: + # return dataset + # return dataset[var_included] + + return dataset.drop_vars(self._variables) class Select(Transform): diff --git a/packages/data/tests/indexes/test_indexes.py b/packages/data/tests/indexes/test_indexes.py index b8554205..14509f93 100644 --- a/packages/data/tests/indexes/test_indexes.py +++ b/packages/data/tests/indexes/test_indexes.py @@ -7,6 +7,7 @@ from pyearthtools.data.exceptions import DataNotFoundError + def test_Index(monkeypatch): monkeypatch.setattr("pyearthtools.data.indexes.Index.__abstractmethods__", set()) diff --git a/packages/nci_site_archive/tests/test_radar_proj.py b/packages/nci_site_archive/tests/test_radar_proj.py index 6a75f080..545323f2 100644 --- a/packages/nci_site_archive/tests/test_radar_proj.py +++ b/packages/nci_site_archive/tests/test_radar_proj.py @@ -24,13 +24,18 @@ import platform import xarray as xr -from site_archive_nci._Rainfields3 import ( - ErrorRadarProj, - ProjErrorStatus, - ProjKind, - RadarProj, - WarnRadarProj, -) +try: + from site_archive_nci._Rainfields3 import ( + ErrorRadarProj, + ProjErrorStatus, + ProjKind, + RadarProj, + WarnRadarProj, + ) + +except ImportError: + pytest.skip(allow_module_level=True) + PYPROJ_SAMPLE = pyproj.Proj("+proj=aea +lat_1=-36 +lat_2=-18 +lon_0=132 +units=m") EXPECTED_KEYS = [ diff --git a/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/join.py b/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/join.py index c4146509..b0e83742 100644 --- a/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/join.py +++ b/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/join.py @@ -46,15 +46,15 @@ def unjoin(self, sample: Any) -> tuple: class LatLonInterpolate(Joiner): - ''' + """ Makes additional assumptions about how interpolation should work and how the data is structured. In this case, interpolation is primarily - expected to occur according to latitude and longitude, performing + expected to occur according to latitude and longitude, performing no broadcasting, and iterating over other dimensions instead. It assumed the dimensions 'latitude', 'longitude', 'time', and 'level' will be present. 'lat' or 'lon' may also be used for convenience. - ''' + """ _override_interface = "Serial" @@ -78,15 +78,15 @@ def __init__( self._merge_kwargs = merge_kwargs def raise_if_dimensions_wrong(self, dataset): - ''' + """ Raise exceptions if the supplied dataset does not meet requirements - ''' + """ - if not hasattr(self, 'required_dims'): - if 'lat' in dataset.coords: - self.required_dims = ['lat', 'lon'] + if not hasattr(self, "required_dims"): + if "lat" in dataset.coords: + self.required_dims = ["lat", "lon"] else: - self.required_dims = ['latitude', 'longitude'] + self.required_dims = ["latitude", "longitude"] present_in_coords = [d in dataset.coords for d in self.required_dims] if not all(present_in_coords): @@ -100,12 +100,12 @@ def raise_if_dimensions_wrong(self, dataset): # raise ValueError(f"Cannot perform a GeoMergePancake on the data variables {data_var} without {self.required_dims} as a dimension") def maybe_interp(self, ds): - ''' + """ This method will only interpolate the datasets if the latitudes and longitudes don't already match. This means, for example, you can't use it to interpolate between time steps or vertical levels alone. The primary purpose here is lat/lon interpolation, not general model interpolation or arbitrarily-dimensioned data interpolation. - ''' + """ ds_coords_ok = [ds[coord].equals(self.reference_dataset[coord]) for coord in self.required_dims] @@ -115,7 +115,6 @@ def maybe_interp(self, ds): return ds - def _join_two_datasets(self, sample_a: xr.Dataset, sample_b: xr.Dataset) -> xr.Dataset: """ Used to reduce a sequence of joinable items. Only called by the public interface join method. @@ -144,7 +143,7 @@ def join(self, sample: tuple[Union[xr.Dataset, xr.DataArray], ...]) -> xr.Datase return merged def unjoin(self, sample: Any) -> tuple: - raise NotImplementedError("Not Implemented") + raise NotImplementedError("Not Implemented") class GeospatialTimeSeriesMerge(Joiner): diff --git a/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/normalisation.py b/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/normalisation.py index 3496dd79..68fe514f 100644 --- a/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/normalisation.py +++ b/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/normalisation.py @@ -157,10 +157,10 @@ class Deviation(xarrayNormalisation): """Deviation Normalisation""" def __init__( - self, - mean: FILE | xr.Dataset | xr.DataArray | float, + self, + mean: FILE | xr.Dataset | xr.DataArray | float, deviation: FILE | xr.Dataset | xr.DataArray | float, - debug=False + debug=False, ): """ Each argument take take a Dataset, DataArray, float or file object. @@ -173,7 +173,9 @@ def __init__( self.record_initialisation() if debug: - import pdb; pdb.set_trace() + import pdb + + pdb.set_trace() if isinstance(mean, xr.Dataset): self.mean = mean @@ -187,7 +189,7 @@ def __init__( if isinstance(deviation, xr.Dataset): self.deviation = deviation elif isinstance(deviation, xr.DataArray): - self.deviation = deviation + self.deviation = deviation elif isinstance(deviation, float): self.deviation = deviation else: diff --git a/packages/tutorial/src/pyearthtools/tutorial/HadisdDataClass.py b/packages/tutorial/src/pyearthtools/tutorial/HadisdDataClass.py index 1fba6cf9..0a5921ba 100644 --- a/packages/tutorial/src/pyearthtools/tutorial/HadisdDataClass.py +++ b/packages/tutorial/src/pyearthtools/tutorial/HadisdDataClass.py @@ -136,62 +136,32 @@ def __init__( self.record_initialisation() - # def get_all_station_ids(self, root_directory: Path | str) -> list[str]: - # """ - # Retrieve all station IDs by scanning the dataset directory. - - # Args: - # root_directory (Path | str): The root directory containing station data. - - # Returns: - # list[str]: A list of all station IDs. - # """ - # root_directory = Path(root_directory) - # station_ids = [] - # for folder in cached_iterdir(root_directory): - # if folder.is_dir(): - # for file in cached_iterdir(folder): - # if file.suffix == ".nc": # Check for NetCDF files - # # Extract the station ID from the filename - # station_id = file.stem.split("_")[-1] # Assuming station ID is the last part of the filename - # station_ids.append(station_id) - # return station_ids - def get_all_station_ids(self, root_directory: Path | str = None) -> list[str]: """ - Retrieve all station IDs by scanning the dataset directory. + Retrieve all station IDs by scanning the Zarr directory. Args: - root_directory (Path | str, optional): The root directory containing station data. - Defaults to HADISD_HOME/netcdf. + root_directory (Path | str, optional): The directory containing Zarr files. + Defaults to HADISD_HOME/zarr. Returns: list[str]: A list of all station IDs. """ - HADISD_HOME = self.ROOT_DIRECTORIES["hadisd"] if root_directory is None: - # Search all WMO folders for netcdf subfolders - wmo_folders = [f for f in Path(HADISD_HOME).iterdir() if f.is_dir() and f.name.startswith("WMO_")] - station_ids = [] - for wmo_folder in wmo_folders: - netcdf_dir = wmo_folder / "netcdf" - if cached_exists(netcdf_dir): - for file in cached_iterdir(netcdf_dir): - if file.suffix == ".nc": - station_id = file.stem.split("_")[-1] - station_ids.append(station_id) - return station_ids + zarr_dir = Path(HADISD_HOME) / "zarr" else: - root_directory = Path(root_directory) - if not cached_exists(root_directory): - raise DataNotFoundError(f"Root directory does not exist: {root_directory}") - station_ids = [] - for file in cached_iterdir(root_directory): - if file.suffix == ".nc": - station_id = file.stem.split("_")[-1] - station_ids.append(station_id) - return station_ids + zarr_dir = Path(root_directory) + + if not cached_exists(zarr_dir): + raise DataNotFoundError(f"Zarr directory does not exist: {zarr_dir}") + + station_ids = [] + for file in cached_iterdir(zarr_dir): + if file.suffix == ".zarr": + station_id = file.stem.split("_")[-1] + station_ids.append(station_id) + return station_ids def filesystem(self, *args, date_range=("1970-01-01T00", "2023-12-31T23"), **kwargs) -> dict[str, Path]: """ @@ -222,66 +192,18 @@ def filesystem(self, *args, date_range=("1970-01-01T00", "2023-12-31T23"), **kwa if not isinstance(station_ids, list) or not all(isinstance(sid, str) for sid in station_ids): raise TypeError(f"Expected station_ids to be a str or list[str], but got: {type(station_ids)}") - # Define the station ranges and corresponding folders - STATION_RANGES = [ - (0, 29999, "WMO_000000-029999"), - (30000, 49999, "WMO_030000-049999"), - (50000, 79999, "WMO_050000-079999"), - (80000, 99999, "WMO_080000-099999"), - (100000, 149999, "WMO_100000-149999"), - (150000, 199999, "WMO_150000-199999"), - (200000, 249999, "WMO_200000-249999"), - (250000, 299999, "WMO_250000-299999"), - (300000, 349999, "WMO_300000-349999"), - (350000, 399999, "WMO_350000-399999"), - (400000, 449999, "WMO_400000-449999"), - (450000, 499999, "WMO_450000-499999"), - (500000, 549999, "WMO_500000-549999"), - (550000, 599999, "WMO_550000-599999"), - (600000, 649999, "WMO_600000-649999"), - (650000, 699999, "WMO_650000-699999"), - (700000, 709999, "WMO_700000-709999"), - (710000, 714999, "WMO_710000-714999"), - (715000, 719999, "WMO_715000-719999"), - (720000, 721999, "WMO_720000-721999"), - (722000, 722999, "WMO_722000-722999"), - (723000, 723999, "WMO_723000-723999"), - (724000, 724999, "WMO_724000-724999"), - (725000, 725999, "WMO_725000-725999"), - (726000, 726999, "WMO_726000-726999"), - (727000, 729999, "WMO_727000-729999"), - (730000, 799999, "WMO_730000-799999"), - (800000, 849999, "WMO_800000-849999"), - (850000, 899999, "WMO_850000-899999"), - (900000, 949999, "WMO_900000-949999"), - (950000, 999999, "WMO_950000-999999"), - ] - # Map station IDs to their file paths paths = {} for station_id in station_ids: - wmo_number = station_id[:6] # Extract the first 6 digits of the station ID - station_numeric = int(wmo_number) # Convert the WMO number to an integer - - # Find the parent folder dynamically - parent_folder = None - for start, end, folder in STATION_RANGES: - if start <= station_numeric <= end: - parent_folder = folder - break - - if parent_folder is None: - raise ValueError(f"Station ID {station_id} does not fall within any defined range.") - - # Construct the expected filename - date_range = "19310101-20240101" # Hardcoded for now; adjust if dataset is updated + date_range_str = "19310101-20240101" # Hardcoded for now; adjust if dataset is updated version = "hadisd.3.4.0.2023f" - filename_nc = f"{version}_{date_range}_{station_id}.nc" - filename_zarr = f"{version}_{date_range}_{station_id}.zarr" + filename_zarr = f"{version}_{date_range_str}_{station_id}.zarr" + + # filename_nc = f"{version}_{date_range_str}_{station_id}.nc" # Uncomment to test with netcdf + # file_path_nc = Path(HADISD_HOME) / "netcdf" / filename_nc # Uncomment to test with netcdf # Construct the full path - _file_path_nc = Path(HADISD_HOME) / parent_folder / "netcdf" / filename_nc - file_path_zarr = Path(HADISD_HOME) / parent_folder / "zarr" / filename_zarr + file_path_zarr = Path(HADISD_HOME) / "zarr" / filename_zarr # Check if the file exists (comment out if testing with single netcdf) if not file_path_zarr.exists(): @@ -290,6 +212,7 @@ def filesystem(self, *args, date_range=("1970-01-01T00", "2023-12-31T23"), **kwa # Add the file path to the dictionary paths[station_id] = ( file_path_zarr # Change to file_path_zarr to test with zarr files or remove "_zarr" to test with netcdf files + # file_path_nc # Uncomment to test with netcdf files ) return paths diff --git a/packages/utils/src/pyearthtools/utils/data/converter.py b/packages/utils/src/pyearthtools/utils/data/converter.py index a8f2e0c2..e8bac483 100644 --- a/packages/utils/src/pyearthtools/utils/data/converter.py +++ b/packages/utils/src/pyearthtools/utils/data/converter.py @@ -158,9 +158,16 @@ def _distill_dataset(self, dataset: XR_OBJECT) -> dict[DISTILL_KEYS, Any]: try: dims[use_shape.index(size)] = coord except ValueError as e: - raise RuntimeError( - "Cannot record coordinate, currently converter can only handle datasets with variables of the same dimensions." - ) from e + + msg = ( + f"Cannot record coordinate '{coord}', currently the conversion can only handle data variables with " + f"the same dimensionality as the dataset coords {list(dataset.coords)}. " + f"Data variable {variables[0]} with dimensions of {dataset[variables[0]].dims} was used to estimate the shape required. " + f"You may need to drop unused coordinates, drop mismatching data variables, or broadcast your data variables onto the " + "coordinates of the dataset yourself as the proper approach is user-defined." + ) + + raise RuntimeError(msg) from e use_shape[use_shape.index(size)] = 1e10 while None in dims: