diff --git a/docs/examples.md b/docs/examples.md index 57b0c8b..9259217 100644 --- a/docs/examples.md +++ b/docs/examples.md @@ -6,9 +6,10 @@ We provide an example notebook showing optimization of models' parameters with `diffWOFOST`. To get familiar with the concepts and implementation, check out [`Introduction`](./index.md) in the documentation. -| Open the notebook | Access the source | View the notebook | -|-------|------------|---------------| -| [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)][colab_link] | [![Access the source code](https://img.shields.io/badge/GitHub_Repository-000.svg?logo=github&labelColor=gray&color=blue)][source_link] | [![here](https://img.shields.io/badge/View_Notebook-orange.svg?logo=jupyter&labelColor=gray)](./notebooks/optimization.ipynb) | +| Model | Open the notebook | Access the source | View the notebook | +|---|----|------------|---------------| +| Leaf and Root dynamics| [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)][leaf_colab_link] | [![Access the source code](https://img.shields.io/badge/GitHub_Repository-000.svg?logo=github&labelColor=gray&color=blue)][leaf_source_link] | [![here](https://img.shields.io/badge/View_Notebook-orange.svg?logo=jupyter&labelColor=gray)](./notebooks/optimization.ipynb) | +| Phenology | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)][pheno_colab_link] | [![Access the source code](https://img.shields.io/badge/GitHub_Repository-000.svg?logo=github&labelColor=gray&color=blue)][pheno_source_link] | [![here](https://img.shields.io/badge/View_Notebook-orange.svg?logo=jupyter&labelColor=gray)](./notebooks/optimization_phenology.ipynb) | !!! note When calculating gradients, it is important to ensure that the predicted @@ -21,6 +22,7 @@ We provide an example notebook showing optimization of models' parameters with output w.r.t the parameter will be close to zero, which may not provide useful information for optimization. -[colab_link]: https://colab.research.google.com/github/WUR-AI/diffWOFOST/blob/main/docs/notebooks/optimization.ipynb - -[source_link]: https://github.com/WUR-AI/diffWOFOST/blob/main/docs/notebooks/optimization.ipynb +[leaf_colab_link]: https://colab.research.google.com/github/WUR-AI/diffWOFOST/blob/main/docs/notebooks/optimization.ipynb +[leaf_source_link]: https://github.com/WUR-AI/diffWOFOST/blob/main/docs/notebooks/optimization.ipynb +[pheno_colab_link]: https://colab.research.google.com/github/WUR-AI/diffWOFOST/blob/main/docs/notebooks/optimization_phenology.ipynb +[pheno_source_link]: https://github.com/WUR-AI/diffWOFOST/blob/main/docs/notebooks/optimization_phenology.ipynb diff --git a/docs/notebooks/optimization_phenology.ipynb b/docs/notebooks/optimization_phenology.ipynb new file mode 100644 index 0000000..7b40125 --- /dev/null +++ b/docs/notebooks/optimization_phenology.ipynb @@ -0,0 +1,477 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "c784f1c2-d477-464d-a35f-f2c64e96fb10", + "metadata": {}, + "source": [ + "
\n", + "

Optimizing parameters in a WOFOST crop model using diffWOFOST

