Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 100 additions & 0 deletions src/parcels/_core/_windowed_array.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
"""Transparent rolling time-window cache for lazy (dask-backed) field data.

Assumptions / current limits:
* ``time`` is the leading dimension of the field (true for both the SGRID and
UGRID ingestion paths; the structured path transposes to ``(time, ...)``).
* Valid while the requested time indices stay within the resident window
(i.e. all particles share the clock). A sample that requests time indices
spanning more than the retained levels would force reloads.
"""

from __future__ import annotations

import numpy as np
import xarray as xr
from dask import is_dask_collection

# xarray / uxarray ``isel`` keyword arguments that are NOT dimension indexers.
_NON_INDEXER_KWARGS = frozenset({"drop", "missing_dims", "ignore_grid"})


class WindowedArray:
"""Wrap a lazy DataArray so ``isel`` loads/caches/evicts time levels as NumPy."""

def __init__(self, data: xr.DataArray, time_dim: str = "time", max_levels: int | None = None):
if data.dims[0] != time_dim:
raise ValueError(f"WindowedArray expects {time_dim!r} as the leading dimension, got {data.dims}")
self._data = data
self._tdim = time_dim
self._cache: dict[int, np.ndarray] = {} # time index -> NumPy slab (remaining dims)
self._max = max_levels
# diagnostics
self.loads = 0
self.bytes_read = 0
self._slab_bytes = int(np.prod(data.isel({time_dim: 0}).shape)) * data.dtype.itemsize

# -- transparency: forward everything we don't override -------------------
def __getattr__(self, name):
# __getattr__ only fires for misses; reach _data without recursing.
return getattr(object.__getattribute__(self, "_data"), name)

def __repr__(self):
return (
f"WindowedArray(time_dim={self._tdim!r}, cached_levels={sorted(self._cache)}, "
f"loads={self.loads})\n{self._data!r}"
)

# -- window management ----------------------------------------------------
def _read_level(self, lvl: int) -> np.ndarray:
"""Bulk, sequential read of one time level into NumPy (the dask->NumPy step)."""
return np.asarray(self._data.isel({self._tdim: int(lvl)}).values)

def _ensure(self, levels: np.ndarray) -> None:
for lvl in levels:
lvl = int(lvl)
if lvl not in self._cache:
self._cache[lvl] = self._read_level(lvl)
self.loads += 1
self.bytes_read += self._slab_bytes
# retire stale levels (the clock only moves forward across the window)
lo = int(np.min(levels))
for old in [k for k in self._cache if k < lo]:
del self._cache[old]
if self._max is not None and len(self._cache) > self._max:
for old in sorted(self._cache)[: len(self._cache) - self._max]:
del self._cache[old]

# -- intercepted indexing -------------------------------------------------
def isel(self, indexers: dict | None = None, **kwargs):
sel = dict(indexers) if indexers is not None else {}
sel.update({k: v for k, v in kwargs.items() if k not in _NON_INDEXER_KWARGS})

# no time selection -> nothing to window; preserve control kwargs
if self._tdim not in sel:
return self._data.isel(indexers, **kwargs)

t_ind = sel[self._tdim]
t_vals = np.asarray(t_ind.values if isinstance(t_ind, xr.DataArray) else t_ind)
levels = np.unique(t_vals)
self._ensure(levels)

# stack the resident levels into one small NumPy block; remap to local indices
block = np.stack([self._cache[int(lvl)] for lvl in levels]) # (nlevels, *rest)
nda = xr.DataArray(block, dims=self._data.dims) # NumPy-backed, original dim order
local = np.searchsorted(levels, t_vals)
sel[self._tdim] = xr.DataArray(local, dims=getattr(t_ind, "dims", ()))
return nda.isel(sel) # plain vectorised gather in NumPy (no ignore_grid needed)


