Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions packages/data/src/pyearthtools/data/transforms/coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,25 +412,37 @@ def weak_cast_to_int(value):


class Flatten(Transform):
"""Flatten a coordinate in a dataset into seperate variables"""
"""Operation to flatten a coordinate in a dataset, putting the data at each value of the coordinate into a separate
data variable."""

def __init__(
self, coordinate: Hashable | list[Hashable] | tuple[Hashable], *extra_coordinates, skip_missing: bool = False
):
"""
Flatten a coordinate in a dataset with each point being made a seperate data var

Flatten a coordinate in an xarray Dataset, putting the data at each value of the coordinate into a separate
data variable.

The output data variables will be named "<old variable name><value of coordinate>". For example, if the input
Dataset has a variable "t" and it is flattened along the coordinate "pressure_level" which has values
[100, 200, 500], then the output Dataset will have variables called t100, t200 and t500.

If more than one coordinate is flattened, the output data variable names will concatenate the values of each
coordinate.

Args:
coordinate (Hashable | list[Hashable] | tuple[Hashable] | None):
Coordinates to flatten, either str or list of candidates.
*extra_coordinates (optional):
Arguments form of `coordinate`.
skip_missing (bool, optional):
Whether to skip data without the dims. Defaults to False
Whether to skip data that does not have any of the listed coordinates. If True, will return such data
unchanged. Defaults to False.

Raises:
ValueError:
If invalid number of coordinates found

"""
super().__init__()
self.record_initialisation()
Expand Down Expand Up @@ -458,7 +470,7 @@ def apply(self, dataset: xr.Dataset) -> xr.Dataset:
)

elif len(discovered_coord) > 1:
transforms = TransformCollection(*[flatten(coord) for coord in discovered_coord])
transforms = TransformCollection(*[Flatten(coord) for coord in discovered_coord])
return transforms(dataset)

discovered_coord = str(discovered_coord[0])
Expand Down
103 changes: 103 additions & 0 deletions packages/data/tests/data/transforms/test_data_coordinates.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright Commonwealth of Australia, Bureau of Meteorology 2025.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from pyearthtools.data.transforms import coordinates
import xarray as xr
import numpy as np
import pytest

SIMPLE_DA1 = xr.DataArray(
[
[
[0.9, 0.0, 5],
[0.7, 1.4, 2.8],
[0.4, 0.5, 2.3],
],
[
[1.9, 1.0, 1.5],
[1.7, 2.4, 1.1],
[1.4, 1.5, 3.3],
],
],
coords=[[10, 20], [0, 1, 2], [5, 6, 7]],
dims=["height", "lat", "lon"],
)

SIMPLE_DA2 = xr.DataArray(
[
[0.9, 0.0, 5],
[0.7, 1.4, 2.8],
[0.4, 0.5, 2.3],
],
coords=[[0, 1, 2], [5, 6, 7]],
dims=["lat", "lon"],
)

SIMPLE_DS1 = xr.Dataset({"Temperature": SIMPLE_DA1})
SIMPLE_DS2 = xr.Dataset({"Humidity": SIMPLE_DA1, "Temperature": SIMPLE_DA1, "WombatsPerKm2": SIMPLE_DA1})

COMPLICATED_DS1 = xr.Dataset({"Temperature": SIMPLE_DA1, "MSLP": SIMPLE_DA2})


def test_Flatten():
f = coordinates.Flatten(["height"])
output = f.apply(SIMPLE_DS2)
variables = list(output.keys())
for vbl in ["Temperature10", "Temperature20", "Humidity10", "Humidity20", "WombatsPerKm210", "WombatsPerKm220"]:
assert vbl in variables


def test_Flatten_2_coords():
f = coordinates.Flatten(["height", "lon"])
output = f.apply(SIMPLE_DS1)
variables = list(output.keys())
# Note that it's hard to predict which coordinate will be processed first.
try:
for vbl in [
"Temperature510",
"Temperature520",
"Temperature610",
"Temperature620",
"Temperature710",
"Temperature720",
]:
assert vbl in variables
except AssertionError:
for vbl in [
"Temperature105",
"Temperature205",
"Temperature106",
"Temperature206",
"Temperature107",
"Temperature207",
]:
assert vbl in variables


def test_Flatten_complicated_dataset():
"""Check that Flatten still works when the coordinate being flattened does not exist for all variables."""
f = coordinates.Flatten(["height"])
output = f.apply(COMPLICATED_DS1)
variables = list(output.keys())
for vbl in ["Temperature10", "Temperature20", "MSLP"]:
assert vbl in variables


def test_Flatten_skip_missing():
f = coordinates.Flatten(["scrupulosity"])
with pytest.raises(ValueError):
f.apply(SIMPLE_DS1)
f2 = coordinates.Flatten(["scrupulosity"], skip_missing=True)
output2 = f2.apply(SIMPLE_DS1)
assert output2 == SIMPLE_DS1, "When skip_missing=True, Datasets without the given coordinate pass unchanged."
13 changes: 6 additions & 7 deletions packages/pipeline/src/pyearthtools/pipeline/iterators.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,12 +224,11 @@ def __init__(
self.allowlist = set(allowlist)

self.blocklist = blocklist
if blocklist:
if blocklist:
self.blocklist = set(blocklist)

self._timerange = pyearthtools.data.TimeRange(start, end, interval)


def __iter__(self) -> Generator[pyearthtools.data.Petdt, None, None]:

# If in allowlist mode, yield only samples from the allow list
Expand All @@ -249,11 +248,11 @@ def __iter__(self) -> Generator[pyearthtools.data.Petdt, None, None]:
# If not filtering, yield everything
else:
for i in self._timerange:
yield i
yield i

def randomise(self, seed: Optional[int] = 42):
"""Randomise this interator"""
return DateRandomise(self, seed=seed)
return DateRandomise(self, seed=seed)


class DateRangeLimit(DateRange):
Expand All @@ -280,6 +279,7 @@ def __init__(self, start: str, interval: Any, num: int):
end = pyearthtools.data.Petdt(start) + (pyearthtools.data.TimeDelta(interval) * num)
super().__init__(start, str(end), interval)


class DateRandomise(Iterator):
"""
Wrap around another `Iterator` and randomly sample
Expand All @@ -306,11 +306,11 @@ def __init__(self, iterator: DateRange, seed: Union[int, None] = 42):

print("Calculated indexes")

if getattr(iterator, 'allowlist', None):
if getattr(iterator, "allowlist", None):
self.valid_times = [t for t in self.valid_times if t in iterator.allowlist]
print(len(self.valid_times))

if getattr(iterator, 'blocklist', None):
if getattr(iterator, "blocklist", None):
self.valid_times = [t for t in self.valid_times if t not in iterator.blocklist]
print(len(self.valid_times))

Expand All @@ -322,7 +322,6 @@ def __iter__(self):
yield key



class Randomise(Iterator):
"""
Wrap around another `Iterator` and randomly sample
Expand Down