\n", + " \n", + "
\n", + "\n", + "\n", + "This Jupyter notebook demonstrates the optimization of parameters in a\n", + "differentiable model using the `diffwofost` package. The package provides\n", + "differentiable implementations of the WOFOST model and its associated\n", + "sub-models. As `diffwofost` is under active development, this notebook focuses on\n", + "one sub-models: `phenology`. " + ] + }, + { + "cell_type": "markdown", + "id": "41262fbd-270b-4616-91ad-09ee82451604", + "metadata": {}, + "source": [ + "## 1. Phenology\n", + "\n", + "In this section, we will demonstrate how to optimize the parameters `TSUMEM`, `TBASEM`, `TSUM1` and `TSUM2`in\n", + "phenology model using a differentiable version of phenology.\n", + "The optimization will be done using the Adam optimizer from `torch.optim`." + ] + }, + { + "cell_type": "markdown", + "id": "1b6c3f53-6fab-4537-9177-7b16e0a1ccec", + "metadata": {}, + "source": [ + "### 1.1 software requirements\n", + "\n", + "To run this notebook, we need to install the `diffwofost`; the differentiable\n", + "version of WOFOST models. Since the package is constantly under development, make\n", + "sure you have the latest version of `diffwofost` installed in your\n", + "python environment. You can install it using pip:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e4049fea-1d05-41f1-bf9d-f030ae83a324", + "metadata": {}, + "outputs": [], + "source": [ + "# install diffwofost\n", + "!pip install diffwofost" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "21731653-3976-4bb9-b83b-b11d78211700", + "metadata": {}, + "outputs": [], + "source": [ + "# ---- import libraries ----\n", + "import copy\n", + "import torch\n", + "import numpy\n", + "from pathlib import Path\n", + "from diffwofost.physical_models.config import Configuration\n", + "from diffwofost.physical_models.crop.phenology import DVS_Phenology\n", + "from diffwofost.physical_models.utils import EngineTestHelper\n", + "from diffwofost.physical_models.utils import prepare_engine_input\n", + "from diffwofost.physical_models.utils import get_test_data" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "82a1ef6b-336e-4902-8bd1-2a1ed2020f9d", + "metadata": {}, + "outputs": [], + "source": [ + "# ---- disable a warning: this will be fixed in the future ----\n", + "import warnings\n", + "warnings.filterwarnings(\"ignore\", message=\"To copy construct from a tensor.*\")" + ] + }, + { + "cell_type": "markdown", + "id": "47def7fc-f2dd-4aaf-a572-41cc9d1e4679", + "metadata": {}, + "source": [ + "### 1.2. Data\n", + "\n", + "A test dataset of `DVS` (Development stage) will be used to optimize the parameters:\n", + "- `TSUMEM`: Temperature sum from sowing to emergence,\n", + "- `TBASEM`: Base temperature for emergence,\n", + "- `TSUM1`: Temperature sum from emergence to anthesis,\n", + "- `TSUM2`: Temperature sum from anthesis to maturity. \n", + "\n", + "The data is stored in PCSE tests folder, and can be doewnloded from PCSE repsository.\n", + "You can select any of the files related to `phenology` model with a file name that follwos the pattern\n", + "`test_phenology_wofost72_*.yaml`. Each file contains different data depending on the locatin and crop type.\n", + "For example, you can download the file \"test_phenology_wofost72_01.yaml\" as:" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "0233a048-e5a2-4249-887d-35a37284769c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Downloaded: test_phenology_wofost72_17.yaml\n" + ] + } + ], + "source": [ + "import urllib.request\n", + "\n", + "filename = \"test_phenology_wofost72_17.yaml\"\n", + "url = f\"https://raw.githubusercontent.com/ajwdewit/pcse/refs/heads/master/tests/test_data/{filename}\"\n", + "\n", + "urllib.request.urlretrieve(url, filename)\n", + "print(f\"Downloaded: {filename}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "5a459489-bfcb-4ad6-9102-1b6be5edeb52", + "metadata": {}, + "outputs": [], + "source": [ + "# ---- Check the path to the files that are downloaded as explained above ----\n", + "test_data_path = \"test_phenology_wofost72_17.yaml\"" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "a39f030b-ca6f-4535-8692-7883476ae7a4", + "metadata": {}, + "outputs": [], + "source": [ + "# ---- Here we read the test data and set some variables ----\n", + "test_data = get_test_data(test_data_path)\n", + "\n", + "crop_model_params = [\n", + " \"TSUMEM\",\n", + " \"TBASEM\",\n", + " \"TEFFMX\",\n", + " \"TSUM1\",\n", + " \"TSUM2\",\n", + " \"IDSL\",\n", + " \"DLO\",\n", + " \"DLC\",\n", + " \"DVSI\",\n", + " \"DVSEND\",\n", + " \"DTSMTB\",\n", + " \"VERNSAT\",\n", + " \"VERNBASE\",\n", + " \"VERNDVS\",\n", + "]\n", + "(crop_model_params_provider, weather_data_provider, agro_management_inputs, _) = (\n", + " prepare_engine_input(test_data, crop_model_params)\n", + ")\n", + "\n", + "expected_results = test_data[\"ModelResults\"]\n", + "expected_dvs = torch.tensor([float(item[\"DVS\"]) for item in expected_results], dtype=torch.float32\n", + ") # shape: [time_steps]\n", + "\n", + "# ---- dont change this: in this config file we specified the diffrentiable version of leaf_dynamics ----\n", + "phenology_config = Configuration(\n", + " CROP=DVS_Phenology,\n", + " OUTPUT_VARS=[\"DVR\", \"DVS\", \"TSUM\", \"TSUME\", \"VERN\"],\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "52b19ae2-3fe6-4b3f-95a7-aea07a2c0958", + "metadata": {}, + "source": [ + "### 1.3. Helper classes/functions\n", + "\n", + "The model parameters shoudl stay in a valid range. To ensure this, we will use\n", + "`BoundedParameter` class with (min, max) and initial values for each\n", + "parameter. You might change these values depending on the crop type and\n", + "location. But dont use a very small range, otherwise gradiants will be very\n", + "small and the optimization will be very slow." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "e4610238-de0d-42cf-9689-3c074eb2cc0e", + "metadata": {}, + "outputs": [], + "source": [ + "# ---- Adjust the values if needed ----\n", + "TSUMEM_MIN, TSUMEM_MAX, TSUMEM_INIT = (0.0, 200, 90)\n", + "TBASEM_MIN, TBASEM_MAX, TBASEM_INIT = (0.0, 10.0, 2.0)\n", + "TSUM1_MIN, TSUM1_MAX, TSUM1_INIT = (0.0, 1000, 800)\n", + "TSUM2_MIN, TSUM2_MAX, TSUM2_INIT = (0.0, 1000, 800)\n", + "\n", + "# ---- Helper for bounded parameters ----\n", + "class BoundedParameter(torch.nn.Module):\n", + " def __init__(self, low, high, init_value):\n", + " super().__init__()\n", + " self.low = low\n", + " self.high = high\n", + "\n", + " # Normalize to [0, 1]\n", + " init_norm = (init_value - low) / (high - low)\n", + "\n", + " # Parameter in raw logit space\n", + " self.raw = torch.nn.Parameter(torch.logit(torch.tensor(init_norm, dtype=torch.float32), eps=1e-6))\n", + "\n", + " def forward(self):\n", + " return self.low + (self.high - self.low) * torch.sigmoid(self.raw)\n" + ] + }, + { + "cell_type": "markdown", + "id": "153e4306-77ed-4278-8797-a04e637e12c8", + "metadata": {}, + "source": [ + "Another helper class is `OptDiffPhenology` which is a subclass of `torch.nn.Module`. \n", + "We use this class to wrap the `EngineTestHelper` function and make it easier to run the model `phenology`." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "36dd6463-4812-41c0-b2bf-d4769df1136f", + "metadata": {}, + "outputs": [], + "source": [ + "# ---- Wrap the model with torch.nn.Module----\n", + "class OptDiffPhenology(torch.nn.Module):\n", + " def __init__(self, crop_model_params_provider, weather_data_provider, agro_management_inputs, phenology_config):\n", + " super().__init__()\n", + " self.crop_model_params_provider = crop_model_params_provider\n", + " self.weather_data_provider = weather_data_provider\n", + " self.agro_management_inputs = agro_management_inputs\n", + " self.config = phenology_config\n", + "\n", + " # bounded parameters\n", + " self.TSUMEM = BoundedParameter(TSUMEM_MIN, TSUMEM_MAX, TSUMEM_INIT)\n", + " self.TBASEM = BoundedParameter(TBASEM_MIN, TBASEM_MAX, TBASEM_INIT)\n", + " self.TSUM1 = BoundedParameter(TSUM1_MIN, TSUM1_MAX, TSUM1_INIT)\n", + " self.TSUM2 = BoundedParameter(TSUM2_MIN, TSUM2_MAX, TSUM2_INIT)\n", + "\n", + " def forward(self):\n", + " # currently, copying is needed due to an internal issue in engine\n", + " crop_model_params_provider_ = copy.deepcopy(self.crop_model_params_provider)\n", + "\n", + " TSUMEM_val = self.TSUMEM()\n", + " TBASEM_val = self.TBASEM()\n", + " TSUM1_val = self.TSUM1()\n", + " TSUM2_val = self.TSUM2()\n", + " \n", + " # pass new value of parameters to the model\n", + " crop_model_params_provider_.set_override(\"TSUMEM\", TSUMEM_val, check=False)\n", + " crop_model_params_provider_.set_override(\"TBASEM\", TBASEM_val, check=False)\n", + " crop_model_params_provider_.set_override(\"TSUM1\", TSUM1_val, check=False)\n", + " crop_model_params_provider_.set_override(\"TSUM2\", TSUM2_val, check=False)\n", + "\n", + " engine = EngineTestHelper(\n", + " crop_model_params_provider_,\n", + " self.weather_data_provider,\n", + " self.agro_management_inputs,\n", + " self.config,\n", + " )\n", + " engine.run_till_terminate()\n", + " results = engine.get_output()\n", + " \n", + " return torch.stack([item[\"DVS\"] for item in results]) # shape: [1, time_steps]" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "2a0754ac-4cf1-4ed7-9059-af80484beb33", + "metadata": {}, + "outputs": [], + "source": [ + "# ---- Create model ---- \n", + "opt_model = OptDiffPhenology(\n", + " crop_model_params_provider,\n", + " weather_data_provider,\n", + " agro_management_inputs,\n", + " phenology_config,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "124a6077-64b7-4816-b42f-538e3f8e0538", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Step 0: duration mismatch (260 vs 279).\n", + "Step 0, Loss 0.1490, TSUMEM 85.0787, TBASEM 1.8448, TSUM1 815.5215, TSUM2 815.5215,\n", + "Step 1: duration mismatch (262 vs 279).\n", + "Step 1, Loss 0.1348, TSUMEM 80.2344, TBASEM 1.6999, TSUM1 830.0543, TSUM2 830.0643,\n", + "Step 2: duration mismatch (263 vs 279).\n", + "Step 2, Loss 0.1197, TSUMEM 77.2860, TBASEM 1.6076, TSUM1 843.6052, TSUM2 843.6012,\n", + "Step 3: duration mismatch (264 vs 279).\n", + "Step 3, Loss 0.1147, TSUMEM 76.5338, TBASEM 1.5720, TSUM1 856.1688, TSUM2 856.1740,\n", + "Step 4: duration mismatch (266 vs 279).\n", + "Step 4, Loss 0.1019, TSUMEM 77.1810, TBASEM 1.5731, TSUM1 867.7785, TSUM2 867.8158,\n", + "Step 5: duration mismatch (267 vs 279).\n", + "Step 5, Loss 0.0881, TSUMEM 78.6763, TBASEM 1.5976, TSUM1 878.4762, TSUM2 878.5369,\n", + "Step 6: duration mismatch (268 vs 279).\n", + "Step 6, Loss 0.0830, TSUMEM 80.7683, TBASEM 1.6402, TSUM1 888.2892, TSUM2 888.3950,\n", + "Step 7: duration mismatch (269 vs 279).\n", + "Step 7, Loss 0.0698, TSUMEM 82.9896, TBASEM 1.6870, TSUM1 897.2725, TSUM2 897.4227,\n", + "Step 8: duration mismatch (270 vs 279).\n", + "Step 8, Loss 0.0568, TSUMEM 84.5758, TBASEM 1.7161, TSUM1 905.4835, TSUM2 905.6589,\n", + "Step 9: duration mismatch (271 vs 279).\n", + "Step 9, Loss 0.0521, TSUMEM 84.9177, TBASEM 1.7125, TSUM1 912.9635, TSUM2 913.1725,\n", + "Step 10: duration mismatch (271 vs 279).\n", + "Step 10, Loss 0.0480, TSUMEM 84.3238, TBASEM 1.6843, TSUM1 919.7631, TSUM2 920.0091,\n", + "Step 11: duration mismatch (273 vs 279).\n", + "Step 11, Loss 0.0381, TSUMEM 83.4182, TBASEM 1.6478, TSUM1 925.9421, TSUM2 926.2325,\n", + "Step 12: duration mismatch (273 vs 279).\n", + "Step 12, Loss 0.0355, TSUMEM 82.3086, TBASEM 1.6063, TSUM1 931.5499, TSUM2 931.8865,\n", + "Step 13: duration mismatch (273 vs 279).\n", + "Step 13, Loss 0.0324, TSUMEM 81.3026, TBASEM 1.5680, TSUM1 936.6345, TSUM2 937.0161,\n", + "Step 14: duration mismatch (275 vs 279).\n", + "Step 14, Loss 0.0245, TSUMEM 80.8495, TBASEM 1.5439, TSUM1 941.2473, TSUM2 941.6774,\n", + "Step 15: duration mismatch (275 vs 279).\n", + "Step 15, Loss 0.0220, TSUMEM 81.1065, TBASEM 1.5381, TSUM1 945.4302, TSUM2 945.9092,\n", + "Step 16: duration mismatch (275 vs 279).\n", + "Step 16, Loss 0.0197, TSUMEM 81.9637, TBASEM 1.5478, TSUM1 949.2226, TSUM2 949.7485,\n", + "Step 17: duration mismatch (276 vs 279).\n", + "Step 17, Loss 0.0103, TSUMEM 83.1409, TBASEM 1.5657, TSUM1 952.6663, TSUM2 953.2308,\n", + "Step 18: duration mismatch (277 vs 279).\n", + "Step 18, Loss 0.0093, TSUMEM 84.1272, TBASEM 1.5787, TSUM1 955.4659, TSUM2 956.3961,\n", + "Step 19: duration mismatch (277 vs 279).\n", + "Step 19, Loss 0.0093, TSUMEM 84.7385, TBASEM 1.5820, TSUM1 957.7150, TSUM2 959.2729,\n", + "Step 20: duration mismatch (277 vs 279).\n", + "Step 20, Loss 0.0093, TSUMEM 85.0120, TBASEM 1.5765, TSUM1 959.5129, TSUM2 961.8885,\n", + "Step 21: duration mismatch (277 vs 279).\n", + "Step 21, Loss 0.0092, TSUMEM 84.9791, TBASEM 1.5633, TSUM1 960.9411, TSUM2 964.2680,\n", + "Step 22: duration mismatch (277 vs 279).\n", + "Step 22, Loss 0.0091, TSUMEM 84.6666, TBASEM 1.5432, TSUM1 962.0599, TSUM2 966.4341,\n", + "Step 23: duration mismatch (278 vs 279).\n", + "Step 23, Loss 0.0090, TSUMEM 84.0982, TBASEM 1.5171, TSUM1 962.9180, TSUM2 968.4114,\n", + "Step 24, Loss 0.0078, TSUMEM 83.4926, TBASEM 1.4905, TSUM1 963.5505, TSUM2 970.0585,\n", + "Step 25, Loss 0.0082, TSUMEM 83.2271, TBASEM 1.4719, TSUM1 963.9872, TSUM2 971.4006,\n", + "Step 26, Loss 0.0086, TSUMEM 83.4078, TBASEM 1.4639, TSUM1 964.2517, TSUM2 972.4788,\n", + "Step 27, Loss 0.0090, TSUMEM 83.9896, TBASEM 1.4651, TSUM1 964.3623, TSUM2 973.3393,\n", + "Step 28, Loss 0.0092, TSUMEM 84.8013, TBASEM 1.4715, TSUM1 964.3331, TSUM2 974.0173,\n", + "Step 29, Loss 0.0093, TSUMEM 85.4506, TBASEM 1.4742, TSUM1 964.1751, TSUM2 974.5405,\n", + "Step 30, Loss 0.0094, TSUMEM 85.9211, TBASEM 1.4726, TSUM1 963.8970, TSUM2 974.9305,\n", + "Step 31, Loss 0.0094, TSUMEM 86.0486, TBASEM 1.4633, TSUM1 963.5046, TSUM2 975.2042,\n", + "Step 32, Loss 0.0093, TSUMEM 85.8631, TBASEM 1.4471, TSUM1 963.0023, TSUM2 975.3755,\n", + "Step 33, Loss 0.0092, TSUMEM 85.6006, TBASEM 1.4293, TSUM1 962.3921, TSUM2 975.4550,\n", + "Step 34, Loss 0.0090, TSUMEM 85.2661, TBASEM 1.4101, TSUM1 961.6753, TSUM2 975.4510,\n", + "Early stopping at step 34\n", + "duration (model 279 vs test 279).\n" + ] + } + ], + "source": [ + "# ---- Early stopping ---- \n", + "best_loss = float(\"inf\")\n", + "patience = 10 # Number of steps to wait for improvement\n", + "patience_counter = 0\n", + "min_delta = 1e-4 \n", + "\n", + "# ---- Optimizer ---- \n", + "optimizer = torch.optim.Adam(opt_model.parameters(), lr=0.1)\n", + "\n", + "# ---- We use relative MAE as loss because there are two outputs with different untis ---- \n", + "denom = torch.mean(torch.abs(expected_dvs)) \n", + "\n", + "# Training loop (example)\n", + "for step in range(101):\n", + " optimizer.zero_grad()\n", + " results = opt_model() \n", + " \n", + " # phenology parameters can change the simulation duration\n", + " min_len = min(len(results), len(expected_dvs))\n", + " if len(results) != len(expected_dvs):\n", + " print(f\"Step {step}: duration mismatch ({len(results)} vs {len(expected_dvs)}).\")\n", + " \n", + " mae = torch.mean(torch.abs(results[:min_len] - expected_dvs[:min_len]))\n", + " loss = mae / denom # example: relative mean absolute error\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " print(\n", + " f\"Step {step}, Loss {loss.item():.4f}, \"\n", + " f\"TSUMEM {opt_model.TSUMEM().item():.4f}, \"\n", + " f\"TBASEM {opt_model.TBASEM().item():.4f}, \"\n", + " f\"TSUM1 {opt_model.TSUM1().item():.4f}, \"\n", + " f\"TSUM2 {opt_model.TSUM2().item():.4f},\"\n", + " )\n", + " \n", + " # Early stopping logic\n", + " if loss.item() < best_loss - min_delta:\n", + " best_loss = loss.item()\n", + " patience_counter = 0\n", + " else:\n", + " patience_counter += 1\n", + " if patience_counter >= patience:\n", + " print(f\"Early stopping at step {step}\")\n", + " print(f\"duration (model {len(results)} vs test {len(expected_dvs)}).\")\n", + " break" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "5b0a6b12-3ca9-4cf3-9dd5-ac2c11bf4fc6", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Actual TSUMEM 110.0000 TBASEM 0.0000 Actual TSUM1 950.0000 TSUM2 991.0000\n" + ] + } + ], + "source": [ + "# ---- validate the results using test data ---- \n", + "print(\n", + " f\"Actual TSUMEM {crop_model_params_provider[\"TSUMEM\"].item():.4f}\",\n", + " f\"TBASEM {crop_model_params_provider[\"TBASEM\"].item():.4f}\",\n", + " f\"Actual TSUM1 {crop_model_params_provider[\"TSUM1\"].item():.4f}\", \n", + " f\"TSUM2 {crop_model_params_provider[\"TSUM2\"].item():.4f}\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b1e06239-c037-4433-9da2-feb46b52a8e4", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tests/physical_models/crop/test_phenology.py b/tests/physical_models/crop/test_phenology.py index 6e16cf0..06b6dee 100644 --- a/tests/physical_models/crop/test_phenology.py +++ b/tests/physical_models/crop/test_phenology.py @@ -664,9 +664,13 @@ def test_gradients_numerical(self, param_name, output_name, config_type): output = model({param_name: param}) loss = output[output_name].sum() grads = torch.autograd.grad(loss, param, retain_graph=True)[0] - rtol = 0.005 - assert torch.all( - torch.abs(numerical_grad - grads.data) / (torch.abs(grads.data) + 1e-8) < rtol + + # here tol is relaxed due to approximations + torch.testing.assert_close( + numerical_grad, + grads, + rtol=1e-2, + atol=1e-2, ) if torch.all(grads == 0): warnings.warn(