diff --git a/docs/getting_started/crs_management.ipynb b/docs/getting_started/crs_management.ipynb index 7e0ff7c7..65b583be 100644 --- a/docs/getting_started/crs_management.ipynb +++ b/docs/getting_started/crs_management.ipynb @@ -29,6 +29,56 @@ "Operations on xarray objects can cause data loss. Due to this, rioxarray writes and expects the spatial reference information to exist in the coordinates." ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Conventions\n", + "\n", + "rioxarray supports multiple conventions for storing geospatial metadata. The convention system provides a flexible way to read and write CRS and transform information.\n", + "\n", + "### Supported Conventions\n", + "\n", + "- **CF (Climate and Forecasts)**: The default convention, using `grid_mapping` coordinates with attributes like `spatial_ref`, `crs_wkt`, and `GeoTransform`. This is the standard for netCDF files in the geospatial community.\n", + "\n", + "### How Conventions Work\n", + "\n", + "- **Reading**: If a convention is set globally, that convention is tried **first** for better performance. If not found, other conventions are tried as fallback. This allows you to optimize reads when you know the data format.\n", + "- **Writing**: Uses the global `convention` setting (default: CF) or a per-method `convention` parameter.\n", + "\n", + "### Setting the Convention\n", + "\n", + "You can set the convention globally using `set_options()`:\n", + "\n", + "```python\n", + "from rioxarray import set_options\n", + "from rioxarray.enum import Convention\n", + "\n", + "# Set globally - reads will try CF first, writes will use CF\n", + "set_options(convention=Convention.CF)\n", + "\n", + "# Or use as a context manager\n", + "with set_options(convention=Convention.CF):\n", + " # CF convention is tried first when reading\n", + " crs = data.rio.crs\n", + " # CF convention is used for writing\n", + " data.rio.write_crs(\"EPSG:4326\", inplace=True)\n", + "```\n", + "\n", + "Or specify the convention per-method (for writing only):\n", + "\n", + "```python\n", + "from rioxarray.enum import Convention\n", + "\n", + "data.rio.write_crs(\"EPSG:4326\", convention=Convention.CF, inplace=True)\n", + "```\n", + "\n", + "#### API Documentation\n", + "\n", + "- [rioxarray.set_options](../rioxarray.rst#rioxarray.set_options)\n", + "- [rioxarray.enum.Convention](../rioxarray.rst#rioxarray.enum.Convention)" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -524,12 +574,12 @@ "## Setting the CRS\n", "\n", "Use the `rio.write_crs` method to set the CRS on your `xarray.Dataset` or `xarray.DataArray`.\n", - "This modifies the `xarray.Dataset` or `xarray.DataArray` and sets the CRS in a CF compliant manner.\n", + "This modifies the `xarray.Dataset` or `xarray.DataArray` and sets the CRS using the configured convention (default: CF).\n", "\n", "- [rio.write_crs()](../rioxarray.rst#rioxarray.rioxarray.XRasterBase.write_crs)\n", "- [rio.crs](../rioxarray.rst#rioxarray.rioxarray.XRasterBase.crs)\n", "\n", - "**Note:** It is recommended to use `rio.write_crs()` if you want the CRS to persist on the Dataset/DataArray and to write the CRS CF compliant metadata. Calling only `rio.set_crs()` CRS storage method is lossy and will not modify the Dataset/DataArray metadata." + "**Note:** It is recommended to use `rio.write_crs()` if you want the CRS to persist on the Dataset/DataArray and to write convention-compliant metadata. Calling only `rio.set_crs()` is lossy and will not modify the Dataset/DataArray metadata." ] }, { diff --git a/docs/history.rst b/docs/history.rst index 343ad776..53313e49 100644 --- a/docs/history.rst +++ b/docs/history.rst @@ -3,6 +3,8 @@ History Latest ------ +- ENH: Add `convention` option to `set_options()` for future multi-convention support (pull #899) +- REF: Extract CF convention logic to `_convention/cf.py` module (pull #899) 0.21.0 diff --git a/docs/rioxarray.rst b/docs/rioxarray.rst index b32f8658..8d07a038 100644 --- a/docs/rioxarray.rst +++ b/docs/rioxarray.rst @@ -27,6 +27,15 @@ rioxarray.show_versions .. autofunction:: rioxarray.show_versions +rioxarray.enum module +--------------------- + +.. automodule:: rioxarray.enum + :members: + :undoc-members: + :show-inheritance: + + rioxarray `rio` accessors -------------------------- diff --git a/rioxarray/_convention/__init__.py b/rioxarray/_convention/__init__.py new file mode 100644 index 00000000..7818ff7b --- /dev/null +++ b/rioxarray/_convention/__init__.py @@ -0,0 +1,214 @@ +"""Convention modules for rioxarray. + +This module defines the common interface for convention implementations +and provides helpers for selecting conventions. +""" + +from typing import Dict, Optional, Protocol, Tuple, Union + +import rasterio.crs +import xarray +from affine import Affine + +from rioxarray._convention.cf import CFConvention +from rioxarray._options import CONVENTION, get_option +from rioxarray.crs import crs_from_user_input +from rioxarray.enum import Convention + + +class ConventionProtocol(Protocol): + """Protocol defining the interface for convention modules.""" + + @staticmethod + def read_crs( + obj: Union[xarray.Dataset, xarray.DataArray], **kwargs + ) -> Optional[rasterio.crs.CRS]: + """Read CRS from the object using this convention.""" + + @staticmethod + def read_transform( + obj: Union[xarray.Dataset, xarray.DataArray], **kwargs + ) -> Optional[Affine]: + """Read transform from the object using this convention.""" + + @staticmethod + def read_spatial_dimensions( + obj: Union[xarray.Dataset, xarray.DataArray], + ) -> Optional[Tuple[str, str]]: + """Read spatial dimensions (y_dim, x_dim) from the object using this convention.""" + + @staticmethod + def write_crs( + obj: Union[xarray.Dataset, xarray.DataArray], + crs: rasterio.crs.CRS, + **kwargs, + ) -> Union[xarray.Dataset, xarray.DataArray]: + """Write CRS to the object using this convention.""" + + @staticmethod + def write_transform( + obj: Union[xarray.Dataset, xarray.DataArray], + transform: Affine, + **kwargs, + ) -> Union[xarray.Dataset, xarray.DataArray]: + """Write transform to the object using this convention.""" + + +# Convention classes mapped by Convention enum +_CONVENTION_MODULES: Dict[Convention, ConventionProtocol] = { + Convention.CF: CFConvention # type: ignore[dict-item] +} + + +def _get_convention(convention: Convention | None) -> ConventionProtocol: + """ + Get the convention module for writing. + + Parameters + ---------- + convention : Convention enum value or None + The convention to use. If None, uses the global default. + + Returns + ------- + ConventionProtocol + The module implementing the convention + """ + if convention is None: + convention = get_option(CONVENTION) or Convention.CF + convention = Convention(convention) + return _CONVENTION_MODULES[convention] + + +def read_crs_auto( + obj: Union[xarray.Dataset, xarray.DataArray], + **kwargs, +) -> Optional[rasterio.crs.CRS]: + """ + Auto-detect and read CRS by trying convention readers. + + If a convention is set globally via set_options(), that convention + is tried first for better performance. Then other conventions are + tried as fallback. + + Parameters + ---------- + obj : xarray.Dataset or xarray.DataArray + Object to read CRS from + **kwargs + Convention-specific parameters (e.g., grid_mapping for CF) + + Returns + ------- + rasterio.crs.CRS or None + CRS object, or None if not found in any convention + """ + # Try the configured convention first (if set) + configured_convention = get_option(CONVENTION) + if configured_convention is not None: + result = _CONVENTION_MODULES[configured_convention].read_crs(obj, **kwargs) + if result is not None: + return result + + # Try all other conventions + for conv_enum, convention in _CONVENTION_MODULES.items(): + if conv_enum == configured_convention: + continue # Already tried this one + result = convention.read_crs(obj, **kwargs) + if result is not None: + return result + + # Legacy fallback: look in attrs for 'crs' (not part of any convention) + try: + return crs_from_user_input(obj.attrs["crs"]) + except KeyError: + pass + + return None + + +def read_transform_auto( + obj: Union[xarray.Dataset, xarray.DataArray], + **kwargs, +) -> Optional[Affine]: + """ + Auto-detect and read transform by trying convention readers. + + If a convention is set globally via set_options(), that convention + is tried first for better performance. Then other conventions are + tried as fallback. + + Parameters + ---------- + obj : xarray.Dataset or xarray.DataArray + Object to read transform from + **kwargs + Convention-specific parameters (e.g., grid_mapping for CF) + + Returns + ------- + affine.Affine or None + Transform object, or None if not found in any convention + """ + # Try the configured convention first (if set) + configured_convention = get_option(CONVENTION) + if configured_convention is not None: + result = _CONVENTION_MODULES[configured_convention].read_transform( + obj, **kwargs + ) + if result is not None: + return result + + # Try all other conventions + for conv_enum, convention in _CONVENTION_MODULES.items(): + if conv_enum == configured_convention: + continue # Already tried this one + result = convention.read_transform(obj, **kwargs) + if result is not None: + return result + + # Legacy fallback: look in attrs for 'transform' (not part of any convention) + try: + return Affine(*obj.attrs["transform"][:6]) + except KeyError: + pass + + return None + + +def read_spatial_dimensions_auto( + obj: Union[xarray.Dataset, xarray.DataArray], +) -> Optional[Tuple[str, str]]: + """ + Auto-detect and read spatial dimensions by trying convention readers. + + If a convention is set globally via set_options(), that convention + is tried first for better performance. Then other conventions are + tried as fallback. + + Parameters + ---------- + obj : xarray.Dataset or xarray.DataArray + Object to read spatial dimensions from + + Returns + ------- + tuple of (y_dim, x_dim) or None + Tuple of dimension names, or None if not found in any convention + """ + # Try the configured convention first (if set) + configured_convention = get_option(CONVENTION) + if configured_convention is not None: + result = _CONVENTION_MODULES[configured_convention].read_spatial_dimensions(obj) + if result is not None: + return result + + # Try all other conventions + for conv_enum, convention in _CONVENTION_MODULES.items(): + if conv_enum == configured_convention: + continue # Already tried this one + result = convention.read_spatial_dimensions(obj) + if result is not None: + return result + + return None diff --git a/rioxarray/_convention/cf.py b/rioxarray/_convention/cf.py new file mode 100644 index 00000000..37ee278f --- /dev/null +++ b/rioxarray/_convention/cf.py @@ -0,0 +1,359 @@ +""" +CF (Climate and Forecasts) convention support for rioxarray. + +This module provides functions for reading and writing geospatial metadata according to +the CF conventions: https://github.com/cf-convention/cf-conventions +""" +from typing import Optional, Tuple, Union + +import numpy +import pyproj +import rasterio.crs +import xarray +from affine import Affine + +from rioxarray._options import EXPORT_GRID_MAPPING, get_option +from rioxarray._spatial_utils import ( + DEFAULT_GRID_MAP, + _get_spatial_dims, + _has_spatial_dims, +) +from rioxarray.crs import crs_from_user_input +from rioxarray.exceptions import MissingSpatialDimensionError + + +def _find_grid_mapping( + obj: Union[xarray.Dataset, xarray.DataArray], + *, + grid_mapping: Optional[str] = None, +) -> Optional[str]: + """ + Find the grid_mapping coordinate name. + + Parameters + ---------- + obj : xarray.Dataset or xarray.DataArray + Object to search for grid_mapping + grid_mapping : str, optional + Explicit grid_mapping name to use + + Returns + ------- + str or None + The grid_mapping name, or None if not found + """ + if grid_mapping is not None: + return grid_mapping + + # Try to find grid_mapping attribute on data variables + if hasattr(obj, "data_vars"): + for data_var in obj.data_vars.values(): + if "grid_mapping" in data_var.attrs: + return data_var.attrs["grid_mapping"] + if "grid_mapping" in data_var.encoding: + return data_var.encoding["grid_mapping"] + + if hasattr(obj, "attrs") and "grid_mapping" in obj.attrs: + return obj.attrs["grid_mapping"] + + if hasattr(obj, "encoding") and "grid_mapping" in obj.encoding: + return obj.encoding["grid_mapping"] + + return None + + +def read_crs( + obj: Union[xarray.Dataset, xarray.DataArray], *, grid_mapping: Optional[str] = None +) -> Optional[rasterio.crs.CRS]: + """ + Read CRS from CF conventions. + + Parameters + ---------- + obj : xarray.Dataset or xarray.DataArray + Object to read CRS from + grid_mapping : str, optional + Name of the grid_mapping coordinate variable + + Returns + ------- + rasterio.crs.CRS or None + CRS object, or None if not found + """ + grid_mapping = _find_grid_mapping(obj, grid_mapping=grid_mapping) + + if grid_mapping is not None: + try: + grid_mapping_coord = obj.coords[grid_mapping] + + # Look in wkt attributes first for performance + for crs_attr in ("spatial_ref", "crs_wkt"): + try: + return crs_from_user_input(grid_mapping_coord.attrs[crs_attr]) + except KeyError: + pass + + # Look in grid_mapping CF attributes + try: + return pyproj.CRS.from_cf(grid_mapping_coord.attrs) + except (KeyError, pyproj.exceptions.CRSError): + pass + except KeyError: + # grid_mapping coordinate doesn't exist + pass + + return None + + +def read_transform( + obj: Union[xarray.Dataset, xarray.DataArray], *, grid_mapping: Optional[str] = None +) -> Optional[Affine]: + """ + Read transform from CF conventions (GeoTransform attribute). + + Parameters + ---------- + obj : xarray.Dataset or xarray.DataArray + Object to read transform from + grid_mapping : str, optional + Name of the grid_mapping coordinate variable + + Returns + ------- + affine.Affine or None + Transform object, or None if not found + """ + grid_mapping = _find_grid_mapping(obj, grid_mapping=grid_mapping) + + if grid_mapping is not None: + try: + transform = numpy.fromstring( + obj.coords[grid_mapping].attrs["GeoTransform"], sep=" " + ) + # Calling .tolist() to assure the arguments are Python float and JSON serializable + return Affine.from_gdal(*transform.tolist()) + except KeyError: + pass + + return None + + +def read_spatial_dimensions( + obj: Union[xarray.Dataset, xarray.DataArray], +) -> Optional[Tuple[str, str]]: + """ + Read spatial dimensions from CF conventions. + + This function detects spatial dimensions based on: + 1. Standard dimension names ('x'/'y', 'longitude'/'latitude') + 2. CF coordinate attributes ('axis', 'standard_name') + + Parameters + ---------- + obj : xarray.Dataset or xarray.DataArray + Object to read spatial dimensions from + + Returns + ------- + tuple of (y_dim, x_dim) or None + Tuple of dimension names, or None if not found + """ + x_dim = None + y_dim = None + + # Check standard dimension names + if "x" in obj.dims and "y" in obj.dims: + return "y", "x" + if "longitude" in obj.dims and "latitude" in obj.dims: + return "latitude", "longitude" + + # Look for coordinates with CF attributes + for coord in obj.coords: + # Make sure to only look in 1D coordinates + # that has the same dimension name as the coordinate + if obj.coords[coord].dims != (coord,): + continue + if (obj.coords[coord].attrs.get("axis", "").upper() == "X") or ( + obj.coords[coord].attrs.get("standard_name", "").lower() + in ("longitude", "projection_x_coordinate") + ): + x_dim = coord + elif (obj.coords[coord].attrs.get("axis", "").upper() == "Y") or ( + obj.coords[coord].attrs.get("standard_name", "").lower() + in ("latitude", "projection_y_coordinate") + ): + y_dim = coord + + if x_dim is not None and y_dim is not None: + return str(y_dim), str(x_dim) + + return None + + +def write_crs( + obj: Union[xarray.Dataset, xarray.DataArray], + *, + crs: rasterio.crs.CRS, + **kwargs, +) -> Union[xarray.Dataset, xarray.DataArray]: + """ + Write CRS using CF conventions. + + This also writes the grid_mapping attribute to encoding for CF compliance. + + Parameters + ---------- + obj : xarray.Dataset or xarray.DataArray + Object to write CRS to + crs : rasterio.crs.CRS + CRS to write + **kwargs + grid_mapping_name : str + Name of the grid_mapping coordinate (required for CF) + + Returns + ------- + xarray.Dataset or xarray.DataArray + Object with CRS written + """ + grid_mapping_name = kwargs.get("grid_mapping_name") + if grid_mapping_name is None: + # Get grid_mapping from encoding/attrs or use default + grid_mapping_name = _find_grid_mapping(obj) or DEFAULT_GRID_MAP + + # Get original transform before modifying (pass grid_mapping_name to find it) + transform = read_transform(obj, grid_mapping=grid_mapping_name) + + # Remove old grid mapping coordinate if exists + try: + del obj.coords[grid_mapping_name] + except KeyError: + pass + + # Add grid mapping coordinate + obj.coords[grid_mapping_name] = xarray.Variable((), 0) + grid_map_attrs = {} + if get_option(EXPORT_GRID_MAPPING): + try: + grid_map_attrs = pyproj.CRS.from_user_input(crs).to_cf() + except KeyError: + pass + + # spatial_ref is for compatibility with GDAL + crs_wkt = crs.to_wkt() + grid_map_attrs["spatial_ref"] = crs_wkt + grid_map_attrs["crs_wkt"] = crs_wkt + if transform is not None: + grid_map_attrs["GeoTransform"] = " ".join( + [str(item) for item in transform.to_gdal()] + ) + obj.coords[grid_mapping_name].attrs = grid_map_attrs + + # Write grid_mapping to encoding (CF specific) + obj = _write_grid_mapping(obj, grid_mapping_name=grid_mapping_name) + + return obj + + +def _write_grid_mapping( + obj: Union[xarray.Dataset, xarray.DataArray], + *, + grid_mapping_name: str, +) -> Union[xarray.Dataset, xarray.DataArray]: + """ + Write the CF grid_mapping attribute to the encoding. + + Parameters + ---------- + obj : xarray.Dataset or xarray.DataArray + Object to write grid_mapping to + grid_mapping_name : str + Name of the grid_mapping coordinate + + Returns + ------- + xarray.Dataset or xarray.DataArray + Object with grid_mapping written to encoding + """ + if hasattr(obj, "data_vars"): + for var in obj.data_vars: + if not _has_spatial_dims(obj, var=var): + continue + try: + x_dim, y_dim = _get_spatial_dims(obj, var=var) + except MissingSpatialDimensionError: + continue + # remove grid_mapping from attributes if it exists + # and update the grid_mapping in encoding + new_attrs = dict(obj[var].attrs) + new_attrs.pop("grid_mapping", None) + obj[var].attrs = new_attrs + obj[var].encoding["grid_mapping"] = grid_mapping_name + obj[var].rio.set_spatial_dims(x_dim=x_dim, y_dim=y_dim, inplace=True) + + # remove grid_mapping from attributes if it exists + # and update the grid_mapping in encoding + new_attrs = dict(obj.attrs) + new_attrs.pop("grid_mapping", None) + obj.attrs = new_attrs + obj.encoding["grid_mapping"] = grid_mapping_name + + return obj + + +def write_transform( + obj: Union[xarray.Dataset, xarray.DataArray], + *, + transform: Affine, + **kwargs, +) -> Union[xarray.Dataset, xarray.DataArray]: + """ + Write transform using CF conventions (GeoTransform attribute). + + This also writes the grid_mapping attribute to encoding for CF compliance. + + Parameters + ---------- + obj : xarray.Dataset or xarray.DataArray + Object to write transform to + transform : affine.Affine + Transform to write + **kwargs + grid_mapping_name : str + Name of the grid_mapping coordinate (required for CF) + + Returns + ------- + xarray.Dataset or xarray.DataArray + Object with transform written + """ + grid_mapping_name = kwargs.get("grid_mapping_name") + if grid_mapping_name is None: + # Get grid_mapping from encoding/attrs or use default + grid_mapping_name = _find_grid_mapping(obj) or DEFAULT_GRID_MAP + + try: + grid_map_attrs = obj.coords[grid_mapping_name].attrs.copy() + except KeyError: + obj.coords[grid_mapping_name] = xarray.Variable((), 0) + grid_map_attrs = obj.coords[grid_mapping_name].attrs.copy() + + grid_map_attrs["GeoTransform"] = " ".join( + [str(item) for item in transform.to_gdal()] + ) + obj.coords[grid_mapping_name].attrs = grid_map_attrs + + # Write grid_mapping to encoding (CF specific) + obj = _write_grid_mapping(obj, grid_mapping_name=grid_mapping_name) + + return obj + + +class CFConvention: + """CF convention class implementing ConventionProtocol.""" + + read_crs = staticmethod(read_crs) + read_transform = staticmethod(read_transform) + read_spatial_dimensions = staticmethod(read_spatial_dimensions) + write_crs = staticmethod(write_crs) + write_transform = staticmethod(write_transform) diff --git a/rioxarray/_options.py b/rioxarray/_options.py index 1dc55ffa..95fdb57d 100644 --- a/rioxarray/_options.py +++ b/rioxarray/_options.py @@ -6,19 +6,30 @@ This file was adopted from: https://github.com/pydata/xarray # noqa Source file: https://github.com/pydata/xarray/blob/2ab0666c1fcc493b1e0ebc7db14500c427f8804e/xarray/core/options.py # noqa """ -from typing import Any +from typing import Any, Optional + +from rioxarray.enum import Convention EXPORT_GRID_MAPPING = "export_grid_mapping" SKIP_MISSING_SPATIAL_DIMS = "skip_missing_spatial_dims" +CONVENTION = "convention" -OPTIONS = { +OPTIONS: dict[str, Any] = { EXPORT_GRID_MAPPING: True, SKIP_MISSING_SPATIAL_DIMS: False, + CONVENTION: None, } OPTION_NAMES = set(OPTIONS) + +def _validate_convention(value: Optional[Convention]) -> bool: + """Validate the convention option.""" + return value is None or isinstance(value, Convention) + + VALIDATORS = { EXPORT_GRID_MAPPING: lambda choice: isinstance(choice, bool), + CONVENTION: _validate_convention, } @@ -46,6 +57,7 @@ class set_options: # pylint: disable=invalid-name .. versionadded:: 0.3.0 .. versionadded:: 0.7.0 skip_missing_spatial_dims + .. versionadded:: 0.22.0 convention Parameters ---------- @@ -60,6 +72,10 @@ class set_options: # pylint: disable=invalid-name If True, it will not perform spatial operations on variables within a :class:`xarray.Dataset` if the spatial dimensions are not found. + convention: Convention, default=None + The convention to use for reading and writing geospatial metadata. + If None, CF convention is used as the default. + See :class:`rioxarray.enum.Convention` for available options. Usage as a context manager:: diff --git a/rioxarray/_spatial_utils.py b/rioxarray/_spatial_utils.py index 05c92d26..c651f012 100644 --- a/rioxarray/_spatial_utils.py +++ b/rioxarray/_spatial_utils.py @@ -365,8 +365,11 @@ def _add_attrs_proj( new_data_array.rio.set_attrs(new_attrs, inplace=True) # make sure projection added - new_data_array.rio.write_grid_mapping(src_data_array.rio.grid_mapping, inplace=True) - new_data_array.rio.write_crs(src_data_array.rio.crs, inplace=True) + new_data_array.rio.write_crs( + src_data_array.rio.crs, + grid_mapping_name=src_data_array.rio.grid_mapping, + inplace=True, + ) new_data_array.rio.write_coordinate_system(inplace=True) new_data_array.rio.write_transform(inplace=True) # make sure encoding added diff --git a/rioxarray/enum.py b/rioxarray/enum.py new file mode 100644 index 00000000..33cf7fd0 --- /dev/null +++ b/rioxarray/enum.py @@ -0,0 +1,41 @@ +"""Enums for rioxarray.""" +from enum import Enum + + +class Convention(Enum): + """ + Supported geospatial metadata conventions. + + rioxarray supports conventions for storing geospatial metadata. + Currently supported: + + - CF: Climate and Forecasts convention using grid_mapping coordinates + + The convention can be set globally using set_options() or per-method + using the convention parameter. + + Examples + -------- + Set global convention: + + >>> import rioxarray + >>> from rioxarray.enum import Convention + >>> rioxarray.set_options(convention=Convention.CF) + + Use specific convention for a method: + + >>> from rioxarray.enum import Convention + >>> data.rio.write_crs("EPSG:4326", convention=Convention.CF) + + See Also + -------- + rioxarray.set_options : Set global options including convention + + References + ---------- + .. [1] CF Conventions: https://github.com/cf-convention/cf-conventions + """ + + #: Climate and Forecasts convention (default) + #: https://github.com/cf-convention/cf-conventions + CF = "CF" diff --git a/rioxarray/rioxarray.py b/rioxarray/rioxarray.py index 17d61e14..f1d6e05c 100644 --- a/rioxarray/rioxarray.py +++ b/rioxarray/rioxarray.py @@ -11,7 +11,6 @@ from typing import Any, Literal, Optional, Union import numpy -import pyproj import rasterio.warp import rasterio.windows import xarray @@ -22,7 +21,13 @@ from rasterio.crs import CRS from rasterio.rpc import RPC -from rioxarray._options import EXPORT_GRID_MAPPING, get_option +from rioxarray._convention import ( + _get_convention, + cf, + read_crs_auto, + read_spatial_dimensions_auto, + read_transform_auto, +) from rioxarray._spatial_utils import ( # noqa: F401, pylint: disable=unused-import DEFAULT_GRID_MAP, _affine_has_rotation, @@ -35,6 +40,7 @@ affine_to_coords, ) from rioxarray.crs import crs_from_user_input +from rioxarray.enum import Convention from rioxarray.exceptions import ( DimensionError, DimensionMissingCoordinateError, @@ -57,30 +63,11 @@ def __init__(self, xarray_obj: Union[xarray.DataArray, xarray.Dataset]): self._x_dim: Optional[Hashable] = None self._y_dim: Optional[Hashable] = None - # Determine the spatial dimensions of the `xarray.DataArray` - if "x" in self._obj.dims and "y" in self._obj.dims: - self._x_dim = "x" - self._y_dim = "y" - elif "longitude" in self._obj.dims and "latitude" in self._obj.dims: - self._x_dim = "longitude" - self._y_dim = "latitude" - else: - # look for coordinates with CF attributes - for coord in self._obj.coords: - # make sure to only look in 1D coordinates - # that has the same dimension name as the coordinate - if self._obj.coords[coord].dims != (coord,): - continue - if (self._obj.coords[coord].attrs.get("axis", "").upper() == "X") or ( - self._obj.coords[coord].attrs.get("standard_name", "").lower() - in ("longitude", "projection_x_coordinate") - ): - self._x_dim = coord - elif (self._obj.coords[coord].attrs.get("axis", "").upper() == "Y") or ( - self._obj.coords[coord].attrs.get("standard_name", "").lower() - in ("latitude", "projection_y_coordinate") - ): - self._y_dim = coord + + # Auto-detect spatial dimensions from any supported convention + spatial_dims = read_spatial_dimensions_auto(self._obj) + if spatial_dims is not None: + self._y_dim, self._x_dim = spatial_dims # properties self._count: Optional[int] = None @@ -98,32 +85,16 @@ def crs(self) -> Optional[rasterio.crs.CRS]: if self._crs is not None: return None if self._crs is False else self._crs - # look in wkt attributes to avoid using - # pyproj CRS if possible for performance - for crs_attr in ("spatial_ref", "crs_wkt"): - try: - self._set_crs( - self._obj.coords[self.grid_mapping].attrs[crs_attr], - inplace=True, - ) - return self._crs - except KeyError: - pass + # Auto-detect CRS from any supported convention + parsed_crs = read_crs_auto(self._obj, grid_mapping=self.grid_mapping) - # look in grid_mapping - try: - self._set_crs( - pyproj.CRS.from_cf(self._obj.coords[self.grid_mapping].attrs), - inplace=True, - ) - except (KeyError, pyproj.exceptions.CRSError): - try: - # look in attrs for 'crs' - self._set_crs(self._obj.attrs["crs"], inplace=True) - except KeyError: - self._crs = False - return None - return self._crs + if parsed_crs is not None: + self._set_crs(parsed_crs, inplace=True) + return self._crs + + # No CRS found + self._crs = False + return None def _get_obj(self, inplace: bool) -> Union[xarray.Dataset, xarray.DataArray]: """ @@ -235,14 +206,16 @@ def grid_mapping(self) -> str: return grid_mapping def write_grid_mapping( - self, grid_mapping_name: str = DEFAULT_GRID_MAP, inplace: bool = False + self, + grid_mapping_name: str = DEFAULT_GRID_MAP, + inplace: bool = False, ) -> Union[xarray.Dataset, xarray.DataArray]: """ - Write the CF grid_mapping attribute to the encoding. + Write the grid_mapping attribute to the encoding. Parameters ---------- - grid_mapping_name: str, optional + grid_mapping_name: str Name of the grid_mapping coordinate. inplace: bool, optional If True, it will write to the existing dataset. Default is False. @@ -250,42 +223,26 @@ def write_grid_mapping( Returns ------- :obj:`xarray.Dataset` | :obj:`xarray.DataArray`: - Modified dataset with CF compliant CRS information. + Modified dataset with grid_mapping written. + + See Also + -------- + :meth:`rioxarray.rioxarray.XRasterBase.write_crs` """ data_obj = self._get_obj(inplace=inplace) - if hasattr(data_obj, "data_vars"): - for var in data_obj.data_vars: - try: - x_dim, y_dim = _get_spatial_dims(data_obj, var=var) - except MissingSpatialDimensionError: - continue - # remove grid_mapping from attributes if it exists - # and update the grid_mapping in encoding - new_attrs = dict(data_obj[var].attrs) - new_attrs.pop("grid_mapping", None) - data_obj[var].rio.update_encoding( - {"grid_mapping": grid_mapping_name}, inplace=True - ).rio.set_attrs(new_attrs, inplace=True).rio.set_spatial_dims( - x_dim=x_dim, y_dim=y_dim, inplace=True - ) - # remove grid_mapping from attributes if it exists - # and update the grid_mapping in encoding - new_attrs = dict(data_obj.attrs) - new_attrs.pop("grid_mapping", None) - return data_obj.rio.update_encoding( - {"grid_mapping": grid_mapping_name}, inplace=True - ).rio.set_attrs(new_attrs, inplace=True) + return cf._write_grid_mapping(data_obj, grid_mapping_name=grid_mapping_name) def write_crs( self, input_crs: Optional[Any] = None, grid_mapping_name: Optional[str] = None, + convention: Optional[Convention] = None, inplace: bool = False, ) -> Union[xarray.Dataset, xarray.DataArray]: """ - Write the CRS to the dataset in a CF compliant manner. + Write the CRS to the dataset using the specified convention. - .. warning:: The grid_mapping attribute is written to the encoding. + .. warning:: When using CF convention, the grid_mapping attribute is written to the encoding. Parameters ---------- @@ -293,14 +250,17 @@ def write_crs( Anything accepted by `rasterio.crs.CRS.from_user_input`. grid_mapping_name: str, optional Name of the grid_mapping coordinate to store the CRS information in. - Default is the grid_mapping name of the dataset. + Only used with CF convention. Default is the grid_mapping name of the dataset. + convention: Convention, optional + Convention to use for writing CRS. If None, uses the global default + from set_options(). Currently only CF convention is supported. inplace: bool, optional If True, it will write to the existing dataset. Default is False. Returns ------- :obj:`xarray.Dataset` | :obj:`xarray.DataArray`: - Modified dataset with CF compliant CRS information. + Modified dataset with CRS information. Examples -------- @@ -317,44 +277,21 @@ def write_crs( else: data_obj = self._get_obj(inplace=inplace) - # get original transform - transform = self._cached_transform() - # remove old grid maping coordinate if exists - grid_mapping_name = ( - self.grid_mapping if grid_mapping_name is None else grid_mapping_name - ) - try: - del data_obj.coords[grid_mapping_name] - except KeyError: - pass - if data_obj.rio.crs is None: raise MissingCRS( "CRS not found. Please set the CRS with 'rio.write_crs()'." ) - # add grid mapping coordinate - data_obj.coords[grid_mapping_name] = xarray.Variable((), 0) - grid_map_attrs = {} - if get_option(EXPORT_GRID_MAPPING): - try: - grid_map_attrs = pyproj.CRS.from_user_input(data_obj.rio.crs).to_cf() - except KeyError: - pass - # spatial_ref is for compatibility with GDAL - crs_wkt = data_obj.rio.crs.to_wkt() - grid_map_attrs["spatial_ref"] = crs_wkt - grid_map_attrs["crs_wkt"] = crs_wkt - if transform is not None: - grid_map_attrs["GeoTransform"] = " ".join( - [str(item) for item in transform.to_gdal()] - ) - data_obj.coords[grid_mapping_name].rio.set_attrs(grid_map_attrs, inplace=True) - # remove old crs if exists + # Remove legacy crs attr (not part of any convention) data_obj.attrs.pop("crs", None) - return data_obj.rio.write_grid_mapping( - grid_mapping_name=grid_mapping_name, inplace=True + # Use the convention module to write CRS + # Pass user input grid_mapping_name (may be None, convention handles default) + convention_module = _get_convention(convention) + return convention_module.write_crs( + data_obj, + crs=data_obj.rio.crs, + grid_mapping_name=grid_mapping_name, ) def estimate_utm_crs(self, datum_name: str = "WGS 84") -> rasterio.crs.CRS: @@ -401,36 +338,23 @@ def estimate_utm_crs(self, datum_name: str = "WGS 84") -> rasterio.crs.CRS: def _cached_transform(self) -> Optional[Affine]: """ - Get the transform from: - 1. The GeoTransform metatada property in the grid mapping - 2. The transform attribute. + Get the transform by auto-detecting from any supported convention. """ - try: - # look in grid_mapping - transform = numpy.fromstring( - self._obj.coords[self.grid_mapping].attrs["GeoTransform"], sep=" " - ) - # Calling .tolist() to assure the arguments are Python float and JSON serializable - return Affine.from_gdal(*transform.tolist()) - - except KeyError: - try: - return Affine(*self._obj.attrs["transform"][:6]) - except KeyError: - pass - return None + return read_transform_auto(self._obj, grid_mapping=self.grid_mapping) def write_transform( self, transform: Optional[Affine] = None, grid_mapping_name: Optional[str] = None, + convention: Optional[Convention] = None, inplace: bool = False, ) -> Union[xarray.Dataset, xarray.DataArray]: """ .. versionadded:: 0.0.30 - Write the GeoTransform to the dataset where GDAL can read it in. + Write the transform to the dataset using the specified convention. + For CF convention, this writes the GeoTransform to the dataset where GDAL can read it in. https://gdal.org/drivers/raster/netcdf.html#georeference Parameters @@ -439,33 +363,31 @@ def write_transform( The transform of the dataset. If not provided, it will be calculated. grid_mapping_name: str, optional Name of the grid_mapping coordinate to store the transform information in. - Default is the grid_mapping name of the dataset. + Only used with CF convention. Default is the grid_mapping name of the dataset. + convention: Convention, optional + Convention to use for writing transform. If None, uses the global default + from set_options(). Currently only CF convention is supported. inplace: bool, optional If True, it will write to the existing dataset. Default is False. Returns ------- :obj:`xarray.Dataset` | :obj:`xarray.DataArray`: - Modified dataset with Geo Transform written. + Modified dataset with transform written. """ transform = transform or self.transform(recalc=True) data_obj = self._get_obj(inplace=inplace) - # delete the old attribute to prevent confusion + + # Remove legacy transform attr (not part of any convention) data_obj.attrs.pop("transform", None) - grid_mapping_name = ( - self.grid_mapping if grid_mapping_name is None else grid_mapping_name - ) - try: - grid_map_attrs = data_obj.coords[grid_mapping_name].attrs.copy() - except KeyError: - data_obj.coords[grid_mapping_name] = xarray.Variable((), 0) - grid_map_attrs = data_obj.coords[grid_mapping_name].attrs.copy() - grid_map_attrs["GeoTransform"] = " ".join( - [str(item) for item in transform.to_gdal()] - ) - data_obj.coords[grid_mapping_name].rio.set_attrs(grid_map_attrs, inplace=True) - return data_obj.rio.write_grid_mapping( - grid_mapping_name=grid_mapping_name, inplace=True + + # Use the convention module to write transform + # Pass user input grid_mapping_name (may be None, convention handles default) + convention_module = _get_convention(convention) + return convention_module.write_transform( + data_obj, + transform=transform, + grid_mapping_name=grid_mapping_name, ) def transform(self, recalc: bool = False) -> Affine: diff --git a/test/unit/test_convention_cf.py b/test/unit/test_convention_cf.py new file mode 100644 index 00000000..7d9eec60 --- /dev/null +++ b/test/unit/test_convention_cf.py @@ -0,0 +1,203 @@ +"""Unit tests for the CF convention module.""" +import numpy +import xarray +from affine import Affine +from rasterio.crs import CRS + +from rioxarray._convention import cf, read_crs_auto, read_transform_auto + + +def test_read_crs__from_grid_mapping_spatial_ref(): + """Test reading CRS from grid_mapping coordinate's spatial_ref attribute.""" + data = xarray.DataArray(numpy.random.rand(10, 10), dims=["y", "x"]) + data.coords["spatial_ref"] = xarray.Variable((), 0) + data.coords["spatial_ref"].attrs["spatial_ref"] = "EPSG:4326" + + crs = cf.read_crs(data, grid_mapping="spatial_ref") + assert crs is not None + assert crs == CRS.from_epsg(4326) + + +def test_read_crs__from_grid_mapping_crs_wkt(): + """Test reading CRS from grid_mapping coordinate's crs_wkt attribute.""" + data = xarray.DataArray(numpy.random.rand(10, 10), dims=["y", "x"]) + data.coords["spatial_ref"] = xarray.Variable((), 0) + data.coords["spatial_ref"].attrs["crs_wkt"] = CRS.from_epsg(4326).to_wkt() + + crs = cf.read_crs(data, grid_mapping="spatial_ref") + assert crs is not None + assert crs == CRS.from_epsg(4326) + + +def test_read_crs__from_legacy_attrs(): + """Test reading CRS from object's attrs (legacy, not CF convention). + + The 'crs' attribute is not part of CF convention but is supported + for backwards compatibility via the auto-detect method. + """ + data = xarray.DataArray(numpy.random.rand(10, 10), dims=["y", "x"]) + data.attrs["crs"] = "EPSG:4326" + + # CF convention should NOT find this + crs = cf.read_crs(data) + assert crs is None + + # Auto-detect should find it + crs = read_crs_auto(data) + assert crs is not None + assert crs == CRS.from_epsg(4326) + + +def test_read_crs__from_legacy_attrs_with_missing_grid_mapping(): + """Test reading CRS from attrs when grid_mapping doesn't exist. + + This tests a common case where rioxarray's grid_mapping property returns + "spatial_ref" as a default, but the coordinate doesn't actually exist. + The CRS should still be found via auto-detect. + """ + data = xarray.DataArray(numpy.random.rand(10, 10), dims=["y", "x"]) + data.attrs["crs"] = "EPSG:4326" + + # CF convention should NOT find this + crs = cf.read_crs(data, grid_mapping="spatial_ref") + assert crs is None + + # Auto-detect should find it + crs = read_crs_auto(data, grid_mapping="spatial_ref") + assert crs is not None + assert crs == CRS.from_epsg(4326) + + +def test_read_crs__not_found(): + """Test that None is returned when no CRS is found.""" + data = xarray.DataArray(numpy.random.rand(10, 10), dims=["y", "x"]) + + crs = cf.read_crs(data) + assert crs is None + + +def test_read_transform__from_geotransform(): + """Test reading transform from GeoTransform attribute.""" + data = xarray.DataArray(numpy.random.rand(10, 10), dims=["y", "x"]) + data.coords["spatial_ref"] = xarray.Variable((), 0) + # GeoTransform format: [c, a, b, f, d, e] (GDAL format) + data.coords["spatial_ref"].attrs["GeoTransform"] = "0.0 1.0 0.0 10.0 0.0 -1.0" + + transform = cf.read_transform(data, grid_mapping="spatial_ref") + assert transform is not None + assert transform == Affine(1.0, 0.0, 0.0, 0.0, -1.0, 10.0) + + +def test_read_transform__from_legacy_attrs(): + """Test reading transform from object's attrs (legacy, not CF convention). + + The 'transform' attribute is not part of CF convention but is supported + for backwards compatibility via the auto-detect method. + """ + data = xarray.DataArray(numpy.random.rand(10, 10), dims=["y", "x"]) + # Transform stored as list in attrs + data.attrs["transform"] = [1.0, 0.0, 0.0, 0.0, -1.0, 10.0] + + # CF convention should NOT find this + transform = cf.read_transform(data) + assert transform is None + + # Auto-detect should find it + transform = read_transform_auto(data) + assert transform is not None + assert transform == Affine(1.0, 0.0, 0.0, 0.0, -1.0, 10.0) + + +def test_read_transform__not_found(): + """Test that None is returned when no transform is found.""" + data = xarray.DataArray(numpy.random.rand(10, 10), dims=["y", "x"]) + + transform = cf.read_transform(data) + assert transform is None + + +def test_read_spatial_dimensions__xy(): + """Test detecting x/y dimension names.""" + data = xarray.DataArray(numpy.random.rand(10, 10), dims=["y", "x"]) + + dims = cf.read_spatial_dimensions(data) + assert dims == ("y", "x") + + +def test_read_spatial_dimensions__lonlat(): + """Test detecting longitude/latitude dimension names.""" + data = xarray.DataArray(numpy.random.rand(10, 10), dims=["latitude", "longitude"]) + + dims = cf.read_spatial_dimensions(data) + assert dims == ("latitude", "longitude") + + +def test_read_spatial_dimensions__cf_axis(): + """Test detecting dimensions from CF axis attributes.""" + data = xarray.DataArray( + numpy.random.rand(10, 10), + dims=["row", "col"], + coords={ + "row": ("row", numpy.arange(10)), + "col": ("col", numpy.arange(10)), + }, + ) + data.coords["row"].attrs["axis"] = "Y" + data.coords["col"].attrs["axis"] = "X" + + dims = cf.read_spatial_dimensions(data) + assert dims == ("row", "col") + + +def test_read_spatial_dimensions__cf_standard_name(): + """Test detecting dimensions from CF standard_name attributes.""" + data = xarray.DataArray( + numpy.random.rand(10, 10), + dims=["lat", "lon"], + coords={ + "lat": ("lat", numpy.arange(10)), + "lon": ("lon", numpy.arange(10)), + }, + ) + data.coords["lat"].attrs["standard_name"] = "latitude" + data.coords["lon"].attrs["standard_name"] = "longitude" + + dims = cf.read_spatial_dimensions(data) + assert dims == ("lat", "lon") + + +def test_read_spatial_dimensions__not_found(): + """Test that None is returned when spatial dimensions are not found.""" + data = xarray.DataArray(numpy.random.rand(10, 10), dims=["a", "b"]) + + dims = cf.read_spatial_dimensions(data) + assert dims is None + + +def test_write_crs(): + """Test writing CRS to a DataArray.""" + data = xarray.DataArray(numpy.random.rand(10, 10), dims=["y", "x"]) + crs = CRS.from_epsg(4326) + + result = cf.write_crs(data, crs=crs, grid_mapping_name="spatial_ref") + + assert "spatial_ref" in result.coords + assert result.coords["spatial_ref"].attrs["spatial_ref"] == crs.to_wkt() + assert result.coords["spatial_ref"].attrs["crs_wkt"] == crs.to_wkt() + + +def test_write_transform(): + """Test writing transform to a DataArray.""" + data = xarray.DataArray(numpy.random.rand(10, 10), dims=["y", "x"]) + transform = Affine(1.0, 0.0, 0.0, 0.0, -1.0, 10.0) + + result = cf.write_transform( + data, transform=transform, grid_mapping_name="spatial_ref" + ) + + assert "spatial_ref" in result.coords + assert "GeoTransform" in result.coords["spatial_ref"].attrs + assert ( + result.coords["spatial_ref"].attrs["GeoTransform"] + == "0.0 1.0 0.0 10.0 0.0 -1.0" + ) diff --git a/test/unit/test_options.py b/test/unit/test_options.py index 4b7388f5..704b4ae7 100644 --- a/test/unit/test_options.py +++ b/test/unit/test_options.py @@ -1,7 +1,8 @@ import pytest from rioxarray import set_options -from rioxarray._options import EXPORT_GRID_MAPPING, get_option +from rioxarray._options import CONVENTION, EXPORT_GRID_MAPPING, get_option +from rioxarray.enum import Convention def test_set_options__contextmanager(): @@ -37,3 +38,35 @@ def test_set_options__invalid_value(): ): with set_options(export_grid_mapping=12345): pass + + +def test_set_options__convention_default(): + """Test that convention defaults to None.""" + assert get_option(CONVENTION) is None + + +def test_set_options__convention_cf(): + """Test setting convention to CF.""" + assert get_option(CONVENTION) is None + with set_options(convention=Convention.CF): + assert get_option(CONVENTION) is Convention.CF + assert get_option(CONVENTION) is None + + +def test_set_options__convention_none(): + """Test setting convention back to None.""" + with set_options(convention=Convention.CF): + assert get_option(CONVENTION) is Convention.CF + with set_options(convention=None): + assert get_option(CONVENTION) is None + assert get_option(CONVENTION) is Convention.CF + + +def test_set_options__convention_invalid(): + """Test that invalid convention values raise error.""" + with pytest.raises( + ValueError, + match="option 'convention' gave an invalid value: 'invalid'.", + ): + with set_options(convention="invalid"): + pass