diff --git a/src/parcels/_core/fieldset.py b/src/parcels/_core/fieldset.py index 43aefd555..45310f2cb 100644 --- a/src/parcels/_core/fieldset.py +++ b/src/parcels/_core/fieldset.py @@ -9,12 +9,18 @@ import uxarray as ux import xarray as xr +import parcels._typing as ptyping from parcels._core.field import Field, VectorField -from parcels._core.model import CONSTANT_FIELD_MODELS, ModelData, StructuredModelData, UnstructuredModelData +from parcels._core.model import ( + CONSTANT_FIELD_MODELS, + ModelData, + StructuredModelData, + UnstructuredModelData, +) from parcels._core.utils.string import _assert_str_and_python_varname from parcels._core.utils.time import get_datetime_type_calendar from parcels._core.utils.time import is_compatible as datetime_is_compatible -from parcels._typing import Mesh +from parcels._python import NOTSET, NotSetType from parcels.interpolators import ( XConstantField, ) @@ -144,7 +150,7 @@ def add_field(self, field: Field, name: str | None = None): self.fields[name] = field - def add_constant_field(self, name: str, value, mesh: Mesh = "spherical"): + def add_constant_field(self, name: str, value, mesh: ptyping.Mesh = "spherical"): """Wrapper function to add a Field that is constant in space, useful e.g. when using constant horizontal diffusivity @@ -201,7 +207,12 @@ def gridset(self) -> list[BaseGrid]: return grids @classmethod - def from_ugrid_conventions(cls, ds: ux.UxDataset, mesh: str = "spherical"): + def from_ugrid_conventions( + cls, + ds: ux.UxDataset, + mesh: str = "spherical", + vector_fields: ptyping.VectorFields | NotSetType = NOTSET, + ): """Create a FieldSet from a Parcels compliant uxarray.UxDataset. This is the primary ingestion method in Parcels for structured grid datasets. @@ -215,6 +226,10 @@ def from_ugrid_conventions(cls, ds: ux.UxDataset, mesh: str = "spherical"): ---------- ds : uxarray.UxDataset uxarray.UxDataset as obtained from the uxarray package but with appropriate named vertical dimensions + vector_fields : Mapping[str, tuple[str, ...]], optional + Mapping of vector field names to tuples of component variable names in the dataset. + For example, ``{"UV": ("U", "V"), "UVW": ("U", "V", "W")}``. + If omitted (default), vector fields are auto-discovered from standard variable names (``U``/``V``/``W``). Returns ------- @@ -225,12 +240,15 @@ def from_ugrid_conventions(cls, ds: ux.UxDataset, mesh: str = "spherical"): ----- See https://ugrid-conventions.github.io/ugrid-conventions/ for more information on the UGRID conventions. """ - model = UnstructuredModelData.from_ugrid_conventions(ds, mesh) + model = UnstructuredModelData.from_ugrid_conventions(ds, mesh, vector_fields) return cls([model]) @classmethod def from_sgrid_conventions( - cls, ds: xr.Dataset, mesh: Mesh | None = None + cls, + ds: xr.Dataset, + mesh: ptyping.Mesh | None = None, + vector_fields: ptyping.VectorFields | NotSetType = NOTSET, ): # TODO: Update mesh to be discovered from the dataset metadata """Create a FieldSet from a dataset using SGRID convention metadata. @@ -245,6 +263,10 @@ def from_sgrid_conventions( mesh : str String indicating the type of mesh coordinates used during velocity interpolation. Options are "spherical" or "flat". + vector_fields : Mapping[str, tuple[str, ...]], optional + Mapping of vector field names to tuples of component variable names in the dataset. + For example, ``{"UV": ("U", "V"), "UVW": ("U", "V", "W")}``. + If omitted (default), vector fields are auto-discovered from standard variable names (``U``/``V``/``W``). Returns ------- @@ -259,7 +281,7 @@ def from_sgrid_conventions( See https://sgrid.github.io/sgrid/ for more information on the SGRID conventions. """ - model = StructuredModelData.from_sgrid_conventions(ds, mesh) + model = StructuredModelData.from_sgrid_conventions(ds, mesh, vector_fields) return cls([model]) @@ -356,9 +378,3 @@ def _format_calendar_error_message(field: Field | VectorField, reference_datetim ], "W": ["upward_sea_water_velocity", "vertical_sea_water_velocity"], } - - -def _is_agrid(ds: xr.Dataset) -> bool: - # check if U and V are defined on the same dimensions - # if yes, interpret as A grid - return set(ds["U"].dims) == set(ds["V"].dims) diff --git a/src/parcels/_core/model.py b/src/parcels/_core/model.py index 2040ca14a..c44026370 100644 --- a/src/parcels/_core/model.py +++ b/src/parcels/_core/model.py @@ -1,6 +1,7 @@ from __future__ import annotations from abc import ABC, abstractmethod +from collections.abc import Hashable, Sequence from typing import Any, Self import cf_xarray # noqa: F401 @@ -8,6 +9,7 @@ import xarray as xr import parcels._sgrid as sgrid +import parcels._typing as ptyping from parcels._core.basegrid import BaseGrid from parcels._core.field import Field, VectorField from parcels._core.utils.time import TimeInterval @@ -18,6 +20,7 @@ assert_all_field_dims_have_axis, # noqa: F401, leave import for now until decision is made # TODO v4: Make decision ) from parcels._logger import logger +from parcels._python import NOTSET, NotSetType from parcels._typing import Mesh from parcels.convert import _ds_rename_using_standard_names from parcels.interpolators import ( @@ -37,6 +40,7 @@ class ModelData(ABC): data: Any grid: BaseGrid field_to_interpolator: dict[str, ScalarInterpolator | VectorInterpolator] + vector_field_components: ptyping.VectorFields @abstractmethod def construct_fields(self) -> list[Field | VectorField]: ... @@ -79,7 +83,7 @@ def preprocess_sgrid_model_data(ds: xr.Dataset) -> xr.Dataset: class StructuredModelData(ModelData): - def __init__(self, data: xr.Dataset, mesh: Mesh): + def __init__(self, data: xr.Dataset, mesh: Mesh, vector_field_components: ptyping.VectorFields): if not isinstance(data, xr.Dataset): raise ValueError(f"Expected `data` to be an xarray.Dataset . Got {type(data)}") @@ -88,6 +92,7 @@ def __init__(self, data: xr.Dataset, mesh: Mesh): self.data = data self.grid = grid + self.vector_field_components = vector_field_components self.field_to_interpolator = {} self._fields: list[Field | VectorField] | None = None self.assert_valid_model_data() @@ -110,30 +115,25 @@ def construct_fields(self) -> list[Field | VectorField]: single_fields: dict[str, Field] = {} vector_fields: dict[str, VectorField] = {} scalar_field_names = self.scalar_field_names - if "U" in scalar_field_names and "V" in scalar_field_names: - interp_method = XLinear_Velocity() if _is_agrid(self.data) else CGrid_Velocity() - single_fields["U"] = Field("U", self) - single_fields["V"] = Field("V", self) - vector_fields["UV"] = VectorField("UV", single_fields["U"], single_fields["V"], interp_method=interp_method) - - if "W" in scalar_field_names: - single_fields["W"] = Field("W", self) - vector_fields["UVW"] = VectorField( - "UVW", - single_fields["U"], - single_fields["V"], - single_fields["W"], - interp_method=interp_method, - ) - fields: dict[str, Field | VectorField] = {**single_fields, **vector_fields} - for varname in set(scalar_field_names) - set(fields.keys()): - fields[varname] = Field(str(varname), self) + for varname in set(scalar_field_names): + single_fields[varname] = Field(str(varname), self) + + for vfield_name, components in self.vector_field_components.items(): + interp_method = ( + XLinear_Velocity() if _is_agrid(self.data, u=components[0], v=components[1]) else CGrid_Velocity() + ) + + component_fields = [single_fields[name] for name in components] + vector_fields[vfield_name] = VectorField(vfield_name, *component_fields, interp_method=interp_method) # type:ignore[misc,arg-type] + fields: dict[str, Field | VectorField] = {**single_fields, **vector_fields} return list(fields.values()) @classmethod - def from_sgrid_conventions(cls, ds: xr.Dataset, mesh: Mesh | None = None) -> Self: + def from_sgrid_conventions( + cls, ds: xr.Dataset, mesh: Mesh | None, vector_fields: ptyping.VectorFields | NotSetType + ) -> Self: ds = ds.copy() if mesh is None: mesh = _get_mesh_type_from_sgrid_dataset(ds) @@ -160,7 +160,10 @@ def from_sgrid_conventions(cls, ds: xr.Dataset, mesh: Mesh | None = None) -> Sel # ds["lon"] = ds[node_dimensions[0]] # ds["lat"] = ds[node_dimensions[1]] - model = cls(ds, mesh=mesh) + vector_fields = resolve_vector_fields(ds, vector_fields) + assert_valid_vector_fields(ds, vector_fields) + + model = cls(ds, mesh=mesh, vector_field_components=vector_fields) model._fields = model.construct_fields() for f in model._fields: if isinstance(f, Field): @@ -168,6 +171,45 @@ def from_sgrid_conventions(cls, ds: xr.Dataset, mesh: Mesh | None = None) -> Sel return model +def resolve_vector_fields(ds: xr.Dataset, vector_fields: ptyping.VectorFields | NotSetType) -> ptyping.VectorFields: + if vector_fields is NOTSET: # i.e., the default vectorfield discovery behaviour + return _default_vector_field_components(list(ds.data_vars)) + return vector_fields + + +def assert_valid_vector_fields(ds: xr.Dataset, vector_fields: ptyping.VectorFields) -> None: + if not isinstance(vector_fields, dict): + raise ValueError(f"vector_fields must be a dictionary. Got {type(vector_fields)=!r}.") + + for vfield_name, components in vector_fields.items(): + if not isinstance(vfield_name, str): + raise ValueError( + f"Invalid `vector_fields` argument. Vector field name in `vector_fields` should be a string. Got field name {vfield_name!r}." + ) + if not (2 <= len(components) <= 3): + raise ValueError( + f"Invalid `vector_fields` argument. Vector fields must have either 2 or 3 components. Vector field {vfield_name} has {len(components)} components." + ) + for c in components: + if not isinstance(c, str): + raise ValueError( + f"Invalid `vector_fields` argument. Component names must be strings. Got component name of value {c!r}." + ) + + assert_vector_field_components_in_dataset(ds, vector_fields) + return + + +def assert_vector_field_components_in_dataset(ds: xr.Dataset, vector_fields: ptyping.VectorFields) -> None: + for components in vector_fields.values(): + for c in components: + if c not in ds.data_vars: + raise ValueError( + f"Field component '{c}' not present in the source dataset, but is listed in {vector_fields=!r}. This component cannot be used in this mapping." + ) + return + + CONSTANT_FIELD_MODELS = { mesh: StructuredModelData.from_sgrid_conventions( xr.Dataset( @@ -191,13 +233,14 @@ def from_sgrid_conventions(cls, ds: xr.Dataset, mesh: Mesh | None = None) -> Sel ), ), mesh=mesh, # type:ignore + vector_fields={}, ) for mesh in ["flat", "spherical"] } class UnstructuredModelData(ModelData): - def __init__(self, data: ux.UxDataset, grid: UxGrid): + def __init__(self, data: ux.UxDataset, grid: UxGrid, vector_field_components: ptyping.VectorFields): if not isinstance(data, ux.UxDataset): raise ValueError(f"Expected `data` to be an uxarray.UxDataset . Got {type(data)}") @@ -206,6 +249,7 @@ def __init__(self, data: ux.UxDataset, grid: UxGrid): self.data = data self.grid = grid + self.vector_field_components = vector_field_components self.field_to_interpolator = {} self._fields: list[Field | VectorField] | None = None @@ -213,21 +257,17 @@ def construct_fields(self) -> list[Field | VectorField]: single_fields: dict[str, Field] = {} vector_fields: dict[str, VectorField] = {} scalar_field_names = self.scalar_field_names - if "U" in scalar_field_names and "V" in scalar_field_names: - single_fields["U"] = Field("U", self) - single_fields["V"] = Field("V", self) - vector_fields["UV"] = VectorField("UV", single_fields["U"], single_fields["V"], interp_method=Ux_Velocity()) - - if "W" in scalar_field_names: - single_fields["W"] = Field("W", self) - vector_fields["UVW"] = VectorField( - "UVW", single_fields["U"], single_fields["V"], single_fields["W"], interp_method=Ux_Velocity() - ) - fields: dict[str, Field | VectorField] = {**single_fields, **vector_fields} - for varname in set(scalar_field_names) - set(single_fields.keys()): - fields[varname] = Field(str(varname), self) + for varname in set(scalar_field_names): + single_fields[varname] = Field(str(varname), self) + for vfield_name, components in self.vector_field_components.items(): + interp_method = Ux_Velocity() + + component_fields = [single_fields[name] for name in components] + vector_fields[vfield_name] = VectorField(vfield_name, *component_fields, interp_method=interp_method) # type:ignore[misc, arg-type] + + fields: dict[str, Field | VectorField] = {**single_fields, **vector_fields} return list(fields.values()) def assert_valid_field_data(self, field_data: ux.UxDataArray) -> None: @@ -239,7 +279,7 @@ def scalar_field_names(self) -> list[str]: return list(self.data.data_vars) @classmethod - def from_ugrid_conventions(cls, ds: ux.UxDataset, mesh: str = "spherical"): + def from_ugrid_conventions(cls, ds: ux.UxDataset, mesh: Mesh, vector_fields: ptyping.VectorFields | NotSetType): ds_dims = list(ds.dims) if not all(dim in ds_dims for dim in ["time", "zf", "zc"]): raise ValueError( @@ -248,7 +288,11 @@ def from_ugrid_conventions(cls, ds: ux.UxDataset, mesh: str = "spherical"): grid = UxGrid(ds.uxgrid, z=ds.coords["zf"], mesh=mesh) ds = _discover_ux_U_and_V(ds) - model = cls(ds, grid) + + vector_fields = resolve_vector_fields(ds, vector_fields) + assert_valid_vector_fields(ds, vector_fields) + + model = cls(ds, grid, vector_fields) model._fields = model.construct_fields() for f in model._fields: if isinstance(f, Field): @@ -276,6 +320,17 @@ def _get_mesh_type_from_sgrid_dataset(ds_sgrid: xr.Dataset) -> Mesh: return "spherical" if _is_coordinate_in_degrees(ds_sgrid[fpoint_x]) else "flat" +def _default_vector_field_components(data_vars: Sequence[Hashable]) -> ptyping.VectorFields: + vars = set(data_vars) + ret: ptyping.VectorFields = {} + + if {"U", "V"}.issubset(vars): + ret["UV"] = ("U", "V") + if {"U", "V", "W"}.issubset(vars): + ret["UVW"] = ("U", "V", "W") + return ret + + def _is_coordinate_in_degrees(da: xr.DataArray) -> bool: units = da.attrs.get("units") if units is None: @@ -366,10 +421,10 @@ def _select_uxinterpolator(da: ux.UxDataArray): return None -def _is_agrid(ds: xr.Dataset) -> bool: +def _is_agrid(ds: xr.Dataset, u: str, v: str) -> bool: # check if U and V are defined on the same dimensions # if yes, interpret as A grid - return set(ds["U"].dims) == set(ds["V"].dims) + return set(ds[u].dims) == set(ds[v].dims) def _get_time_interval(data: xr.DataArray | ux.UxDataArray) -> TimeInterval | None: diff --git a/src/parcels/_python.py b/src/parcels/_python.py index 81db6ade4..4f8bf106e 100644 --- a/src/parcels/_python.py +++ b/src/parcels/_python.py @@ -1,4 +1,5 @@ # Generic Python helpers +import enum import inspect from collections.abc import Callable, Mapping from typing import TypeVar @@ -6,6 +7,9 @@ K = TypeVar("K") V = TypeVar("V") +NotSetType = enum.Enum("NotSetType", "VALUE") +NOTSET = NotSetType.VALUE + def isinstance_noimport(obj, class_or_tuple): """A version of isinstance that does not require importing the class. diff --git a/src/parcels/_typing.py b/src/parcels/_typing.py index 18e8aa55f..e8993cb05 100644 --- a/src/parcels/_typing.py +++ b/src/parcels/_typing.py @@ -47,6 +47,7 @@ CfAxis = XgcmAxisDirection XgcmAxisPosition = Literal["center", "left", "right", "inner", "outer"] XgcmAxes = Mapping[XgcmAxisDirection, "xgcm.Axis"] +VectorFields = dict[str, tuple[str, str] | tuple[str, str, str]] def _is_xarray_object(obj): # with no imports diff --git a/tests/test_field.py b/tests/test_field.py index f3893fbcc..5ae9ed794 100644 --- a/tests/test_field.py +++ b/tests/test_field.py @@ -12,6 +12,7 @@ from parcels._datasets.structured.generic import datasets as datasets_structured from parcels._datasets.unstructured.generic import _ux_constant_flow_face_centered_2D from parcels._datasets.unstructured.generic import datasets as datasets_unstructured +from parcels._python import NOTSET from parcels.interpolators import ( UxConstantFaceConstantZC, ) @@ -19,7 +20,7 @@ def test_field_init_param_types(): data = datasets_structured["ds_2d_left"] - model = StructuredModelData.from_sgrid_conventions(data, mesh="flat") + model = StructuredModelData.from_sgrid_conventions(data, mesh="flat", vector_fields=NOTSET) with pytest.raises(TypeError, match="Expected a string for variable name, got int instead."): Field(name=123, model=model) @@ -52,7 +53,7 @@ def test_field_init_fail_on_float_time_dim(): ds["time"].attrs, ) - model = StructuredModelData.from_sgrid_conventions(ds, mesh="flat") + model = StructuredModelData.from_sgrid_conventions(ds, mesh="flat", vector_fields=NOTSET) with pytest.raises( ValueError, match=r"Are you sure that the time dimension on the xarray dataset is stored as timedelta, datetime or cftime datetime objects\?", @@ -64,7 +65,7 @@ def test_field_init_fail_on_float_time_dim(): def test_field_time_interval(): """Test that field.time_interval delegates correctly to model.time_interval.""" data = datasets_structured["ds_2d_left"] - model = StructuredModelData.from_sgrid_conventions(data, mesh="flat") + model = StructuredModelData.from_sgrid_conventions(data, mesh="flat", vector_fields=NOTSET) field = Field(name="data_g", model=model) assert field.time_interval.left == np.datetime64("2000-01-01") assert field.time_interval.right == np.datetime64("2001-01-01") @@ -77,7 +78,7 @@ def test_vectorfield_init_different_time_intervals(): def test_field_invalid_interpolator(): ds = datasets_structured["ds_2d_left"] - model = StructuredModelData.from_sgrid_conventions(ds, mesh="flat") + model = StructuredModelData.from_sgrid_conventions(ds, mesh="flat", vector_fields=NOTSET) field = Field(name="data_g", model=model) def not_a_scalar_interpolator(particle_positions, grid_positions, field): @@ -90,7 +91,7 @@ def not_a_scalar_interpolator(particle_positions, grid_positions, field): def test_vectorfield_invalid_interpolator(): ds = datasets_structured["ds_2d_left"] - model = StructuredModelData.from_sgrid_conventions(ds, mesh="flat") + model = StructuredModelData.from_sgrid_conventions(ds, mesh="flat", vector_fields=NOTSET) fields = {f.name: f for f in model.construct_fields()} U = fields["U_A_grid"] V = fields["V_A_grid"] diff --git a/tests/test_fieldset.py b/tests/test_fieldset.py index 8cfdd38d5..8e2f65b6f 100644 --- a/tests/test_fieldset.py +++ b/tests/test_fieldset.py @@ -8,6 +8,7 @@ from parcels import Field, ParticleFile, ParticleSet, XGrid, convert from parcels._core.fieldset import FieldSet, _datetime_to_msg +from parcels._core.model import _default_vector_field_components from parcels._datasets.structured.generic import datasets as datasets_structured from parcels._datasets.structured.generic import datasets_sgrid from parcels._datasets.unstructured.generic import datasets as datasets_unstructured @@ -96,7 +97,110 @@ def test_fieldset_from_structured_generic_datasets(ds): assert len(fieldset.gridset) == 1 -def test_fieldset_gridset_multiple_grids(): ... +@pytest.mark.parametrize( + "vector_fields,ctx", + [ + pytest.param( + {"UV": ("U",)}, + pytest.raises(ValueError, match="must have either 2 or 3 components"), + id="single-component", + ), + pytest.param( + {"UV": ("U", "missing")}, + pytest.raises(ValueError, match="not present in the source dataset"), + id="component-not-in-dataset", + ), + pytest.param( + {"UV": ("U", "U", "U", "U")}, + pytest.raises(ValueError, match="must have either 2 or 3 components"), + id="too-many-components", + ), + pytest.param( + None, + pytest.raises(ValueError, match="vector_fields must be a dictionary"), + id="None", + ), + ], +) +def test_fieldset_invalid_vector_fields(vector_fields, ctx): + ds = datasets_structured["ds_2d_left"][["U_A_grid", "V_A_grid", "grid"]].rename({"U_A_grid": "U", "V_A_grid": "V"}) + + with ctx: + FieldSet.from_sgrid_conventions(ds, mesh="flat", vector_fields=vector_fields) + + +def test_fieldset_structured_vectorfield_default(): + ds = datasets_structured["ds_2d_left"][["U_A_grid", "V_A_grid", "grid"]].rename({"U_A_grid": "U", "V_A_grid": "V"}) + + fset = FieldSet.from_sgrid_conventions(ds, mesh="flat") + + assert "U" in fset.fields + assert "V" in fset.fields + assert "UV" in fset.fields + + +def test_fieldset_structured_vectorfield_custom(): + ds = datasets_structured["ds_2d_left"][["U_A_grid", "V_A_grid", "grid"]].rename({"U_A_grid": "U", "V_A_grid": "V"}) + ds = ds.rename({"U": "U_wind", "V": "V_wind"}) + + fset = FieldSet.from_sgrid_conventions(ds, mesh="flat", vector_fields={"UV_wind": ("U_wind", "V_wind")}) + + assert "U_wind" in fset.fields + assert "V_wind" in fset.fields + assert "UV_wind" in fset.fields + + +def test_fieldset_structured_vectorfield_empty(): + ds = datasets_structured["ds_2d_left"][["U_A_grid", "V_A_grid", "grid"]].rename({"U_A_grid": "U", "V_A_grid": "V"}) + + fset = FieldSet.from_sgrid_conventions(ds, mesh="flat", vector_fields={}) + + assert "U" in fset.fields + assert "V" in fset.fields + assert "UV" not in fset.fields + + +def test_fieldset_unstructured_vectorfield_default(): + ds = datasets_unstructured["stommel_gyre_delaunay"] + fset = FieldSet.from_ugrid_conventions(ds, mesh="spherical") + + assert "U" in fset.fields + assert "V" in fset.fields + assert "UV" in fset.fields + + +def test_fieldset_unstructured_vectorfield_custom(): + ds = datasets_unstructured["stommel_gyre_delaunay"] + ds = ds.rename({"U": "U_wind", "V": "V_wind"}) + + fset = FieldSet.from_ugrid_conventions(ds, mesh="spherical", vector_fields={"UV_wind": ("U_wind", "V_wind")}) + + assert "U_wind" in fset.fields + assert "V_wind" in fset.fields + assert "UV_wind" in fset.fields + + +def test_fieldset_unstructured_vectorfield_empty(): + ds = datasets_unstructured["stommel_gyre_delaunay"] + + fset = FieldSet.from_ugrid_conventions(ds, mesh="spherical", vector_fields={}) + + assert "U" in fset.fields + assert "V" in fset.fields + assert "UV" not in fset.fields + + +@pytest.mark.parametrize( + "data_vars,expected", + [ + (["U", "V", "land_mask"], {"UV": ("U", "V")}), + (["U", "V", "W", "land_mask"], {"UV": ("U", "V"), "UVW": ("U", "V", "W")}), + (["field1", "field2", "field3"], {}), + ], +) +def test_default_vector_field_components(data_vars, expected): + got = _default_vector_field_components(data_vars) + assert got == expected # TODO restructure: use adding of fieldset notation to test this @@ -208,34 +312,38 @@ def test_fieldset_from_sgrid_conventions(ds_name): assert len(fieldset.fields) > 0 -def test_fieldset_add(): - """Test that two FieldSets can be combined with + (fset1 + fset2).""" - ds1 = datasets_structured["ds_2d_left"][["U_A_grid", "grid"]].rename({"U_A_grid": "U1"}) - ds2 = datasets_structured["ds_2d_left"][["V_A_grid", "grid"]].rename({"V_A_grid": "V2"}) +def test_fieldset_add_error_on_duplicate_fields(): + """Test that adding FieldSets with overlapping field names raises a ValueError.""" + ds1 = datasets_structured["ds_2d_left"][["U_A_grid", "V_A_grid", "grid"]].rename({"U_A_grid": "U", "V_A_grid": "V"}) + ds2 = ds1.copy() fset1 = FieldSet.from_sgrid_conventions(ds1, mesh="flat") fset2 = FieldSet.from_sgrid_conventions(ds2, mesh="flat") - fset = fset1 + fset2 - - assert len(fset.models) == len(fset1.models) + len(fset2.models) - assert "U1" in fset.fields - assert "V2" in fset.fields + with pytest.raises(ValueError, match="field names in common.*'U'"): + fset1 + fset2 -def test_fieldset_add_overlapping_fields(): - """Test that adding FieldSets with overlapping field names raises a ValueError.""" - ds1 = datasets_structured["ds_2d_left"][["U_A_grid", "grid"]].rename({"U_A_grid": "U"}) - ds2 = datasets_structured["ds_2d_left"][["V_A_grid", "grid"]].rename({"V_A_grid": "U"}) +def test_fieldset_add(): + """Test that two FieldSets can be combined with + (fset1 + fset2).""" + ds1 = datasets_structured["ds_2d_left"][["U_A_grid", "V_A_grid", "grid"]].rename({"U_A_grid": "U", "V_A_grid": "V"}) + ds2 = datasets_structured["ds_2d_left"][["U_A_grid", "V_A_grid", "grid"]].rename( + {"U_A_grid": "U_wind", "V_A_grid": "V_wind"} + ) fset1 = FieldSet.from_sgrid_conventions(ds1, mesh="flat") - fset2 = FieldSet.from_sgrid_conventions(ds2, mesh="flat") + fset2 = FieldSet.from_sgrid_conventions(ds2, mesh="flat", vector_fields={"UV_wind": ("U_wind", "V_wind")}) - with pytest.raises(ValueError, match="field names in common.*'U'"): - fset1 + fset2 + fset = fset1 + fset2 + + assert len(fset.models) == len(fset1.models) + len(fset2.models) + + fields_before = list(fset1.fields.keys()) + list(fset2.fields.keys()) + assert len(fields_before) == len(fset.fields) + assert set(fields_before) == set(fset.fields.keys()) -def test_fieldset_add_overlapping_context_values(): +def test_fieldset_add_error_on_duplicate_context_values(): """Test that adding FieldSets with overlapping context value names raises a ValueError.""" ds1 = datasets_structured["ds_2d_left"][["U_A_grid", "grid"]].rename({"U_A_grid": "U1"}) ds2 = datasets_structured["ds_2d_left"][["V_A_grid", "grid"]].rename({"V_A_grid": "V2"})