diff --git a/notebooks/Project_config.ipynb b/notebooks/Project_config.ipynb deleted file mode 100644 index 565587a0..00000000 --- a/notebooks/Project_config.ipynb +++ /dev/null @@ -1,126 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from pathlib import Path\n", - "import os" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### **Set Up The Project Paths:** User input required" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Replace the PROJECT_HOME and ERA5LOWRES paths with your actual paths\n", - "ERA5LOWRES = Path(\"/path/to/your/data\") # Replace with the path to your ERA5LOWRES data here. Example: Path(\"/path/to/your/ERA5LOWRES\")\n", - "PROJECT_HOME = None # Replace with the path to your PROJECT_HOME here, or set to None to default to current directory. Example: Path(\"/path/to/your/PROJECT_HOME\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### **Path Setting Logic:** Can be ignored by the user" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Ensure PROJECT_HOME is set to a valid Path object\n", - "try:\n", - " if PROJECT_HOME is None:\n", - " PROJECT_HOME = Path.cwd()\n", - " elif not isinstance(PROJECT_HOME, Path):\n", - " raise TypeError(f\"PROJECT_HOME must be a Path object or None. Got: {type(PROJECT_HOME).__name__}\")\n", - "except Exception as e:\n", - " raise ValueError(f\"Error setting PROJECT_HOME: {e}\")\n", - "\n", - "# Validate the PROJECT_HOME path\n", - "try:\n", - " if not PROJECT_HOME.is_dir():\n", - " raise ValueError(f\"The provided PROJECT_HOME path '{PROJECT_HOME}' is not a valid directory. Please update it.\")\n", - "except ValueError as e:\n", - " raise ValueError(f\"Error validating PROJECT_HOME: {e}\")\n", - "\n", - "# Ensure ERA5LOWRES is set to a valid Path object\n", - "try:\n", - " if not isinstance(ERA5LOWRES, Path):\n", - " raise TypeError(f\"ERA5LOWRES must be a Path object. Got: {type(ERA5LOWRES).__name__}\")\n", - "except Exception as e:\n", - " raise ValueError(f\"Error setting ERA5LOWRES: {e}\")\n", - "\n", - "# Validate the ERA5LOWRES path\n", - "try:\n", - " if not ERA5LOWRES.is_dir():\n", - " raise ValueError(f\"The provided ERA5LOWRES path '{ERA5LOWRES}' does not exist or is not a directory. Please update it.\")\n", - "except ValueError as e:\n", - " raise ValueError(f\"Error validating ERA5LOWRES: {e}\")\n", - "\n", - "\n", - "# Set the environment variable if ERA5LOWRES is valid\n", - "if ERA5LOWRES:\n", - " os.environ[\"ERA5LOWRES\"] = str(ERA5LOWRES)\n", - "else:\n", - " print(\"ERA5LOWRES is not set. Please check your configuration.\")\n", - "\n", - "# Define paths in PROJECT_HOME for stats and caching\n", - "default_root_dir = PROJECT_HOME / \"cnn_training\" # Folder to save the trained models and other project outputs\n", - "stats_folder = PROJECT_HOME / \"cnn_training/stats\" # Folder to save estimated mean & standard deviation of fields\n", - "cache_folder = PROJECT_HOME / \"cnn_training/cache\" # Folders used to cache dataset processed by the pipeline\n", - "\n", - "# Ensure that the required directories exist or create them\n", - "stats_folder.mkdir(parents=True, exist_ok=True)\n", - "cache_folder.mkdir(parents=True, exist_ok=True)\n", - "default_root_dir.mkdir(parents=True, exist_ok=True)\n", - "\n", - "# Print paths for debugging\n", - "print(f\"Project Home: {PROJECT_HOME}\")\n", - "print(f\"Stats folder: {stats_folder}\")\n", - "print(f\"Cache folder: {cache_folder}\")\n", - "print(f\"Default root directory: {default_root_dir}\")\n", - "print(f\"ERA5LOWRES: {ERA5LOWRES}\")\n", - "\n", - "# Show contents of the ERA5LOWRES directory\n", - "if ERA5LOWRES:\n", - " print(f\"Contents of ERA5LOWRES:\")\n", - " !ls -l {ERA5LOWRES}" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "PET_tutorial", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.11" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/packages/data/src/pyearthtools/data/load.py b/packages/data/src/pyearthtools/data/load.py index 42f572ef..421d5f7c 100644 --- a/packages/data/src/pyearthtools/data/load.py +++ b/packages/data/src/pyearthtools/data/load.py @@ -42,11 +42,22 @@ def load(stream: Union[str, Path], **kwargs) -> "pyearthtools.data.Index": (pyearthtools.data.Index): Loaded Index """ + + # Check if the stream is not a Path or String and raise an error + if not isinstance(stream, (str, Path)): + raise TypeError( + f"Stream is not a Path or String {type(stream)} - {stream}." + ) + contents = None + # Check is the stream is a path if os.path.sep in str(stream): try: + # Check if the path is a directory if parse_path(stream).is_dir(): + + # Create a list of config files if found. stream = list( [ *Path(stream).glob("catalog.cat"), @@ -54,10 +65,9 @@ def load(stream: Union[str, Path], **kwargs) -> "pyearthtools.data.Index": *Path(stream).glob("*.cat"), *Path(stream).glob("*.edi"), ] - )[ - 0 - ] # Find default save file of index - + )[0] # Find default save file of index + + # Combine all files into a single variable contents = "".join(open(str(parse_path(stream))).readlines()) except FileNotFoundError as e: raise e @@ -66,6 +76,7 @@ def load(stream: Union[str, Path], **kwargs) -> "pyearthtools.data.Index": except IndexError: raise FileNotFoundError(f"No default catalog could be found at {stream!r}.") + # If the stream is not a path, check if it is a string and store it in contents. if contents is None: contents = str(stream) @@ -74,4 +85,5 @@ def load(stream: Union[str, Path], **kwargs) -> "pyearthtools.data.Index": contents = initialisation.update_contents(contents, **kwargs) + # Load and return the contents as yaml. return yaml.load(str(contents), initialisation.Loader) diff --git a/packages/data/tests/test_load.py b/packages/data/tests/test_load.py new file mode 100644 index 00000000..ad2658b0 --- /dev/null +++ b/packages/data/tests/test_load.py @@ -0,0 +1,100 @@ +# Copyright Commonwealth of Australia, Bureau of Meteorology 2025. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from unittest.mock import patch, mock_open + +import yaml + +from pyearthtools.data import load +from pyearthtools.data.utils import parse_path +from pathlib import Path + + + +def test_load_invalid_stream_type(): + """Test loading with an invalid type for the stream.""" + + with pytest.raises(TypeError): + load(1234) + +def test_load_file_not_found(): + """Test loading a file from a directory that does not exist.""" + + mock_file_path = "data/nonexistent_file.yaml" + with patch("pyearthtools.data.utils.parse_path", return_value=Path(mock_file_path)): + with pytest.raises(FileNotFoundError): + load(mock_file_path) + +def test_load_empty_directory(): + """Test loading from a directory with no matching catalog files.""" + + mock_empty_dir = "data/empty_dir" + + with patch("pyearthtools.data.utils.parse_path", return_value=Path(mock_empty_dir)), \ + patch("pathlib.Path.glob", return_value=[]): # Mock an empty directory + with pytest.raises(FileNotFoundError): + load(mock_empty_dir) + + +def test_update_contents_called_correctly(): + """Test that initialisation.update_contents is called with the correct arguments.""" + mock_contents = "key: value" + mock_updated_contents = "updated key: value" + + # Mock the `initialisation.update_contents` function + with patch("pyearthtools.utils.initialisation.update_contents", return_value=mock_updated_contents) as mock_update_contents: + # Mock the `yaml.load` function to avoid parsing errors + with patch("yaml.load", return_value={"key": "value"}): + # Call the `load` function with a string stream + result = load(mock_contents, extra_arg="test") + + # Assert that `initialisation.update_contents` was called with the correct arguments + mock_update_contents.assert_called_once_with(mock_contents, extra_arg="test") + + # Assert that the result of `load` matches the mocked updated contents + assert result == {"key": "value"} + + +def test_load(): + """Test load function""" + + mock_file_content = "key: value" + + # Mock dependencies + mock_open_function = mock_open(read_data=mock_file_content) + mock_parse_path = Path("valid_file.yaml") + mock_yaml_load = {mock_file_content} + + with patch("builtins.open", mock_open_function), \ + patch("os.path.sep", "/"), \ + patch("pyearthtools.data.utils.parse_path", return_value=mock_parse_path), \ + patch("yaml.load", return_value=mock_yaml_load): + + # Call the load.py load function with mocked dependencies. + result = load("valid_file.yaml") + + # Assert the result + assert result == {mock_file_content} + + +def test_load_invalid_yaml(): + """Test loading invalid YAML content.""" + + invalid_yaml_content = "invalid: yaml: content" + with patch("pyearthtools.utils.initialisation.update_contents", return_value=invalid_yaml_content), \ + patch("yaml.load", side_effect=yaml.YAMLError): + with pytest.raises(yaml.YAMLError): + load(invalid_yaml_content) +