def maybe_windowed(data: xr.DataArray, max_levels: int | None = None):
"""Wrap dask-backed, field data in a ``WindowedArray``; else pass through.

NumPy-backed fields (already resident) and fields without a leading ``time``
dimension are returned unchanged, so existing eager workflows are unaffected.
Already-wrapped data is returned unchanged.
"""
if isinstance(data, WindowedArray):
return data
if data.dims and data.dims[0] == "time" and is_dask_collection(data.data):
return WindowedArray(data, max_levels=max_levels)
return data
2 changes: 1 addition & 1 deletion src/parcels/_core/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def __init__(

@property
def data(self):
return self.model.data[self.name]
return self.model.field_data(self.name)

@property
def grid(self): # TODO PR: Remove in favour of referencing model grid directly
Expand Down
26 changes: 26 additions & 0 deletions src/parcels/_core/fieldset.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,32 @@ def add_field(self, field: Field, name: str | None = None):

self.fields[name] = field

def to_windowed_arrays(self, *, max_levels: int | None = None):
"""Wrap dask-backed field data in rolling time-window caches.

Opt-in optimization for forward-marching simulations where all particles
share a single clock. Delegates to each underlying model; dask-backed,
time-leading fields are served through a resident NumPy window (each time
level loaded once and evicted as the clock advances) instead of re-reading
chunks on every kernel step. NumPy-backed (eager) and non-time-leading
fields are left unchanged, and re-invoking is idempotent, so this is safe
to call more than once.

Parameters
----------
max_levels : int, optional
Cap on the number of time levels kept resident per field. ``None``
(default) retains every level the advancing clock still brackets.

Returns
-------
FieldSet
``self``, to allow chaining.
"""
for model in self.models:
model.to_windowed_arrays(max_levels=max_levels)
return self

def add_constant_field(self, name: str, value, mesh: Mesh = "spherical"):
"""Wrapper function to add a Field that is constant in space,
useful e.g. when using constant horizontal diffusivity
Expand Down
38 changes: 38 additions & 0 deletions src/parcels/_core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import xarray as xr

import parcels._sgrid as sgrid
from parcels._core._windowed_array import maybe_windowed
from parcels._core.basegrid import BaseGrid
from parcels._core.field import Field, VectorField
from parcels._core.utils.time import TimeInterval
Expand Down Expand Up @@ -58,6 +59,43 @@ def assert_valid_model_data(self) -> None:
raise e
return

def field_data(self, name: str) -> Any:
"""Return the array backing field ``name``.

Normally this is the ``xr.DataArray`` held in the dataset. After
:meth:`to_windowed_arrays`, dask-backed fields are served through a
cached :class:`~parcels._core._windowed_array.WindowedArray` instead.
"""
windowed = self.__dict__.get("_windowed")
if windowed is not None and name in windowed:
return windowed[name]
return self.data[name]

def to_windowed_arrays(self, *, max_levels: int | None = None) -> Self:
"""Wrap dask-backed field data in rolling time-window caches.

Opt-in optimization for forward-marching simulations where all particles
share a single clock. For each dask-backed, time-leading field, ``isel``
then samples a resident NumPy window (each time level loaded once and
evicted as the clock advances) instead of re-reading chunks and paying the
dask scheduling overhead on every kernel step. NumPy-backed (eager) fields
and non-time-leading fields are left unchanged.

Idempotent: re-invoking reuses the existing wrapper (and its warm cache)
rather than rebuilding it.

Parameters
----------
max_levels : int, optional
Cap on the number of time levels kept resident per field. ``None``
(default) retains every level the advancing clock still brackets.
"""
windowed = self.__dict__.setdefault("_windowed", {})
for name in self.scalar_field_names:
current = windowed.get(name, self.data[name])
windowed[name] = maybe_windowed(current, max_levels=max_levels)
return self

@property
def time_interval(self) -> TimeInterval | None:
try:
Expand Down
102 changes: 102 additions & 0 deletions tests/test_windowed_array.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
"""Tests for the transparent rolling time-window cache (WindowedArray)."""

import dask.array as da
import numpy as np
import pytest
import xarray as xr

from parcels import FieldSet, ParticleSet
from parcels._core._windowed_array import WindowedArray, maybe_windowed
from parcels._datasets.structured.generated import simple_UV_dataset
from parcels.kernels import AdvectionRK2


def test_windowed_isel_matches_dask_loads_once_and_evicts():
"""WindowedArray.isel must equal dask isel, load each level once, keep <=2 resident."""
ntime, n, npart = 20, 64, 200
rng = np.random.default_rng(0)
base = rng.standard_normal((ntime, 3, n, n))
lazy = xr.DataArray(da.from_array(base, chunks=(1, 3, n, n)), dims=("time", "depth", "lat", "lon"))
win = WindowedArray(lazy)

worst, max_cache = 0.0, 0
for step in range(40):
ti = min(step // 2, ntime - 2) # advancing clock, 2 sub-steps per level
yi, xi = rng.integers(0, n, npart), rng.integers(0, n, npart)
zi = np.zeros(npart, dtype=int)
sel = dict(
time=xr.DataArray(np.r_[np.full(npart, ti), np.full(npart, ti + 1)], dims="p"),
depth=xr.DataArray(np.r_[zi, zi], dims="p"),
lat=xr.DataArray(np.r_[yi, yi], dims="p"),
lon=xr.DataArray(np.r_[xi, xi], dims="p"),
)
got = win.isel(sel).data
ref = lazy.isel(sel).data.compute()
worst = max(worst, float(np.abs(got - ref).max()))
max_cache = max(max_cache, len(win._cache))

assert worst == 0.0 # byte-identical to dask
assert win.loads == ntime # each time level read exactly once
assert max_cache <= 2 # only the bracketing levels resident


def test_to_windowed_arrays_wraps_dask_but_not_numpy():
ds = simple_UV_dataset(mesh="flat")
fs_np = FieldSet.from_sgrid_conventions(ds, mesh="flat")
fs_dk = FieldSet.from_sgrid_conventions(ds.chunk({"time": 1}), mesh="flat")

# construction is never windowing -- it is opt-in via the fieldset method
assert not isinstance(fs_np.U.data, WindowedArray)
assert not isinstance(fs_dk.U.data, WindowedArray)

assert fs_np.to_windowed_arrays() is fs_np # chainable
fs_dk.to_windowed_arrays()

# numpy-backed field is left eager; dask-backed field gets wrapped
assert not isinstance(fs_np.U.data, WindowedArray)
assert isinstance(fs_dk.U.data, WindowedArray)
# transparency: forwarded attributes still behave like the DataArray
assert fs_dk.U.data.dims == fs_np.U.data.dims
assert fs_dk.U.data.shape == fs_np.U.data.shape


def test_to_windowed_arrays_is_idempotent_and_forwards_max_levels():
ds = simple_UV_dataset(mesh="flat")
fs = FieldSet.from_sgrid_conventions(ds.chunk({"time": 1}), mesh="flat")

fs.to_windowed_arrays(max_levels=3)
first = fs.U.data
assert isinstance(first, WindowedArray)
assert first._max == 3

# re-wrapping returns the same object (idempotent, warm cache preserved)
fs.to_windowed_arrays(max_levels=3)
assert fs.U.data is first


def test_maybe_windowed_passthrough_for_non_time_leading():
da_no_time = xr.DataArray(da.zeros((3, 4), chunks=(3, 4)), dims=("lat", "lon"))
assert maybe_windowed(da_no_time) is da_no_time # not wrapped (no leading time dim)


@pytest.mark.parametrize("mesh", ["flat", "spherical"])
def test_dask_advection_matches_numpy(mesh):
"""An identical advection must give identical trajectories whether the field
is numpy-backed or dask-backed (windowed).
"""
ds = simple_UV_dataset(mesh=mesh)
ds["U"].data[:] = 1.0 # steady zonal flow -> in-bounds, deterministic

def run(chunked):
d = ds.chunk({"time": 1}) if chunked else ds
fs = FieldSet.from_sgrid_conventions(d, mesh=mesh)
if chunked:
fs.to_windowed_arrays()
pset = ParticleSet(fs, lon=np.zeros(10), lat=np.linspace(-10, 10, 10))
pset.execute(AdvectionRK2, runtime=7200, dt=np.timedelta64(15, "m"))
return np.array(pset.lon), np.array(pset.lat)

lon_np, lat_np = run(False)
lon_dk, lat_dk = run(True)
np.testing.assert_allclose(lon_dk, lon_np, atol=1e-9)
np.testing.assert_allclose(lat_dk, lat_np, atol=1e-9)