diff --git a/pyproject.toml b/pyproject.toml index 353811c9..e6a96e96 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,6 @@ dependencies = [ "click", "dask-image", "dask>=2026.3.0", - "distributed>=2026.3.0", "datashader", "fsspec[s3,http]", "geopandas>=0.14", @@ -37,6 +36,7 @@ dependencies = [ "numpy", "ome_zarr>=0.16.0", "pandas", + "platformdirs", "pooch", "pyarrow", "rich", @@ -60,6 +60,9 @@ extra = [ "spatialdata-plot", "spatialdata-io", ] +zarrs = [ + "zarrs" +] [dependency-groups] dev = [ @@ -71,6 +74,7 @@ test = [ "pytest-mock", "pytest-xdist", "torch", + "zarrs", ] docs = [ "sphinx>=4.5", diff --git a/src/spatialdata/_core/_utils.py b/src/spatialdata/_core/_utils.py index 9dfd613b..2505df1f 100644 --- a/src/spatialdata/_core/_utils.py +++ b/src/spatialdata/_core/_utils.py @@ -1,8 +1,10 @@ from __future__ import annotations from collections.abc import Iterable +from typing import Any from anndata import AnnData +from ome_zarr.types import JSONDict from spatialdata._core.spatialdata import SpatialData @@ -164,3 +166,37 @@ def get_unique_name(name: str, attr: str, is_dataframe_column: bool = False) -> setattr(sanitized, attr, new_dict) return None if inplace else sanitized + + +def create_raster_element_kwargs( + raster_write_kwargs: dict[str, JSONDict | list[JSONDict]] | list[JSONDict], + element_name: str, + element_names: set[str], +) -> dict[str, Any] | list[dict[str, Any]]: + """Normalize raster keyword arguments to the kwargs required by `zarr.create_array` for a single raster.""" + element_raster_write_kwargs = None + if isinstance(raster_write_kwargs, dict) and (kwargs := raster_write_kwargs.get(element_name)): + element_raster_write_kwargs = kwargs + + if not element_raster_write_kwargs: + if isinstance(raster_write_kwargs, dict): + for name in element_names: + raster_write_kwargs.pop(name, None) + if not raster_write_kwargs: + element_raster_write_kwargs = {} + elif isinstance(raster_write_kwargs, dict) and not all( + isinstance(x, (dict, list)) for x in raster_write_kwargs.values() + ): + element_raster_write_kwargs = raster_write_kwargs + elif isinstance(raster_write_kwargs, list): + if not all(isinstance(x, dict) for x in raster_write_kwargs): + raise ValueError( + "If passing raster_write_kwargs as list, it is assumed to be the storage " + "options for each scale of a multiscale raster as a dictionary." + ) + element_raster_write_kwargs = raster_write_kwargs + else: + raise ValueError( + f"Type of raster_write_kwargs should be either dict or list, got {type(raster_write_kwargs)}." + ) + return element_raster_write_kwargs diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index fb55ab08..3c223f3a 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -16,6 +16,7 @@ from dask.dataframe import DataFrame as DaskDataFrame from dask.dataframe import Scalar from geopandas import GeoDataFrame +from ome_zarr.types import JSONDict from shapely import MultiPolygon, Polygon from upath import UPath from xarray import DataArray, DataTree @@ -31,7 +32,7 @@ ) from spatialdata._logging import logger from spatialdata._types import ArrayLike, Raster_T -from spatialdata._utils import _deprecation_alias +from spatialdata._utils import _deprecation_alias, zarrs_context from spatialdata.models import ( Image2DModel, Image3DModel, @@ -1113,6 +1114,7 @@ def write( update_sdata_path: bool = True, sdata_formats: SpatialDataFormatType | list[SpatialDataFormatType] | None = None, shapes_geometry_encoding: Literal["WKB", "geoarrow"] | None = None, + raster_write_kwargs: dict[str, JSONDict | list[JSONDict]] | list[JSONDict] | None = None, raster_compressor: dict[Literal["lz4", "zstd"], int] | None = None, ) -> None: """ @@ -1161,6 +1163,25 @@ def write( shapes_geometry_encoding Whether to use the WKB or geoarrow encoding for GeoParquet. See :meth:`geopandas.GeoDataFrame.to_parquet` for details. If None, uses the value from :attr:`spatialdata.settings.shapes_geometry_encoding`. + raster_write_kwargs + Storage options for raster elements. These options are passed to the zarr storage backend for writing and + can be provided in several formats: + + 1. Single dictionary + A dictionary containing all storage options applied globally. + 2. Dictionary per raster element + A dictionary where: + - Keys = names of raster elements + - Values = storage options for each element + - For single-scale data: a dictionary + - For multiscale data: a list of dictionaries (one per scale) + 3. List of dictionaries (multiscale only) + A list where each dictionary defines the storage options for one scale of a multiscale raster element. + + Important Notes + - The available key–value pairs in these dictionaries depend on the Zarr format used for writing. + - For a full list of supported storage options, refer to: + https://zarr.readthedocs.io/en/stable/api/zarr/create/#zarr.create_array raster_compressor A lenght-1 dictionary with as key the type of compression to use for images and labels and as value the compression level which should be inclusive between 0 and 9. For compression, `lz4` and `zstd` are @@ -1193,6 +1214,7 @@ def write( overwrite=False, parsed_formats=parsed, shapes_geometry_encoding=shapes_geometry_encoding, + raster_write_kwargs=raster_write_kwargs, raster_compressor=raster_compressor, ) @@ -1211,6 +1233,7 @@ def _write_element( overwrite: bool, parsed_formats: dict[str, SpatialDataFormatType] | None = None, shapes_geometry_encoding: Literal["WKB", "geoarrow"] | None = None, + raster_write_kwargs: dict[str, JSONDict | list[JSONDict] | Any] | list[JSONDict] | None = None, raster_compressor: dict[Literal["lz4", "zstd"], int] | None = None, ) -> None: from spatialdata._io.io_zarr import _get_groups_for_element @@ -1244,44 +1267,54 @@ def _write_element( validate_element(element) - if element_type == "images": - write_image( - image=element, - group=element_group, - name=element_name, - element_format=parsed_formats["raster"], - raster_compressor=raster_compressor, - ) - elif element_type == "labels": - write_labels( - labels=element, - group=root_group, - name=element_name, - element_format=parsed_formats["raster"], - raster_compressor=raster_compressor, - ) - elif element_type == "points": - write_points( - points=element, - group=element_group, - element_format=parsed_formats["points"], - ) - elif element_type == "shapes": - write_shapes( - shapes=element, - group=element_group, - element_format=parsed_formats["shapes"], - geometry_encoding=shapes_geometry_encoding, - ) - elif element_type == "tables": - write_table( - table=element, - group=element_type_group, - name=element_name, - element_format=parsed_formats["tables"], - ) - else: - raise ValueError(f"Unknown element type: {element_type}") + element_raster_write_kwargs = None + if element_type in ("images", "labels") and raster_write_kwargs: + from spatialdata._core._utils import create_raster_element_kwargs + + element_names = set(self.images.keys()).union(self.labels.keys()) + element_raster_write_kwargs = create_raster_element_kwargs(raster_write_kwargs, element_name, element_names) + + with zarrs_context(): + if element_type == "images": + write_image( + image=element, + group=element_group, + name=element_name, + element_format=parsed_formats["raster"], + storage_options=element_raster_write_kwargs, + raster_compressor=raster_compressor, + ) + elif element_type == "labels": + write_labels( + labels=element, + group=root_group, + name=element_name, + element_format=parsed_formats["raster"], + storage_options=element_raster_write_kwargs, + raster_compressor=raster_compressor, + ) + elif element_type == "points": + write_points( + points=element, + group=element_group, + element_format=parsed_formats["points"], + ) + elif element_type == "shapes": + write_shapes( + shapes=element, + group=element_group, + element_format=parsed_formats["shapes"], + geometry_encoding=shapes_geometry_encoding, + ) + elif element_type == "tables": + write_table( + table=element, + group=element_type_group, + name=element_name, + element_format=parsed_formats["tables"], + ) + else: + raise ValueError(f"Unknown element type: {element_type}") def write_element( self, @@ -1289,6 +1322,7 @@ def write_element( overwrite: bool = False, sdata_formats: SpatialDataFormatType | list[SpatialDataFormatType] | None = None, shapes_geometry_encoding: Literal["WKB", "geoarrow"] | None = None, + raster_write_kwargs: dict[str, JSONDict | list[JSONDict] | Any] | list[JSONDict] | None = None, raster_compressor: dict[Literal["lz4", "zstd"], int] | None = None, ) -> None: """ @@ -1308,6 +1342,25 @@ def write_element( shapes_geometry_encoding Whether to use the WKB or geoarrow encoding for GeoParquet. See :meth:`geopandas.GeoDataFrame.to_parquet` for details. If None, uses the value from :attr:`spatialdata.settings.shapes_geometry_encoding`. + raster_write_kwargs + Storage options for raster elements. These options are passed to the zarr storage backend for writing and + can be provided in several formats: + + 1. Single dictionary + A dictionary containing all storage options applied globally. + 2. Dictionary per raster element + A dictionary where: + - Keys = names of raster elements + - Values = storage options for each element + - For single-scale data: a dictionary + - For multiscale data: a list of dictionaries (one per scale) + 3. List of dictionaries (multiscale only) + A list where each dictionary defines the storage options for one scale of a multiscale raster element. + + Important Notes + - The available key–value pairs in these dictionaries depend on the Zarr format used for writing. + - For a full list of supported storage options, refer to: + https://zarr.readthedocs.io/en/stable/api/zarr/create/#zarr.create_array raster_compressor A lenght-1 dictionary with as key the type of compression to use for images and labels and as value the compression level which should be inclusive between 0 and 9. For compression, `lz4` and `zstd` are @@ -1331,6 +1384,7 @@ def write_element( overwrite=overwrite, sdata_formats=sdata_formats, shapes_geometry_encoding=shapes_geometry_encoding, + raster_write_kwargs=raster_write_kwargs, raster_compressor=raster_compressor, ) return @@ -1367,6 +1421,7 @@ def write_element( overwrite=overwrite, parsed_formats=parsed_formats, shapes_geometry_encoding=shapes_geometry_encoding, + raster_write_kwargs=raster_write_kwargs, raster_compressor=raster_compressor, ) # After every write, metadata should be consolidated, otherwise this can lead to IO problems like when deleting. diff --git a/src/spatialdata/_io/io_raster.py b/src/spatialdata/_io/io_raster.py index 2feb7a77..1cb3fcd4 100644 --- a/src/spatialdata/_io/io_raster.py +++ b/src/spatialdata/_io/io_raster.py @@ -148,13 +148,13 @@ def _prepare_storage_options( return None if isinstance(storage_options, dict): prepared = dict(storage_options) - if "chunks" in prepared: + if "chunks" in prepared and prepared["chunks"] is not None: prepared["chunks"] = _normalize_explicit_chunks(prepared["chunks"]) return prepared prepared_options = [dict(options) for options in storage_options] for options in prepared_options: - if "chunks" in options: + if "chunks" in options and options["chunks"] is not None: options["chunks"] = _normalize_explicit_chunks(options["chunks"]) return prepared_options @@ -284,6 +284,19 @@ def _write_raster( raster_format The format used to write the raster data. storage_options + Storage options for raster elements, which have been extracted from potentially mixed kwargs dict by + `create_raster_element_kwargs`. These options are passed to the zarr storage backend for writing and can be + provided in several formats: + + 1. Single dictionary + A dictionary containing all storage options applied to the raster, either single or multiscale. + 2. List of dictionaries (multiscale only) + A list where each dictionary defines the storage options for one scale of the multiscale raster element. + + Important Notes + - The available key–value pairs in these dictionaries depend on the Zarr format used for writing. + - For a full list of supported storage options, refer to: + https://zarr.readthedocs.io/en/stable/api/zarr/create/#zarr.create_array Additional options for writing the raster data, like chunks and compression. raster_compressor Compression settings as a len-1 dictionary with a single key-value {compression: compression level} pair @@ -292,6 +305,10 @@ def _write_raster( metadata Additional metadata for the raster element """ + from dataclasses import asdict + + from spatialdata import settings + if raster_type not in ["image", "labels"]: raise ValueError(f"{raster_type} is not a valid raster type. Must be 'image' or 'labels'.") # "name" and "label_metadata" are only used for labels. "name" is written in write_multiscale_ngff() but ignored in @@ -308,6 +325,13 @@ def _write_raster( for c in channels: metadata["metadata"]["omero"]["channels"].append({"label": c}) # type: ignore[union-attr, index, call-overload] + base_options = {k.split("_")[1]: v for k, v in asdict(settings).items() if k in ("raster_chunks", "raster_shards")} + + if isinstance(storage_options, list): + storage_options = [{**base_options, **x} for x in storage_options] + else: + storage_options = {**base_options, **(storage_options or {})} + if isinstance(raster_data, DataArray): _write_raster_dataarray( raster_type, diff --git a/src/spatialdata/_utils.py b/src/spatialdata/_utils.py index 609cd040..afd1a551 100644 --- a/src/spatialdata/_utils.py +++ b/src/spatialdata/_utils.py @@ -4,12 +4,14 @@ import re import warnings from collections.abc import Callable, Generator -from contextlib import contextmanager +from contextlib import contextmanager, nullcontext +from importlib.util import find_spec from itertools import islice from typing import Any, TypeVar import numpy as np import pandas as pd +import zarr from anndata import AnnData from dask import array as da from dask import config @@ -354,3 +356,20 @@ def _check_match_length_channels_c_dim( f" with length {c_length}." ) return c_coords + + +# TODO: get this in scverse-misc and import from there +@contextmanager +def zarrs_context() -> Generator[None, None, None]: + with ( + zarr.config.set({"codec_pipeline.path": "zarrs.ZarrsCodecPipeline"}) if find_spec("zarrs") else nullcontext(), + warnings.catch_warnings() if find_spec("zarrs") else nullcontext(), + ): + # The warning is there in case zarrs doesn't support the store type you passed in to read_zarr. + if find_spec("zarrs"): + warnings.filterwarnings( + "ignore", + message=r".*unsupported by ZarrsCodecPipeline.*", + category=UserWarning, + ) + yield diff --git a/tests/conftest.py b/tests/conftest.py index 617acb90..0941b9f4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,6 +12,8 @@ import copy as _copy from collections.abc import Callable, Sequence +from contextlib import contextmanager +from dataclasses import replace from pathlib import Path from typing import Any @@ -30,6 +32,7 @@ from skimage import data from xarray import DataArray, DataTree +from spatialdata import settings from spatialdata._core._deepcopy import deepcopy from spatialdata._core.spatialdata import SpatialData from spatialdata._types import ArrayLike @@ -715,3 +718,27 @@ def complex_sdata() -> SpatialData: sdata.tables["labels_table"].layers["log"] = np.log1p(np.abs(sdata.tables["labels_table"].X)) return sdata + + +@pytest.fixture() +def settings_cls(tmp_path, monkeypatch): + """ + Provide setting class with default path redirected. + """ + from spatialdata.config import Settings + + monkeypatch.setattr("spatialdata.config._config_path", lambda: tmp_path / "default_settings.json") + return Settings + + +@contextmanager +def temporary_settings(**kwargs): + old = replace(settings) + try: + for k, v in kwargs.items(): + setattr(settings, k, v) + settings.save() + yield + finally: + settings.__dict__.update(old.__dict__) + settings.save() diff --git a/tests/io/test_readwrite.py b/tests/io/test_readwrite.py index 034c01d3..7e969c1c 100644 --- a/tests/io/test_readwrite.py +++ b/tests/io/test_readwrite.py @@ -34,7 +34,7 @@ ) from spatialdata._io.io_raster import write_image from spatialdata.datasets import blobs -from spatialdata.models import Image2DModel +from spatialdata.models import Image2DModel, Labels2DModel from spatialdata.models._utils import get_channel_names from spatialdata.testing import assert_spatial_data_objects_are_identical from spatialdata.transformations.operations import ( @@ -53,6 +53,27 @@ RNG = default_rng(0) SDATA_FORMATS = list(SpatialDataContainerFormats.values()) +RASTER_CASES = [ + pytest.param( + {"model": Image2DModel, "dims": ("c", "y", "x"), "data_shape": (3, 800, 1000), "zarr_subpath": "images"}, + id="image", + ), + pytest.param( + {"model": Labels2DModel, "dims": ("y", "x"), "data_shape": (800, 1000), "zarr_subpath": "labels"}, + id="label", + ), +] + +RASTER_CASES_MULTISCALE = [ + pytest.param( + {"model": Image2DModel, "dims": ("c", "y", "x"), "data_shape": (3, 1600, 2000), "zarr_subpath": "images"}, + id="image", + ), + pytest.param( + {"model": Labels2DModel, "dims": ("y", "x"), "data_shape": (1600, 2000), "zarr_subpath": "labels"}, + id="label", + ), +] @pytest.mark.filterwarnings("ignore:SpatialData is not stored in the most current format:UserWarning") @@ -820,6 +841,168 @@ def test_single_scale_image_roundtrip_stays_dataarray(tmp_path: Path) -> None: assert list(image_group.keys()) == ["s0"] +@pytest.mark.parametrize("raster_case", RASTER_CASES) +@pytest.mark.parametrize("sdata_container_format", SDATA_FORMATS) +def test_write_raster_sharding( + tmp_path: Path, + raster_case: dict, + sdata_container_format: SpatialDataContainerFormatType, +) -> None: + model, dims, data_shape, zarr_subpath = ( + raster_case["model"], + raster_case["dims"], + raster_case["data_shape"], + raster_case["zarr_subpath"], + ) + chunks = (1, 100, 200) if len(dims) == 3 else (100, 200) + write_chunks = (1, 50, 100) if len(dims) == 3 else (50, 100) + write_shards = (1, 100, 200) if len(dims) == 3 else (100, 200) + + data = da.from_array(RNG.random(data_shape), chunks=chunks) + element = model.parse(data, dims=dims) + name = "element" + sdata = SpatialData(**{zarr_subpath: {name: element}}) + path = tmp_path / "data.zarr" + + if sdata_container_format.zarr_format == 2: + with pytest.raises(ValueError, match="Zarr format 2 arrays can only"): + sdata.write( + path, + sdata_formats=sdata_container_format, + raster_write_kwargs={"chunks": write_chunks, "shards": write_shards}, + ) + else: + sdata.write( + path, + sdata_formats=sdata_container_format, + raster_write_kwargs={"chunks": write_chunks, "shards": write_shards}, + ) + arr = zarr.open_group(path / zarr_subpath / name, mode="r")["s0"] + assert arr.chunks == write_chunks + assert arr.shards == write_shards + + +@pytest.mark.parametrize("raster_case", RASTER_CASES_MULTISCALE) +def test_write_multiscale_raster_sharding(tmp_path: Path, raster_case: dict) -> None: + model, dims, data_shape, zarr_subpath = ( + raster_case["model"], + raster_case["dims"], + raster_case["data_shape"], + raster_case["zarr_subpath"], + ) + chunks = (1, 100, 200) if len(dims) == 3 else (100, 200) + write_chunks = (1, 50, 100) if len(dims) == 3 else (50, 100) + write_shards = (1, 100, 200) if len(dims) == 3 else (100, 200) + + data = da.from_array(RNG.random(data_shape), chunks=chunks) + element = model.parse(data, dims=dims, scale_factors=[2]) + name = "element" + sdata = SpatialData(**{zarr_subpath: {name: element}}) + path = tmp_path / "data.zarr" + + sdata.write(path, raster_write_kwargs={"chunks": write_chunks, "shards": write_shards}) + + group = zarr.open_group(path / zarr_subpath / name, mode="r") + for scale in ("s0", "s1"): + arr = group[scale] + assert arr.chunks == write_chunks + assert arr.shards == write_shards + + +@pytest.mark.parametrize("raster_case", RASTER_CASES_MULTISCALE) +def test_write_multiscale_raster_scale_sharding(tmp_path: Path, raster_case: dict) -> None: + model, dims, data_shape, zarr_subpath = ( + raster_case["model"], + raster_case["dims"], + raster_case["data_shape"], + raster_case["zarr_subpath"], + ) + chunks_s0 = (1, 50, 100) if len(dims) == 3 else (50, 100) + shards_s0 = (1, 100, 200) if len(dims) == 3 else (100, 200) + chunks_s1 = (1, 25, 50) if len(dims) == 3 else (25, 50) + shards_s1 = (1, 50, 100) if len(dims) == 3 else (50, 100) + base_chunks = (1, 100, 200) if len(dims) == 3 else (100, 200) + + data = da.from_array(RNG.random(data_shape), chunks=base_chunks) + element = model.parse(data, dims=dims, scale_factors=[2]) + name = "element" + sdata = SpatialData(**{zarr_subpath: {name: element}}) + path = tmp_path / "data.zarr" + + sdata.write( + path, + raster_write_kwargs=[ + {"chunks": chunks_s0, "shards": shards_s0}, + {"chunks": chunks_s1, "shards": shards_s1}, + ], + ) + + group = zarr.open_group(path / zarr_subpath / name, mode="r") + assert group["s0"].chunks == chunks_s0 + assert group["s0"].shards == shards_s0 + assert group["s1"].chunks == chunks_s1 + assert group["s1"].shards == shards_s1 + + +@pytest.mark.parametrize("raster_case", RASTER_CASES) +def test_write_raster_sharding_keyword(tmp_path: Path, raster_case: dict) -> None: + model, dims, data_shape, zarr_subpath = ( + raster_case["model"], + raster_case["dims"], + raster_case["data_shape"], + raster_case["zarr_subpath"], + ) + base_chunks = (1, 100, 200) if len(dims) == 3 else (100, 200) + write_chunks = (1, 50, 100) if len(dims) == 3 else (50, 100) + write_shards = (1, 100, 200) if len(dims) == 3 else (100, 200) + + data = da.from_array(RNG.random(data_shape), chunks=base_chunks) + element = model.parse(data, dims=dims) + other = model.parse(data.copy(), dims=dims) + name, other_name = "element", "other_element" + sdata = SpatialData(**{zarr_subpath: {name: element, other_name: other}}) + path = tmp_path / "data.zarr" + + sdata.write( + path, + raster_write_kwargs={name: {"chunks": write_chunks, "shards": write_shards}}, + ) + + arr = zarr.open_group(path / zarr_subpath / name, mode="r")["s0"] + assert arr.chunks == write_chunks + assert arr.shards == write_shards + + other_arr = zarr.open_group(path / zarr_subpath / other_name, mode="r")["s0"] + assert other_arr.chunks == base_chunks + assert not other_arr.shards + + +def test_write_raster_elements_sharding_chunking(tmp_path: Path) -> None: + write_chunks = (1, 50, 100) + write_shards = (1, 100, 200) + + data = da.from_array(RNG.random((1, 500, 600))) + element = Image2DModel.parse(data, dims=("c", "y", "x")) + + sdata = SpatialData() + path = tmp_path / "data.zarr" + + sdata.write(path) + sdata["image"] = element + sdata["other_image"] = element + + sdata.write_element( + element_name=["image", "other_image"], raster_write_kwargs={"chunks": write_chunks, "shards": write_shards} + ) + + arr = zarr.open_group(path / "images" / "image", mode="r")["s0"] + assert arr.chunks == write_chunks + assert arr.shards == write_shards + arr = zarr.open_group(path / "images" / "other_image", mode="r")["s0"] + assert arr.chunks == write_chunks + assert arr.shards == write_shards + + @pytest.mark.filterwarnings("ignore:SpatialData is not stored in the most current format:UserWarning") @pytest.mark.parametrize("sdata_container_format", SDATA_FORMATS) def test_self_contained(full_sdata: SpatialData, sdata_container_format: SpatialDataContainerFormatType) -> None: