diff --git a/pyproject.toml b/pyproject.toml index 1dd9dfc..5d6f7da 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,9 +37,11 @@ classifiers = [ ] dependencies = [ + "arraybridge>=0.2.9", "numpy>=1.26.0", "portalocker>=2.8.0", # Cross-platform file locking "metaclass-registry", + "imageio>=2.37.0", "zarr>=2.18.0,<3.0", # Required for ZarrStorageBackend "ome-zarr>=0.11.0", # Required for OME-ZARR HCS compliance ] @@ -197,4 +199,4 @@ ignore = [ ] [tool.ruff.per-file-ignores] -"__init__.py" = ["F401"] # unused imports \ No newline at end of file +"__init__.py" = ["F401"] # unused imports diff --git a/src/polystore/__init__.py b/src/polystore/__init__.py index 5c38d68..123c449 100644 --- a/src/polystore/__init__.py +++ b/src/polystore/__init__.py @@ -26,10 +26,10 @@ get_backend, ) from .constants import Backend, MemoryType, TransportMode -from .disk import DiskStorageBackend +from .disk import DiskBackend, DiskStorageBackend from .filemanager import FileManager from .formats import FileFormat, DEFAULT_IMAGE_EXTENSIONS -from .memory import MemoryStorageBackend +from .memory import MemoryBackend, MemoryStorageBackend from .metadata_writer import ( AtomicMetadataWriter, MetadataWriteError, @@ -76,7 +76,9 @@ "register_cleanup_callback", "STORAGE_BACKENDS", "DiskStorageBackend", + "DiskBackend", "MemoryStorageBackend", + "MemoryBackend", "FileManager", "file_lock", "atomic_write_json", diff --git a/src/polystore/backend_registry.py b/src/polystore/backend_registry.py index ad8ac52..eb4cb21 100644 --- a/src/polystore/backend_registry.py +++ b/src/polystore/backend_registry.py @@ -74,7 +74,7 @@ def create_storage_registry() -> Dict[str, DataSink]: # Backends that require context-specific initialization (e.g., plate_root) # These are registered lazily when needed, not at startup - SKIP_BACKENDS = {'virtual_workspace', 'omero_local'} + SKIP_BACKENDS = {'virtual_workspace', 'omero_local', 'bioformats'} registry = {} for backend_type in STORAGE_BACKENDS.keys(): @@ -157,4 +157,3 @@ def cleanup_all_backends() -> None: _backend_instances.clear() logger.info("All backend instances cleaned up") - diff --git a/src/polystore/base.py b/src/polystore/base.py index 2b033fc..e18849e 100644 --- a/src/polystore/base.py +++ b/src/polystore/base.py @@ -546,15 +546,16 @@ def reset_memory_backend() -> None: # Clear files from existing memory backend while preserving directories memory_backend = storage_registry[Backend.MEMORY.value] - # DEBUG: Log what's in memory before clearing existing_keys = list(memory_backend._memory_store.keys()) - logger.info(f"🔍 VFS_CLEAR: Memory backend has {len(existing_keys)} entries BEFORE clear") - logger.info(f"🔍 VFS_CLEAR: First 10 keys: {existing_keys[:10]}") + logger.debug("Memory backend has %s entries before clear", len(existing_keys)) + logger.debug("First memory backend keys before clear: %s", existing_keys[:10]) memory_backend.clear_files_only() - # DEBUG: Log what's in memory after clearing remaining_keys = list(memory_backend._memory_store.keys()) - logger.info(f"🔍 VFS_CLEAR: Memory backend has {len(remaining_keys)} entries AFTER clear (directories only)") - logger.info(f"🔍 VFS_CLEAR: First 10 remaining keys: {remaining_keys[:10]}") + logger.debug( + "Memory backend has %s entries after clear (directories only)", + len(remaining_keys), + ) + logger.debug("First memory backend keys after clear: %s", remaining_keys[:10]) logger.info("Memory backend reset - files cleared, directories preserved") diff --git a/src/polystore/bioformats_java.py b/src/polystore/bioformats_java.py new file mode 100644 index 0000000..41c7824 --- /dev/null +++ b/src/polystore/bioformats_java.py @@ -0,0 +1,223 @@ +"""Shared Java Bio-Formats bridge for metadata discovery and plane loading.""" + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from threading import Lock +from typing import Any, Callable + +import numpy as np + + +class BioFormatsJavaUnavailableError(RuntimeError): + """Raised when the Java Bio-Formats runtime cannot be initialized.""" + + +@dataclass(frozen=True, slots=True) +class BioFormatsOpenedReader: + """Open Bio-Formats reader plus its OME metadata store.""" + + reader: Any + metadata: Any + + def close(self) -> None: + self.reader.close() + + +class BioFormatsJavaContext: + """Lazy JVM/ImageJ context for Bio-Formats Java access.""" + + _lock = Lock() + _instance: "BioFormatsJavaContext | None" = None + + def __init__(self, imagej_module: Any, scyjava_module: Any): + self.imagej = imagej_module + self.scyjava = scyjava_module + self.ij = None + self.ImageReader = None + self.MetadataTools = None + self.FormatTools = None + + @classmethod + def instance(cls) -> "BioFormatsJavaContext": + with cls._lock: + if cls._instance is None: + cls._instance = cls._create() + return cls._instance + + @classmethod + def _create(cls) -> "BioFormatsJavaContext": + try: + import imagej + import scyjava + except ImportError as exc: + raise BioFormatsJavaUnavailableError( + "Bio-Formats support requires the optional bioformats/fiji dependencies." + ) from exc + return cls(imagej, scyjava) + + def ensure_initialized(self) -> None: + if self.ij is not None: + return + try: + self.ij = self.imagej.init("sc.fiji:fiji", mode="headless") + self.ImageReader = self.scyjava.jimport("loci.formats.ImageReader") + self.MetadataTools = self.scyjava.jimport("loci.formats.MetadataTools") + self.FormatTools = self.scyjava.jimport("loci.formats.FormatTools") + except Exception as exc: + raise BioFormatsJavaUnavailableError( + "Could not initialize Fiji/Bio-Formats through pyimagej." + ) from exc + + def open_reader(self, source_path: str | Path) -> BioFormatsOpenedReader: + self.ensure_initialized() + metadata = self.MetadataTools.createOMEXMLMetadata() + reader = self.ImageReader() + try: + reader.setMetadataStore(metadata) + reader.setId(str(source_path)) + return BioFormatsOpenedReader(reader=reader, metadata=metadata) + except Exception: + reader.close() + raise + + +def java_int(value: Any) -> int | None: + """Convert nullable Java primitive wrappers to Python int.""" + return OptionalJavaScalar.from_java(value, JAVA_SCALAR_PROJECTOR.readers).convert(int) + + +def java_float(value: Any) -> float | None: + """Convert nullable Java numeric wrappers to Python float.""" + return OptionalJavaScalar.from_java(value, JAVA_SCALAR_PROJECTOR.readers).convert(float) + + +def java_str(value: Any) -> str | None: + """Convert nullable Java strings to Python strings.""" + if value is None: + return None + return str(value) + + +def _read_java_value(value: Any) -> Any: + return value.value() + + +def _read_java_get_value(value: Any) -> Any: + return value.getValue() + + +@dataclass(frozen=True, slots=True) +class JavaScalarProjector: + """Project nullable Java scalar wrappers to Python scalar values.""" + + readers: tuple[Callable[[Any], Any], ...] + + def unwrap(self, value: Any) -> Any: + for reader in self.readers: + try: + return reader(value) + except AttributeError: + continue + return value + + +@dataclass(frozen=True, slots=True) +class OptionalJavaScalar: + """Nullable Java scalar after wrapper unwrapping.""" + + value: Any | None + + @classmethod + def from_java( + cls, + value: Any, + readers: tuple[Callable[[Any], Any], ...], + ) -> "OptionalJavaScalar": + if value is None: + return cls(None) + return cls(JavaScalarProjector(readers).unwrap(value)) + + def convert(self, converter: Callable[[Any], Any]) -> Any | None: + if self.value is None: + return None + return converter(self.value) + + +JAVA_SCALAR_PROJECTOR = JavaScalarProjector( + readers=( + _read_java_value, + _read_java_get_value, + ) +) + + +def load_bioformats_plane( + *, + source_path: Path, + series_index: int, + plane_index: int, +) -> np.ndarray: + """Load a single 2D Bio-Formats plane through the Java ImageReader.""" + context = BioFormatsJavaContext.instance() + opened = context.open_reader(source_path) + reader = opened.reader + try: + reader.setSeries(series_index) + if reader.getRGBChannelCount() != 1: + raise ValueError( + "Bio-Formats RGB/interleaved planes are not yet representable as " + "OpenHCS scalar channel planes." + ) + raw = bytes(reader.openBytes(plane_index)) + dtype = PixelDtypeCatalog.from_format_tools(context.FormatTools).dtype( + pixel_type=int(reader.getPixelType()), + little_endian=bool(reader.isLittleEndian()), + ) + array = np.frombuffer(raw, dtype=dtype) + return array.reshape((int(reader.getSizeY()), int(reader.getSizeX()))) + finally: + opened.close() + + +@dataclass(frozen=True, slots=True) +class PixelDtypeSpec: + """NumPy dtype projection for one Bio-Formats pixel type.""" + + key: int + dtype_code: str + endian_sensitive: bool = True + + def dtype(self, *, little_endian: bool) -> np.dtype: + if not self.endian_sensitive: + return np.dtype(self.dtype_code) + endian = "<" if little_endian else ">" + return np.dtype(endian + self.dtype_code) + + +@dataclass(frozen=True, slots=True) +class PixelDtypeCatalog: + """Authoritative Bio-Formats pixel-type to NumPy dtype mapping.""" + + specs_by_key: dict[int, PixelDtypeSpec] + + @classmethod + def from_format_tools(cls, format_tools: Any) -> "PixelDtypeCatalog": + specs = ( + PixelDtypeSpec(int(format_tools.INT8), "i1", endian_sensitive=False), + PixelDtypeSpec(int(format_tools.UINT8), "u1", endian_sensitive=False), + PixelDtypeSpec(int(format_tools.INT16), "i2"), + PixelDtypeSpec(int(format_tools.UINT16), "u2"), + PixelDtypeSpec(int(format_tools.INT32), "i4"), + PixelDtypeSpec(int(format_tools.UINT32), "u4"), + PixelDtypeSpec(int(format_tools.FLOAT), "f4"), + PixelDtypeSpec(int(format_tools.DOUBLE), "f8"), + ) + return cls({spec.key: spec for spec in specs}) + + def dtype(self, *, pixel_type: int, little_endian: bool) -> np.dtype: + try: + return self.specs_by_key[pixel_type].dtype(little_endian=little_endian) + except KeyError as exc: + raise ValueError(f"Unsupported Bio-Formats pixel type: {pixel_type}") from exc diff --git a/src/polystore/bioformats_storage.py b/src/polystore/bioformats_storage.py new file mode 100644 index 0000000..ba17dcf --- /dev/null +++ b/src/polystore/bioformats_storage.py @@ -0,0 +1,258 @@ +"""Structured-reference backend for Bio-Formats-backed virtual workspaces.""" + +from __future__ import annotations + +import json +from dataclasses import dataclass +from fnmatch import fnmatch +from pathlib import Path +from typing import Any, Dict, List, Optional, Set, Union + +from .base import PicklableBackend, ReadOnlyBackend +from .constants import Backend +from .exceptions import StorageResolutionError +from .metadata_writer import get_metadata_path + + +@dataclass(frozen=True, slots=True) +class BioFormatsPlaneRef: + """Serializable reference to one Bio-Formats image plane.""" + + source_path: Path + series_index: int + plane_index: int + c: int + z: int + t: int + reader: str = "bioformats" + + @classmethod + def from_mapping( + cls, + payload: Dict[str, Any], + *, + plate_root: Path, + ) -> "BioFormatsPlaneRef": + source_path = Path(payload["source_path"]) + if not source_path.is_absolute(): + source_path = plate_root / source_path + return cls( + source_path=source_path, + series_index=int(payload.get("series_index", 0)), + plane_index=int(payload["plane_index"]), + c=int(payload["c"]), + z=int(payload["z"]), + t=int(payload["t"]), + reader=str(payload.get("reader", "bioformats")), + ) + + +class BioFormatsStorageBackend(ReadOnlyBackend, PicklableBackend): + """Load normalized virtual source keys from structured Bio-Formats refs.""" + + _backend_type = Backend.BIOFORMATS.value + + def __init__(self, plate_root: Path | None = None): + self.plate_root = None if plate_root is None else Path(plate_root) + self._mapping_cache: Optional[Dict[str, Dict[str, Any]]] = None + self._cache_mtime: Optional[float] = None + + def get_connection_params(self) -> Optional[Dict[str, Any]]: + if self.plate_root is None: + return None + return {"plate_root": str(self.plate_root)} + + def set_connection_params(self, params: Optional[Dict[str, Any]]) -> None: + if not params: + self.plate_root = None + self._mapping_cache = None + self._cache_mtime = None + return + self.plate_root = Path(params["plate_root"]) + self._mapping_cache = None + self._cache_mtime = None + + def load(self, file_path: Union[str, Path], **kwargs) -> Any: + ref = self._resolve_ref(file_path) + if ref.reader == "npy": + return _load_npy_plane(ref) + if ref.reader != "bioformats": + raise BioFormatsReaderUnavailableError( + f"Unsupported Bio-Formats reader {ref.reader!r}." + ) + from .bioformats_java import load_bioformats_plane + + return load_bioformats_plane( + source_path=ref.source_path, + series_index=ref.series_index, + plane_index=ref.plane_index, + ) + + def load_batch(self, file_paths: List[Union[str, Path]], **kwargs) -> List[Any]: + return [self.load(file_path, **kwargs) for file_path in file_paths] + + def list_files( + self, + directory: Union[str, Path], + pattern: Optional[str] = None, + extensions: Optional[Set[str]] = None, + recursive: bool = False, + **kwargs, + ) -> List[str]: + plate_root = self._require_plate_root() + relative_dir = self.relative_to_root(directory) + normalized_dir = _normalize_relative_path(str(relative_dir)) + lowercase_extensions = ( + None if extensions is None else {extension.lower() for extension in extensions} + ) + results = [] + for virtual_path in self._load_mapping().keys(): + if not _virtual_path_in_directory( + virtual_path, + normalized_dir=normalized_dir, + recursive=recursive, + ): + continue + path = Path(virtual_path) + if lowercase_extensions is not None and path.suffix.lower() not in lowercase_extensions: + continue + if pattern is not None and not fnmatch(path.name, pattern): + continue + results.append(str(plate_root / virtual_path)) + return results + + def exists(self, path: Union[str, Path]) -> bool: + try: + relative = self.normalized_relative_path(path) + except StorageResolutionError: + return False + if not relative: + return True + mapping = self._load_mapping() + return relative in mapping or any( + virtual_path.startswith(relative + "/") + for virtual_path in mapping + ) + + def is_file(self, path: Union[str, Path]) -> bool: + try: + relative = self.normalized_relative_path(path) + except StorageResolutionError: + return False + return relative in self._load_mapping() + + def is_dir(self, path: Union[str, Path]) -> bool: + try: + relative = self.normalized_relative_path(path) + except StorageResolutionError: + return False + return not relative or any( + virtual_path.startswith(relative + "/") + for virtual_path in self._load_mapping() + ) + + def list_dir(self, path: Union[str, Path]) -> List[str]: + relative = self.normalized_relative_path(path) + prefix = "" if not relative else relative + "/" + names = set() + for virtual_path in self._load_mapping(): + if not virtual_path.startswith(prefix): + continue + remainder = virtual_path[len(prefix):] + if remainder: + names.add(remainder.split("/", 1)[0]) + return sorted(names) + + def _resolve_ref(self, path: Union[str, Path]) -> BioFormatsPlaneRef: + plate_root = self._require_plate_root() + relative_path = self.normalized_relative_path(path) + mapping = self._load_mapping() + try: + payload = mapping[relative_path] + except KeyError as exc: + raise StorageResolutionError( + f"Path not in Bio-Formats workspace mapping: {relative_path}" + ) from exc + if not isinstance(payload, dict): + raise StorageResolutionError( + f"Bio-Formats workspace mapping for {relative_path!r} is not structured." + ) + return BioFormatsPlaneRef.from_mapping(payload, plate_root=plate_root) + + def _load_mapping(self) -> Dict[str, Dict[str, Any]]: + plate_root = self._require_plate_root() + metadata_path = get_metadata_path(plate_root) + if not metadata_path.exists(): + raise FileNotFoundError(f"Metadata not found: {metadata_path}") + current_mtime = metadata_path.stat().st_mtime + if self._mapping_cache is not None and self._cache_mtime == current_mtime: + return self._mapping_cache + metadata = json.loads(metadata_path.read_text(encoding="utf-8")) + combined_mapping: Dict[str, Dict[str, Any]] = {} + for subdirectory in metadata.get("subdirectories", {}).values(): + if Backend.BIOFORMATS.value not in subdirectory.get("available_backends", {}): + continue + workspace_mapping = subdirectory.get("workspace_mapping", {}) + for virtual_path, ref_payload in workspace_mapping.items(): + if isinstance(ref_payload, dict): + combined_mapping[_normalize_relative_path(str(virtual_path))] = ref_payload + if not combined_mapping: + raise ValueError(f"No Bio-Formats workspace_mapping in {metadata_path}") + self._mapping_cache = combined_mapping + self._cache_mtime = current_mtime + return combined_mapping + + def _require_plate_root(self) -> Path: + if self.plate_root is None: + raise StorageResolutionError("BioFormatsStorageBackend requires plate_root.") + return self.plate_root + + def relative_to_root(self, path: Union[str, Path]) -> Path: + plate_root = self._require_plate_root() + path_obj = Path(path) + if not path_obj.is_absolute(): + return path_obj + try: + return path_obj.relative_to(plate_root) + except ValueError as exc: + raise StorageResolutionError( + f"Path {path_obj} is outside Bio-Formats plate root {plate_root}." + ) from exc + + def normalized_relative_path(self, path: Union[str, Path]) -> str: + return _normalize_relative_path(str(self.relative_to_root(path))) + + +class BioFormatsReaderUnavailableError(RuntimeError): + """Raised when a production Bio-Formats reader has not been configured.""" + + +def _load_npy_plane(ref: BioFormatsPlaneRef) -> Any: + import numpy as np + + array = np.load(ref.source_path) + if array.ndim == 2: + return array + if array.ndim == 5: + return array[ref.t - 1, ref.z - 1, ref.c - 1] + if array.ndim == 3: + return array[ref.plane_index] + raise ValueError( + f"Unsupported npy Bio-Formats fixture shape {array.shape} for {ref.source_path}." + ) + + +def _normalize_relative_path(path: str) -> str: + normalized = path.replace("\\", "/") + return "" if normalized == "." else normalized + + +def _virtual_path_in_directory( + virtual_path: str, + *, + normalized_dir: str, + recursive: bool, +) -> bool: + if recursive: + return not normalized_dir or virtual_path.startswith(normalized_dir + "/") + return _normalize_relative_path(str(Path(virtual_path).parent)) == normalized_dir diff --git a/src/polystore/constants.py b/src/polystore/constants.py index 3a27cfb..0103236 100644 --- a/src/polystore/constants.py +++ b/src/polystore/constants.py @@ -19,6 +19,7 @@ class Backend(Enum): FIJI_STREAM = "fiji_stream" OMERO_LOCAL = "omero_local" VIRTUAL_WORKSPACE = "virtual_workspace" + BIOFORMATS = "bioformats" class TransportMode(Enum): diff --git a/src/polystore/disk.py b/src/polystore/disk.py index 40c33d9..ca24e7c 100644 --- a/src/polystore/disk.py +++ b/src/polystore/disk.py @@ -9,6 +9,7 @@ import logging import os import shutil +import importlib from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Set, Union @@ -23,7 +24,7 @@ def optional_import(module_name): try: - return __import__(module_name) + return importlib.import_module(module_name) except ImportError: return None @@ -44,6 +45,7 @@ def optional_import(module_name): cupy = get_cupy() tf = get_tf() tifffile = optional_import("tifffile") +imageio = optional_import("imageio.v3") # Optional arraybridge integration for memory conversion try: @@ -99,6 +101,7 @@ def _register_formats(self): # Complex formats - use custom handlers (FileFormat.TIFF, tifffile, self._tiff_writer, self._tiff_reader), + (FileFormat.RASTER_IMAGE, imageio, self._image_writer, self._image_reader), (FileFormat.TEXT, True, self._text_writer, self._text_reader), (FileFormat.JSON, True, self._json_writer, self._json_reader), (FileFormat.CSV, True, self._csv_writer, self._csv_reader), @@ -164,6 +167,14 @@ def _tiff_reader(self, path): else: return tifffile.imread(str(path)) + def _image_writer(self, path, data, **kwargs): + """Write standard raster images using imageio.""" + imageio.imwrite(path, np.asarray(data)) + + def _image_reader(self, path): + """Read standard raster images using imageio.""" + return imageio.imread(path) + def _text_writer(self, path, data, **kwargs): """Write text data to file. Accepts and ignores extra kwargs for compatibility.""" path.write_text(str(data)) @@ -261,7 +272,7 @@ def load(self, file_path: Union[str, Path], **kwargs) -> Any: ext = disk_path.suffix.lower() if not self.format_registry.is_registered(ext): - raise ValueError(f"No writer registered for extension '{ext}'") + raise ValueError(f"No reader registered for extension '{ext}'") try: reader = self.format_registry.get_reader(ext) @@ -823,3 +834,6 @@ def _save_rois(self, rois: List, output_path: Path, images_dir: str = None, **kw logger.info(f"Saved {roi_count} ROIs to .roi.zip archive: {output_path}") return str(output_path) + + +DiskBackend = DiskStorageBackend diff --git a/src/polystore/fiji_stream.py b/src/polystore/fiji_stream.py index 4d52817..2cbeb1c 100644 --- a/src/polystore/fiji_stream.py +++ b/src/polystore/fiji_stream.py @@ -12,32 +12,145 @@ """ import logging -import time -from pathlib import Path -from typing import Any, List, Union +from enum import Enum -import zmq - -from .constants import Backend, TransportMode +from .constants import Backend from .streaming_constants import StreamingDataType -from .streaming import StreamingBackend +from .streaming import ( + FilePath, + RoiStreamPayload, + StreamingBuiltBatch, + StreamingBackend, + StreamingComponentNamesRequest, + StreamingItemPreparationRequest, + ViewerDisplayPayloadExtra, +) +from .streaming.viewer_transport import ViewerStreamRequest from .roi_converters import FijiROIConverter -from zmqruntime.transport import get_zmq_transport_url, coerce_transport_mode +from zmqruntime.viewer_protocol import ( + ViewerBatchContextWireField, + ViewerBatchItemWireField, + ViewerBatchWireField, + ViewerWireMapping, + ViewerWireValue, +) logger = logging.getLogger(__name__) +class FijiDisplayWireField(str, Enum): + """Fiji-specific display fields inside the shared viewer display payload.""" + + LUT = "lut" + AUTO_CONTRAST = "auto_contrast" + + +class FijiDisplayPayload: + """Display payload projection for Fiji stream messages.""" + + @staticmethod + def auto_contrast_value(display_config) -> bool: + return display_config.auto_contrast + + @classmethod + def from_display_config(cls, display_config) -> dict[str, ViewerWireValue]: + return { + FijiDisplayWireField.LUT.value: display_config.get_lut_name(), + FijiDisplayWireField.AUTO_CONTRAST.value: cls.auto_contrast_value( + display_config + ), + } + + +class FijiMessageMetadata: + """Typed access to optional Fiji message metadata.""" + + @staticmethod + def component_names_metadata(message: ViewerWireMapping) -> ViewerWireValue: + return message[ViewerBatchWireField.COMPONENT_NAMES_METADATA.value] + + +class FijiRoiPayload: + """ROI payload inspection for Fiji logging.""" + + @staticmethod + def count(item_data: ViewerWireMapping) -> int: + if ViewerBatchItemWireField.ROIS.value not in item_data: + raise ValueError("Fiji ROI payload missing required 'rois' field") + return len(item_data[ViewerBatchItemWireField.ROIS.value]) + + class FijiStreamingBackend(StreamingBackend): """Fiji streaming backend with ZMQ publisher pattern (matches Napari architecture).""" _backend_type = Backend.FIJI_STREAM.value - # Configure ABC attributes VIEWER_TYPE = 'fiji' SHM_PREFIX = 'fiji_' - # __init__, _get_publisher, save, cleanup now inherited from ABC + def _display_payload_extra( + self, + stream_request: ViewerStreamRequest, + ) -> ViewerDisplayPayloadExtra: + return ViewerDisplayPayloadExtra.from_mapping( + FijiDisplayPayload.from_display_config(stream_request.display_config) + ) + + def _message_extra( + self, + stream_request: ViewerStreamRequest, + ) -> dict[str, ViewerWireValue]: + message_extra = stream_request.message_extra_payload() + message_extra[ViewerBatchContextWireField.IMAGES_DIR.value] = ( + stream_request.images_dir + ) + return message_extra + + def _component_names_request( + self, + stream_request: ViewerStreamRequest, + ) -> StreamingComponentNamesRequest: + return StreamingComponentNamesRequest.from_stream_request( + stream_request, + log_prefix="🏷️ FIJI BACKEND", + verbose=True, + ) + + def _after_batch_message_built( + self, + stream_request: ViewerStreamRequest, + built_batch: StreamingBuiltBatch, + ) -> None: + logger.info( + "🏷️ FIJI BACKEND: Final component_names_metadata: %s", + FijiMessageMetadata.component_names_metadata(built_batch.message), + ) - def _prepare_rois_data(self, data: Any, file_path: Union[str, Path]) -> dict: + for item in built_batch.batch_images: + logger.info( + "🔍 FIJI BACKEND: Added %s item to batch", + item[ViewerBatchItemWireField.DATA_TYPE.value], + ) + + data_types = [ + item[ViewerBatchItemWireField.DATA_TYPE.value] + for item in built_batch.batch_images + ] + type_counts = { + data_type: data_types.count(data_type) + for data_type in set(data_types) + } + logger.info( + "📤 FIJI BACKEND: Sending batch message with %d items to port %s: %s", + len(built_batch.batch_images), + stream_request.port, + type_counts, + ) + + def _prepare_rois_data( + self, + data: RoiStreamPayload, + file_path: FilePath, + ) -> dict[str, ViewerWireValue]: """ Prepare ROIs data for transmission. @@ -53,125 +166,44 @@ def _prepare_rois_data(self, data: Any, file_path: Union[str, Path]) -> dict: rois_encoded = FijiROIConverter.encode_rois_for_transmission(roi_bytes_list) return { - 'path': str(file_path), - 'rois': rois_encoded, - } - - def _prepare_batch_item(self, data: Any, file_path: Union[str, Path], data_type): - logger.info(f"🔍 FIJI BACKEND: Detected data type: {data_type} for path: {file_path}") - if data_type == StreamingDataType.SHAPES: - logger.info(f"🔍 FIJI BACKEND: Preparing ROI data for {file_path}") - item_data = self._prepare_rois_data(data, file_path) - data_type_value = "rois" - logger.info(f"🔍 FIJI BACKEND: ROI data prepared: {len(item_data.get('rois', []))} ROIs") - else: - logger.info(f"🔍 FIJI BACKEND: Preparing image data for {file_path}") - item_data = self._create_shared_memory(data, file_path) - data_type_value = "image" - return item_data, data_type_value - - def save_batch(self, data_list: List[Any], file_paths: List[Union[str, Path]], **kwargs) -> None: - """Stream batch of images or ROIs to Fiji via ZMQ.""" - - logger.info(f"📦 FIJI BACKEND: save_batch called with {len(data_list)} items") - - # Filter to only supported file types - data_list, file_paths, skipped = self._filter_streamable_files(data_list, file_paths) - if not data_list: - return - - # Extract kwargs using generic polymorphic names - host = kwargs.get('host', 'localhost') - port = kwargs['port'] - transport_mode = kwargs['transport_mode'] - transport_config = kwargs.get('transport_config') - display_config = kwargs['display_config'] - microscope_handler = kwargs['microscope_handler'] - source = kwargs.get('source', 'unknown_source') # Pre-built source value - images_dir = kwargs.get('images_dir') # Source image subdirectory for ROI mapping - plate_path = kwargs.get('plate_path') - logger.info(f"🏷️ FIJI BACKEND: plate_path = {plate_path}") - logger.info(f"🏷️ FIJI BACKEND: microscope_handler = {microscope_handler}") - display_payload_extra = { - "lut": display_config.get_lut_name(), - "auto_contrast": display_config.auto_contrast if hasattr(display_config, "auto_contrast") else True, - } - message_extra = { - "images_dir": images_dir, + ViewerBatchItemWireField.PATH.value: str(file_path), + ViewerBatchItemWireField.ROIS.value: rois_encoded, } - message, batch_images, image_ids = self._build_batch_message( - data_list, - file_paths, - microscope_handler, - source, - display_config, - self._prepare_batch_item, - plate_path=plate_path, - component_names_kwargs={"log_prefix": "🏷️ FIJI BACKEND", "verbose": True}, - display_payload_extra=display_payload_extra, - message_extra=message_extra, - ) - + def _prepare_batch_item( + self, + request: StreamingItemPreparationRequest, + ) -> tuple[ViewerWireMapping, str]: logger.info( - "🏷️ FIJI BACKEND: Final component_names_metadata: %s", - message.get("component_names_metadata", {}), + "🔍 FIJI BACKEND: Detected data type: %s for path: %s", + request.data_type, + request.item_path.value, ) - - for item in batch_images: - logger.info(f"🔍 FIJI BACKEND: Added {item['data_type']} item to batch") - - # Log batch composition - data_types = [item['data_type'] for item in batch_images] - type_counts = {dt: data_types.count(dt) for dt in set(data_types)} - logger.info(f"📤 FIJI BACKEND: Sending batch message with {len(batch_images)} items to port {port}: {type_counts}") - - # Register sent images with queue tracker BEFORE sending - # This prevents race condition with IPC mode where acks arrive before registration - self._register_with_queue_tracker( - port, - image_ids, - transport_mode=transport_mode, - transport_config=transport_config, - ) - - # Create FRESH REQ socket for each send - REQ sockets cannot be reused - # This prevents the "Operation cannot be accomplished in current state" error - # when multiple streams happen concurrently - transport_config = transport_config or self._transport_config - url = get_zmq_transport_url( - port, - host=host, - mode=coerce_transport_mode(transport_mode), - config=transport_config, - ) - - if self._context is None: - self._context = zmq.Context() - - socket = self._context.socket(zmq.REQ) - socket.connect(url) - time.sleep(0.1) # Brief delay for connection to establish - - try: - # Send with REQ socket (BLOCKING - worker waits for Fiji to acknowledge) - # Worker blocks until Fiji receives, copies data from shared memory, and sends ack - # This guarantees no messages are lost and shared memory is only closed after Fiji is done - logger.info(f"📤 FIJI BACKEND: Sending batch of {len(batch_images)} images to Fiji on port {port} (REQ/REP - blocking until ack)") - socket.send_json(message) # Blocking send - - # Wait for acknowledgment from Fiji (REP socket) - # Fiji will only reply after it has copied all data from shared memory - ack_response = socket.recv_json() - logger.info(f"✅ FIJI BACKEND: Received ack from Fiji: {ack_response.get('status', 'unknown')}") - - finally: - # Always close the socket - never reuse REQ sockets - socket.close() - - # Clean up publisher's handles after successful send - # Receiver will unlink the shared memory after copying the data - self._cleanup_shared_memory_blocks(batch_images, unlink=False) + if request.data_type == StreamingDataType.SHAPES: + logger.info( + "🔍 FIJI BACKEND: Preparing ROI data for %s", + request.item_path.value, + ) + item_data = self._prepare_rois_data( + request.data, + request.item_path.value, + ) + data_type_value = StreamingDataType.ROIS.value + logger.info( + "🔍 FIJI BACKEND: ROI data prepared: %d ROIs", + FijiRoiPayload.count(item_data), + ) + else: + logger.info( + "🔍 FIJI BACKEND: Preparing image data for %s", + request.item_path.value, + ) + item_data = self.create_shared_memory_payload( + request.data, + request.item_path.value, + ) + data_type_value = StreamingDataType.IMAGE.value + return item_data, data_type_value # cleanup() now inherited from ABC diff --git a/src/polystore/formats.py b/src/polystore/formats.py index ddfb9a5..3643361 100644 --- a/src/polystore/formats.py +++ b/src/polystore/formats.py @@ -20,6 +20,7 @@ class FileFormat(Enum): # Image formats TIFF = "tiff" + RASTER_IMAGE = "raster_image" # Data formats CSV = "csv" @@ -44,6 +45,7 @@ def extensions(self): FileFormat.TENSORFLOW: [".tf"], FileFormat.ZARR: [".zarr"], FileFormat.TIFF: [".tif", ".tiff"], + FileFormat.RASTER_IMAGE: [".bmp", ".gif", ".jpeg", ".jpg", ".png"], FileFormat.CSV: [".csv"], FileFormat.JSON: [".json"], FileFormat.TEXT: [".txt"], @@ -51,7 +53,14 @@ def extensions(self): } # Default image extensions -DEFAULT_IMAGE_EXTENSIONS = {".tif", ".tiff", ".TIF", ".TIFF"} +DEFAULT_IMAGE_EXTENSIONS = { + extension + for extensions in ( + FILE_FORMAT_EXTENSIONS[FileFormat.TIFF], + FILE_FORMAT_EXTENSIONS[FileFormat.RASTER_IMAGE], + ) + for extension in extensions +} def get_format_from_extension(ext: str) -> FileFormat: diff --git a/src/polystore/memory.py b/src/polystore/memory.py index a59114f..872d581 100644 --- a/src/polystore/memory.py +++ b/src/polystore/memory.py @@ -139,6 +139,9 @@ def list_files( if self._memory_store[dir_key] is not None: raise NotADirectoryError(f"Path is not a directory: {directory}") + lowercase_extensions = ( + None if extensions is None else {extension.lower() for extension in extensions} + ) result = [] dir_prefix = dir_key + "/" if not dir_key.endswith("/") else dir_key @@ -159,7 +162,10 @@ def list_files( filename = Path(rel_path).name # If pattern is None, match all files if pattern is None or fnmatch(filename, pattern): - if not extensions or Path(filename).suffix in extensions: + if ( + lowercase_extensions is None + or Path(filename).suffix.lower() in lowercase_extensions + ): # Calculate depth for breadth-first sorting depth = rel_path.count('/') result.append((Path(path), depth)) @@ -651,3 +657,6 @@ def __init__(self, target: str): def __repr__(self): return f"" + + +MemoryBackend = MemoryStorageBackend diff --git a/src/polystore/napari_stream.py b/src/polystore/napari_stream.py index 630bcc8..a0940bc 100644 --- a/src/polystore/napari_stream.py +++ b/src/polystore/napari_stream.py @@ -13,32 +13,74 @@ """ import logging -import time -from pathlib import Path -from typing import Any, List, Union - -import zmq - -from .constants import Backend, TransportMode -from .streaming_constants import StreamingDataType -from .streaming import StreamingBackend +from enum import Enum + +from .constants import Backend +from .streaming import ( + FilePath, + RoiStreamPayload, + StreamingBackend, + StreamingItemPreparationRequest, + ViewerDisplayPayloadExtra, +) +from .streaming.viewer_transport import ViewerStreamRequest from .roi_converters import NapariROIConverter -from zmqruntime.transport import get_zmq_transport_url, coerce_transport_mode +from zmqruntime.viewer_protocol import ( + ViewerBatchItemWireField, + ViewerWireMapping, + ViewerWireValue, +) logger = logging.getLogger(__name__) +class NapariDisplayWireField(str, Enum): + """Napari-specific display fields inside the shared viewer display payload.""" + + COLORMAP = "colormap" + VARIABLE_SIZE_HANDLING = "variable_size_handling" + + +class NapariDisplayPayload: + """Display payload projection for Napari stream messages.""" + + @staticmethod + def variable_size_handling_value(display_config): + variable_size_handling = display_config.variable_size_handling + if variable_size_handling is None: + return None + return variable_size_handling.value + + @classmethod + def from_display_config(cls, display_config) -> dict[str, ViewerWireValue]: + return { + NapariDisplayWireField.COLORMAP.value: display_config.get_colormap_name(), + NapariDisplayWireField.VARIABLE_SIZE_HANDLING.value: ( + cls.variable_size_handling_value(display_config) + ), + } + + class NapariStreamingBackend(StreamingBackend): """Napari streaming backend with automatic registration.""" _backend_type = Backend.NAPARI_STREAM.value - # Configure ABC attributes VIEWER_TYPE = 'napari' SHM_PREFIX = 'napari_' - # __init__, _get_publisher, save, cleanup now inherited from ABC + def _display_payload_extra( + self, + stream_request: ViewerStreamRequest, + ) -> ViewerDisplayPayloadExtra: + return ViewerDisplayPayloadExtra.from_mapping( + NapariDisplayPayload.from_display_config(stream_request.display_config) + ) - def _prepare_shapes_data(self, data: Any, file_path: Union[str, Path]) -> dict: + def _prepare_shapes_data( + self, + data: RoiStreamPayload, + file_path: FilePath, + ) -> dict[str, ViewerWireValue]: """ Prepare shapes data for transmission. @@ -52,107 +94,28 @@ def _prepare_shapes_data(self, data: Any, file_path: Union[str, Path]) -> dict: shapes_data = NapariROIConverter.rois_to_shapes(data) return { - 'path': str(file_path), - 'shapes': shapes_data, + ViewerBatchItemWireField.PATH.value: str(file_path), + ViewerBatchItemWireField.SHAPES.value: shapes_data, } - def _prepare_batch_item(self, data: Any, file_path: Union[str, Path], data_type): - if data_type in (StreamingDataType.SHAPES, StreamingDataType.POINTS): - item_data = self._prepare_shapes_data(data, file_path) - data_type_value = data_type.value + def _prepare_batch_item( + self, + request: StreamingItemPreparationRequest, + ) -> tuple[ViewerWireMapping, str]: + if request.data_type.uses_napari_vector_payload: + item_data = self._prepare_shapes_data( + request.data, + request.item_path.value, + ) + data_type_value = request.data_type.value else: - item_data = self._create_shared_memory(data, file_path) - data_type_value = data_type.value + item_data = self.create_shared_memory_payload( + request.data, + request.item_path.value, + ) + data_type_value = request.data_type.value return item_data, data_type_value - def save_batch(self, data_list: List[Any], file_paths: List[Union[str, Path]], **kwargs) -> None: - """ - Stream multiple images or ROIs to napari as a batch. - - Args: - data_list: List of image data or ROI lists - file_paths: List of path identifiers - **kwargs: Additional metadata - """ - # Filter to only supported file types - data_list, file_paths, skipped = self._filter_streamable_files(data_list, file_paths) - if not data_list: - return - - # Extract kwargs using generic polymorphic names - host = kwargs.get('host', 'localhost') - port = kwargs['port'] - transport_mode = kwargs['transport_mode'] - transport_config = kwargs.get('transport_config') - display_config = kwargs['display_config'] - microscope_handler = kwargs['microscope_handler'] - source = kwargs.get('source', 'unknown_source') # Pre-built source value - plate_path = kwargs.get('plate_path') - display_payload_extra = { - "colormap": display_config.get_colormap_name(), - "variable_size_handling": display_config.variable_size_handling.value - if hasattr(display_config, "variable_size_handling") and display_config.variable_size_handling - else None, - } - - message, batch_images, image_ids = self._build_batch_message( - data_list, - file_paths, - microscope_handler, - source, - display_config, - self._prepare_batch_item, - plate_path=plate_path, - display_payload_extra=display_payload_extra, - ) - - # Register sent images with queue tracker BEFORE sending - # This prevents race condition with IPC mode where acks arrive before registration - self._register_with_queue_tracker( - port, - image_ids, - transport_mode=transport_mode, - transport_config=transport_config, - ) - - # Create FRESH REQ socket for each send - REQ sockets cannot be reused - # This prevents the "Operation cannot be accomplished in current state" error - # when multiple streams happen concurrently - transport_config = transport_config or self._transport_config - url = get_zmq_transport_url( - port, - host=host, - mode=coerce_transport_mode(transport_mode), - config=transport_config, - ) - - if self._context is None: - self._context = zmq.Context() - - socket = self._context.socket(zmq.REQ) - socket.connect(url) - time.sleep(0.1) # Brief delay for connection to establish - - try: - # Send with REQ socket (BLOCKING - worker waits for Napari to acknowledge) - # Worker blocks until Napari receives, copies data from shared memory, and sends ack - # This guarantees no messages are lost and shared memory is only closed after Napari is done - logger.info(f"📤 NAPARI BACKEND: Sending batch of {len(batch_images)} images to Napari on port {port} (REQ/REP - blocking until ack)") - socket.send_json(message) # Blocking send - - # Wait for acknowledgment from Napari (REP socket) - # Napari will only reply after it has copied all data from shared memory - ack_response = socket.recv_json() - logger.info(f"✅ NAPARI BACKEND: Received ack from Napari: {ack_response.get('status', 'unknown')}") - - finally: - # Always close the socket - never reuse REQ sockets - socket.close() - - # Clean up publisher's handles after successful send - # Receiver will unlink the shared memory after copying the data - self._cleanup_shared_memory_blocks(batch_images, unlink=False) - # cleanup() now inherited from ABC def __del__(self): diff --git a/src/polystore/roi.py b/src/polystore/roi.py index fb6bdb6..26c1ef1 100644 --- a/src/polystore/roi.py +++ b/src/polystore/roi.py @@ -6,12 +6,14 @@ """ import logging +from abc import ABC, abstractmethod from dataclasses import dataclass, field from enum import Enum from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, ClassVar, Dict, List, Optional, Tuple, Union import numpy as np +from metaclass_registry import AutoRegisterMeta from .constants import Backend @@ -27,8 +29,14 @@ class ShapeType(Enum): ELLIPSE = "ellipse" +class ROIShape(ABC): + """Nominal base for all ROI shape records.""" + + shape_type: ShapeType + + @dataclass(frozen=True) -class PolygonShape: +class PolygonShape(ROIShape): """Polygon ROI shape defined by vertex coordinates.""" coordinates: np.ndarray # Nx2 array of (y, x) coordinates shape_type: ShapeType = field(default=ShapeType.POLYGON, init=False) @@ -41,7 +49,7 @@ def __post_init__(self): @dataclass(frozen=True) -class PolylineShape: +class PolylineShape(ROIShape): """Polyline ROI shape defined by path coordinates (open path, not closed polygon).""" coordinates: np.ndarray # Nx2 array of (y, x) coordinates shape_type: ShapeType = field(default=ShapeType.POLYLINE, init=False) @@ -54,7 +62,7 @@ def __post_init__(self): @dataclass(frozen=True) -class MaskShape: +class MaskShape(ROIShape): """Binary mask ROI shape.""" mask: np.ndarray # 2D boolean array bbox: Tuple[int, int, int, int] # (min_y, min_x, max_y, max_x) @@ -68,7 +76,7 @@ def __post_init__(self): @dataclass(frozen=True) -class PointShape: +class PointShape(ROIShape): """Point ROI shape.""" y: float x: float @@ -76,7 +84,7 @@ class PointShape: @dataclass(frozen=True) -class EllipseShape: +class EllipseShape(ROIShape): """Ellipse ROI shape.""" center_y: float center_x: float @@ -95,67 +103,253 @@ def __post_init__(self): if not self.shapes: raise ValueError("ROI must have at least one shape") for shape in self.shapes: - if not hasattr(shape, "shape_type"): - raise ValueError(f"Shape {shape} must have shape_type attribute") + if not isinstance(shape, ROIShape): + raise ValueError(f"Shape {shape} must be an ROIShape") -def extract_rois_from_labeled_mask( - labeled_mask: np.ndarray, - min_area: int = 10, - extract_contours: bool = True, -) -> List[ROI]: - """Extract ROIs from a labeled segmentation mask.""" - from skimage import measure - from skimage.measure import regionprops - from scipy.ndimage import find_objects +@dataclass(frozen=True, slots=True) +class LabeledMaskROIExtractionRequest: + """Request to extract ROIs from a labeled mask or stack.""" - if labeled_mask.ndim != 2: - raise ValueError(f"Labeled mask must be 2D, got shape {labeled_mask.shape}") + labeled_mask: np.ndarray + min_area: int = 10 + extract_contours: bool = True + spatial_origin_yx: Optional[Tuple[int, int]] = None + source_spatial_shape_yx: Optional[Tuple[int, int]] = None - if not np.issubdtype(labeled_mask.dtype, np.integer): - labeled_mask = labeled_mask.astype(np.int32) - regions = regionprops(labeled_mask) - slices = find_objects(labeled_mask) +class LabeledMaskROIExtractor(ABC, metaclass=AutoRegisterMeta): + """Registered extraction behavior for one labeled-mask dimensional family.""" - rois = [] - for region in regions: - if region.area < min_area: - continue - - metadata = { - "label": int(region.label), - "area": float(region.area), - "perimeter": float(region.perimeter), - "centroid": tuple(float(c) for c in region.centroid), - "bbox": tuple(int(b) for b in region.bbox), - } + __registry_key__ = "__name__" + __skip_if_no_key__ = True - shapes = [] - if extract_contours: - label_idx = region.label - 1 - if label_idx < len(slices) and slices[label_idx] is not None: - slice_y, slice_x = slices[label_idx] - cropped_mask = labeled_mask[slice_y, slice_x] - binary_mask = (cropped_mask == region.label).astype(np.uint8) - padded_mask = np.pad(binary_mask, pad_width=1, mode="constant", constant_values=0) - contours = measure.find_contours(padded_mask, level=0.5) - offset_y = slice_y.start - offset_x = slice_x.start - padding_offset = np.array([offset_y, offset_x]) - 1 - for contour in contours: - if len(contour) >= 3: - contour_full = contour + padding_offset - shapes.append(PolygonShape(coordinates=contour_full)) - else: - binary_mask = (labeled_mask == region.label) - shapes.append(MaskShape(mask=binary_mask, bbox=region.bbox)) + @classmethod + def for_request( + cls, + request: LabeledMaskROIExtractionRequest, + ) -> "LabeledMaskROIExtractor": + for extractor_type in cls.__registry__.values(): + extractor = extractor_type() + if extractor.accepts(request.labeled_mask): + return extractor + raise ValueError( + "No ROI extractor registered for labeled mask shape " + f"{request.labeled_mask.shape}." + ) - if shapes: - rois.append(ROI(shapes=shapes, metadata=metadata)) + @abstractmethod + def accepts(self, labeled_mask: np.ndarray) -> bool: + """Return whether this extractor owns the mask dimensionality.""" - logger.info(f"Extracted {len(rois)} ROIs from labeled mask") - return rois + @abstractmethod + def extract(self, request: LabeledMaskROIExtractionRequest) -> List[ROI]: + """Extract ROIs from the request.""" + + +class TwoDimensionalLabeledMaskROIExtractor(LabeledMaskROIExtractor): + """Extract ROIs from a single 2D labeled mask.""" + + def accepts(self, labeled_mask: np.ndarray) -> bool: + return labeled_mask.ndim == 2 + + def extract(self, request: LabeledMaskROIExtractionRequest) -> List[ROI]: + from skimage import measure + from skimage.measure import regionprops + from scipy.ndimage import find_objects + + labeled_mask = request.labeled_mask + if not np.issubdtype(labeled_mask.dtype, np.integer): + labeled_mask = labeled_mask.astype(np.int32) + + regions = regionprops(labeled_mask) + slices = find_objects(labeled_mask) + origin_y, origin_x = request.spatial_origin_yx or (0, 0) + + rois = [] + for region in regions: + if region.area < request.min_area: + continue + min_y, min_x, max_y, max_x = region.bbox + + metadata = { + "label": int(region.label), + "area": float(region.area), + "perimeter": float(region.perimeter), + "centroid": ( + float(region.centroid[0] + origin_y), + float(region.centroid[1] + origin_x), + ), + "bbox": ( + int(min_y + origin_y), + int(min_x + origin_x), + int(max_y + origin_y), + int(max_x + origin_x), + ), + } + if request.source_spatial_shape_yx is not None: + metadata["source_spatial_shape_yx"] = tuple( + int(value) for value in request.source_spatial_shape_yx + ) + + shapes = [] + if request.extract_contours: + label_idx = region.label - 1 + if label_idx < len(slices) and slices[label_idx] is not None: + slice_y, slice_x = slices[label_idx] + cropped_mask = labeled_mask[slice_y, slice_x] + binary_mask = (cropped_mask == region.label).astype(np.uint8) + padded_mask = np.pad(binary_mask, pad_width=1, mode="constant", constant_values=0) + contours = measure.find_contours(padded_mask, level=0.5) + offset_y = slice_y.start + offset_x = slice_x.start + padding_offset = np.array([offset_y + origin_y, offset_x + origin_x]) - 1 + for contour in contours: + if len(contour) >= 3: + contour_full = contour + padding_offset + shapes.append(PolygonShape(coordinates=contour_full)) + else: + binary_mask = labeled_mask == region.label + shapes.append(MaskShape(mask=binary_mask, bbox=metadata["bbox"])) + + if shapes: + rois.append(ROI(shapes=shapes, metadata=metadata)) + + logger.info(f"Extracted {len(rois)} ROIs from labeled mask") + return rois + + +class NonSpatialLabeledMaskROIExtractor(LabeledMaskROIExtractor): + """Treat scalar and otherwise non-spatial label payloads as empty ROI sets.""" + + def accepts(self, labeled_mask: np.ndarray) -> bool: + return labeled_mask.ndim < 2 + + def extract(self, request: LabeledMaskROIExtractionRequest) -> List[ROI]: + return [] + + +class StackedLabeledMaskROIExtractor(LabeledMaskROIExtractor): + """Extract ROIs from all 2D planes in a labeled-mask stack.""" + + def accepts(self, labeled_mask: np.ndarray) -> bool: + return labeled_mask.ndim > 2 + + def extract(self, request: LabeledMaskROIExtractionRequest) -> List[ROI]: + stack = request.labeled_mask + plane_shape = stack.shape[-2:] + leading_shape = stack.shape[:-2] + rois: list[ROI] = [] + for plane_indices in np.ndindex(leading_shape): + plane_request = LabeledMaskROIExtractionRequest( + labeled_mask=stack[plane_indices], + min_area=request.min_area, + extract_contours=request.extract_contours, + spatial_origin_yx=request.spatial_origin_yx, + source_spatial_shape_yx=request.source_spatial_shape_yx or plane_shape, + ) + for roi in TwoDimensionalLabeledMaskROIExtractor().extract(plane_request): + rois.append(self._with_plane_metadata(roi, plane_indices, leading_shape)) + return rois + + @staticmethod + def _with_plane_metadata( + roi: ROI, + plane_indices: tuple[int, ...], + leading_shape: tuple[int, ...], + ) -> ROI: + return ROI( + shapes=roi.shapes, + metadata={ + **roi.metadata, + "plane_indices": tuple(int(index) for index in plane_indices), + "plane_shape": tuple(int(size) for size in leading_shape), + }, + ) + + +class ROIJsonShapeDecoder(ABC, metaclass=AutoRegisterMeta): + """Decode one serialized ROI shape variant.""" + + __registry_key__ = "shape_type" + __skip_if_no_key__ = True + + shape_type: ClassVar[ShapeType | None] = None + + @classmethod + def for_serialized_shape(cls, shape_dict: Dict[str, Any]) -> "ROIJsonShapeDecoder | None": + shape_type = shape_dict.get("type") + try: + shape_key = ShapeType(shape_type) + except ValueError: + logger.warning(f"Unknown shape type: {shape_type}, skipping") + return None + return cls.__registry__[shape_key]() + + @abstractmethod + def decode(self, shape_dict: Dict[str, Any]) -> Any: + """Return the concrete ROI shape represented by ``shape_dict``.""" + + +class PolygonROIJsonShapeDecoder(ROIJsonShapeDecoder): + shape_type = ShapeType.POLYGON + + def decode(self, shape_dict: Dict[str, Any]) -> PolygonShape: + return PolygonShape(coordinates=np.array(shape_dict["coordinates"])) + + +class PolylineROIJsonShapeDecoder(ROIJsonShapeDecoder): + shape_type = ShapeType.POLYLINE + + def decode(self, shape_dict: Dict[str, Any]) -> PolylineShape: + return PolylineShape(coordinates=np.array(shape_dict["coordinates"])) + + +class MaskROIJsonShapeDecoder(ROIJsonShapeDecoder): + shape_type = ShapeType.MASK + + def decode(self, shape_dict: Dict[str, Any]) -> MaskShape: + return MaskShape( + mask=np.array(shape_dict["mask"], dtype=bool), + bbox=tuple(shape_dict["bbox"]), + ) + + +class PointROIJsonShapeDecoder(ROIJsonShapeDecoder): + shape_type = ShapeType.POINT + + def decode(self, shape_dict: Dict[str, Any]) -> PointShape: + return PointShape(y=shape_dict["y"], x=shape_dict["x"]) + + +class EllipseROIJsonShapeDecoder(ROIJsonShapeDecoder): + shape_type = ShapeType.ELLIPSE + + def decode(self, shape_dict: Dict[str, Any]) -> EllipseShape: + return EllipseShape( + center_y=shape_dict["center_y"], + center_x=shape_dict["center_x"], + radius_y=shape_dict["radius_y"], + radius_x=shape_dict["radius_x"], + ) + + +def extract_rois_from_labeled_mask( + labeled_mask: np.ndarray, + min_area: int = 10, + extract_contours: bool = True, + spatial_origin_yx: Optional[Tuple[int, int]] = None, + source_spatial_shape_yx: Optional[Tuple[int, int]] = None, +) -> List[ROI]: + """Extract ROIs from a labeled segmentation mask.""" + request = LabeledMaskROIExtractionRequest( + labeled_mask=np.asarray(labeled_mask), + min_area=min_area, + extract_contours=extract_contours, + spatial_origin_yx=spatial_origin_yx, + source_spatial_shape_yx=source_spatial_shape_yx, + ) + return LabeledMaskROIExtractor.for_request(request).extract(request) def _get_backend_from_filemanager(filemanager: Any, backend: Union[str, Backend]): @@ -203,31 +397,9 @@ def load_rois_from_json(json_path: Path) -> List[ROI]: metadata = roi_dict.get("metadata", {}) shapes = [] for shape_dict in roi_dict.get("shapes", []): - shape_type = shape_dict.get("type") - - if shape_type == "polygon": - coordinates = np.array(shape_dict["coordinates"]) - shapes.append(PolygonShape(coordinates=coordinates)) - elif shape_type == "polyline": - coordinates = np.array(shape_dict["coordinates"]) - shapes.append(PolylineShape(coordinates=coordinates)) - elif shape_type == "mask": - mask = np.array(shape_dict["mask"], dtype=bool) - bbox = tuple(shape_dict["bbox"]) - shapes.append(MaskShape(mask=mask, bbox=bbox)) - elif shape_type == "point": - shapes.append(PointShape(y=shape_dict["y"], x=shape_dict["x"])) - elif shape_type == "ellipse": - shapes.append( - EllipseShape( - center_y=shape_dict["center_y"], - center_x=shape_dict["center_x"], - radius_y=shape_dict["radius_y"], - radius_x=shape_dict["radius_x"], - ) - ) - else: - logger.warning(f"Unknown shape type: {shape_type}, skipping") + decoder = ROIJsonShapeDecoder.for_serialized_shape(shape_dict) + if decoder is not None: + shapes.append(decoder.decode(shape_dict)) if shapes: rois.append(ROI(shapes=shapes, metadata=metadata)) diff --git a/src/polystore/roi_converters.py b/src/polystore/roi_converters.py index 46e8631..616e4c4 100644 --- a/src/polystore/roi_converters.py +++ b/src/polystore/roi_converters.py @@ -7,63 +7,184 @@ """ import logging -from typing import Any, Dict, List, Tuple +from abc import ABC, abstractmethod +from dataclasses import dataclass +from enum import Enum +from typing import Any, ClassVar, Dict, List, Tuple import numpy as np +from metaclass_registry import AutoRegisterMeta -from .roi import EllipseShape, PointShape, PolygonShape, PolylineShape, ROI -from .streaming_constants import NapariShapeType +from .roi import EllipseShape, PointShape, PolygonShape, PolylineShape, ROI, ShapeType logger = logging.getLogger(__name__) -class NapariROIConverter: - """Convert ROI objects to Napari shapes format.""" +@dataclass(frozen=True, slots=True) +class NapariShapeTypeAlias: + """Inert alias from Napari wire shape names to ROI shape types.""" + + alias: str + shape_type: ShapeType + + +NAPARI_SHAPE_TYPE_ALIASES = ( + NapariShapeTypeAlias("path", ShapeType.POLYLINE), + NapariShapeTypeAlias("points", ShapeType.POINT), +) + + +class NapariShapeConverter(ABC, metaclass=AutoRegisterMeta): + """Registered conversion behavior for one ROI shape type.""" + + __registry_key__ = "shape_type" + __skip_if_no_key__ = True + + shape_type: ClassVar[ShapeType | None] = None + + @classmethod + def for_shape_dict(cls, shape_dict: Dict[str, Any]) -> "NapariShapeConverter": + return cls.__registry__[_shape_type_from_napari(shape_dict["type"])]() + + def append_common_properties( + self, + metadata: Dict[str, Any], + properties: dict[str, list[Any]], + centroid: tuple[Any, Any], + *, + area: Any | None = None, + ) -> None: + properties["label"].append(metadata.get("label", "")) + properties["area"].append(metadata.get("area", 0) if area is None else area) + properties["centroid_y"].append(centroid[0]) + properties["centroid_x"].append(centroid[1]) + + @abstractmethod + def add_dimensions(self, shape_dict: Dict[str, Any], prepend_dims: np.ndarray) -> np.ndarray: + """Add dimensions to a 2D shape to make it nD.""" + + @abstractmethod + def append_napari_format( + self, + shape_dict: Dict[str, Any], + napari_shapes: list[np.ndarray], + shape_types: list[str], + properties: dict[str, list[Any]], + ) -> None: + """Append this shape to a Napari layer payload.""" + + +def _shape_type_from_napari(shape_type: object) -> ShapeType: + if isinstance(shape_type, ShapeType): + return shape_type + value = str(shape_type.value) if isinstance(shape_type, Enum) else str(shape_type) + for alias in NAPARI_SHAPE_TYPE_ALIASES: + if alias.alias == value: + return alias.shape_type + return ShapeType(value) + + +class CoordinateNapariShapeConverter(NapariShapeConverter): + """Shared converter for coordinate-list shapes.""" - _SHAPE_DIMENSION_HANDLERS = { - "polygon": lambda shape_dict, prepend_dims: np.hstack( - [np.tile(prepend_dims, (len(shape_dict["coordinates"]), 1)), np.array(shape_dict["coordinates"])] - ), - "polyline": lambda shape_dict, prepend_dims: np.hstack( - [np.tile(prepend_dims, (len(shape_dict["coordinates"]), 1)), np.array(shape_dict["coordinates"])] - ), - "ellipse": lambda shape_dict, prepend_dims: np.hstack( + napari_shape_type: ClassVar[str] + + def add_dimensions(self, shape_dict: Dict[str, Any], prepend_dims: np.ndarray) -> np.ndarray: + coordinates = np.array(shape_dict["coordinates"]) + return np.hstack([np.tile(prepend_dims, (len(coordinates), 1)), coordinates]) + + def append_napari_format( + self, + shape_dict: Dict[str, Any], + napari_shapes: list[np.ndarray], + shape_types: list[str], + properties: dict[str, list[Any]], + ) -> None: + metadata = shape_dict.get("metadata", {}) + napari_shapes.append(np.array(shape_dict["coordinates"])) + shape_types.append(self.napari_shape_type) + self.append_common_properties( + metadata, + properties, + metadata.get("centroid", (0, 0)), + ) + + +class PolygonNapariShapeConverter(CoordinateNapariShapeConverter): + shape_type = ShapeType.POLYGON + napari_shape_type = "polygon" + + +class PolylineNapariShapeConverter(CoordinateNapariShapeConverter): + shape_type = ShapeType.POLYLINE + napari_shape_type = "path" + + +class EllipseNapariShapeConverter(NapariShapeConverter): + shape_type = ShapeType.ELLIPSE + + def add_dimensions(self, shape_dict: Dict[str, Any], prepend_dims: np.ndarray) -> np.ndarray: + center = shape_dict["center"] + radii = shape_dict["radii"] + corners = np.array( [ - np.tile(prepend_dims, (4, 1)), - np.array( - [ - [ - shape_dict["center"][0] - shape_dict["radii"][0], - shape_dict["center"][1] - shape_dict["radii"][1], - ], - [ - shape_dict["center"][0] - shape_dict["radii"][0], - shape_dict["center"][1] + shape_dict["radii"][1], - ], - [ - shape_dict["center"][0] + shape_dict["radii"][0], - shape_dict["center"][1] + shape_dict["radii"][1], - ], - [ - shape_dict["center"][0] + shape_dict["radii"][0], - shape_dict["center"][1] - shape_dict["radii"][1], - ], - ] - ), + [center[0] - radii[0], center[1] - radii[1]], + [center[0] - radii[0], center[1] + radii[1]], + [center[0] + radii[0], center[1] + radii[1]], + [center[0] + radii[0], center[1] - radii[1]], ] - ), - "point": lambda shape_dict, prepend_dims: np.concatenate([prepend_dims, shape_dict["coordinates"]]).reshape(1, -1), - } + ) + return np.hstack([np.tile(prepend_dims, (4, 1)), corners]) + + def append_napari_format( + self, + shape_dict: Dict[str, Any], + napari_shapes: list[np.ndarray], + shape_types: list[str], + properties: dict[str, list[Any]], + ) -> None: + metadata = shape_dict.get("metadata", {}) + center = np.array(shape_dict["center"]) + radii = np.array(shape_dict["radii"]) + napari_shapes.append(np.array([center - radii, center + radii])) + shape_types.append("ellipse") + self.append_common_properties( + metadata, + properties, + metadata.get("centroid", (0, 0)), + ) + + +class PointNapariShapeConverter(NapariShapeConverter): + shape_type = ShapeType.POINT + + def add_dimensions(self, shape_dict: Dict[str, Any], prepend_dims: np.ndarray) -> np.ndarray: + return np.concatenate([prepend_dims, shape_dict["coordinates"]]).reshape(1, -1) + + def append_napari_format( + self, + shape_dict: Dict[str, Any], + napari_shapes: list[np.ndarray], + shape_types: list[str], + properties: dict[str, list[Any]], + ) -> None: + metadata = shape_dict.get("metadata", {}) + coordinates = shape_dict["coordinates"] + napari_shapes.append(np.array([coordinates])) + shape_types.append("point") + self.append_common_properties(metadata, properties, coordinates, area=0) + + +class NapariROIConverter: + """Convert ROI objects to Napari shapes format.""" @staticmethod def add_dimensions_to_shape(shape_dict: Dict[str, Any], prepend_dims: List[float]) -> np.ndarray: """Add dimensions to a 2D shape to make it nD.""" - shape_type = shape_dict["type"] - shape_type_enum = NapariShapeType(shape_type) if isinstance(shape_type, str) else shape_type - handler = NapariROIConverter._SHAPE_DIMENSION_HANDLERS.get(shape_type_enum.value) - if handler is None: - raise ValueError(f"Unsupported shape type: {shape_type}") - return handler(shape_dict, np.array(prepend_dims)) + return NapariShapeConverter.for_shape_dict(shape_dict).add_dimensions( + shape_dict, + np.array(prepend_dims), + ) @staticmethod def rois_to_shapes(rois: List[ROI]) -> List[Dict[str, Any]]: @@ -104,40 +225,12 @@ def shapes_to_napari_format(shapes_data: List[Dict]) -> Tuple[List[np.ndarray], properties = {"label": [], "area": [], "centroid_y": [], "centroid_x": []} for shape_dict in shapes_data: - shape_type = shape_dict.get("type") - metadata = shape_dict.get("metadata", {}) - - if shape_type == "polygon": - coords = np.array(shape_dict["coordinates"]) - napari_shapes.append(coords) - shape_types.append("polygon") - centroid = metadata.get("centroid", (0, 0)) - properties["label"].append(metadata.get("label", "")) - properties["area"].append(metadata.get("area", 0)) - properties["centroid_y"].append(centroid[0]) - properties["centroid_x"].append(centroid[1]) - - elif shape_type == "ellipse": - center = np.array(shape_dict["center"]) - radii = np.array(shape_dict["radii"]) - corners = np.array([center - radii, center + radii]) - napari_shapes.append(corners) - shape_types.append("ellipse") - centroid = metadata.get("centroid", (0, 0)) - properties["label"].append(metadata.get("label", "")) - properties["area"].append(metadata.get("area", 0)) - properties["centroid_y"].append(centroid[0]) - properties["centroid_x"].append(centroid[1]) - - elif shape_type == "point": - coords = np.array([shape_dict["coordinates"]]) - napari_shapes.append(coords) - shape_types.append("point") - point_coords = shape_dict["coordinates"] - properties["label"].append(metadata.get("label", "")) - properties["area"].append(0) - properties["centroid_y"].append(point_coords[0]) - properties["centroid_x"].append(point_coords[1]) + NapariShapeConverter.for_shape_dict(shape_dict).append_napari_format( + shape_dict, + napari_shapes, + shape_types, + properties, + ) return napari_shapes, shape_types, properties diff --git a/src/polystore/streaming/__init__.py b/src/polystore/streaming/__init__.py index 8a8536e..2217b4b 100644 --- a/src/polystore/streaming/__init__.py +++ b/src/polystore/streaming/__init__.py @@ -10,7 +10,32 @@ # This allows both: # from polystore.streaming import StreamingBackend # from polystore.streaming.receivers import FijiBatchProcessor -from polystore.streaming._streaming_backend import StreamingBackend - -__all__ = ["StreamingBackend"] +from polystore.streaming._streaming_backend import ( + FilePath, + RoiStreamPayload, + StreamablePayload, + StreamingBatchItemPreparationAuthority, + StreamingBatchMessageBuilder, + StreamingBatchMessageRequest, + StreamingBuiltBatch, + StreamingPreparedBatchItems, + StreamingBackend, + StreamingComponentNamesRequest, + StreamingItemPreparationRequest, + ViewerDisplayPayloadExtra, +) +__all__ = [ + "FilePath", + "RoiStreamPayload", + "StreamablePayload", + "StreamingBatchItemPreparationAuthority", + "StreamingBatchMessageBuilder", + "StreamingBatchMessageRequest", + "StreamingBuiltBatch", + "StreamingPreparedBatchItems", + "StreamingBackend", + "StreamingComponentNamesRequest", + "StreamingItemPreparationRequest", + "ViewerDisplayPayloadExtra", +] diff --git a/src/polystore/streaming/_streaming_backend.py b/src/polystore/streaming/_streaming_backend.py index 417baa2..1f4976f 100644 --- a/src/polystore/streaming/_streaming_backend.py +++ b/src/polystore/streaming/_streaming_backend.py @@ -5,25 +5,508 @@ data to external systems without persistent storage capabilities. """ +from __future__ import annotations + import logging -import os import time import uuid +from collections.abc import Mapping, Sequence +from dataclasses import dataclass, field +from enum import Enum +from multiprocessing import resource_tracker, shared_memory from pathlib import Path -from typing import Any, Callable, List, Set, Union +from types import MappingProxyType +from typing import TypeAlias import numpy as np +import zmq +from arraybridge import convert_memory, detect_memory_type +from arraybridge.types import MemoryType as ArrayBridgeMemoryType from ..base import DataSink -from ..constants import TransportMode from ..streaming_constants import StreamingDataType from ..roi import ROI, PointShape from ..zmq_config import POLYSTORE_ZMQ_CONFIG +from .viewer_transport import ( + ViewerDisplayConfigABC, + ViewerMicroscopeHandlerABC, + ViewerStreamBackendKwargs, + ViewerStreamRequest, + ViewerTransportDefaults, +) from zmqruntime.ack_listener import GlobalAckListener -from zmqruntime.transport import coerce_transport_mode, get_zmq_transport_url +from zmqruntime.config import ZMQConfig +from zmqruntime.viewer_protocol import ( + ViewerBatchDisplayPayload, + ViewerBatchItemPayload, + ViewerBatchItemWireField, + ViewerBatchMessagePayload, + ViewerComponentMetadataPayload, + ViewerDisplayConfigWireField, + ViewerTransportEndpoint, + ViewerWireMapping, + ViewerWireValue, +) logger = logging.getLogger(__name__) +FilePath: TypeAlias = str | Path +RoiStreamPayload: TypeAlias = Sequence[ROI] +StreamablePayload: TypeAlias = np.ndarray | Sequence[ViewerWireValue] | RoiStreamPayload +ComponentValue = str | int | float | bool | tuple | None +ViewerDisplayPayloadExtraValues: TypeAlias = Mapping[ + str | ViewerDisplayConfigWireField, + ViewerWireValue, +] +STREAMING_TRANSPORT_DEFAULTS = ViewerTransportDefaults() + + +@dataclass(frozen=True) +class ViewerDisplayPayloadExtra: + """Nominal viewer-specific display payload extension.""" + + values: ViewerDisplayPayloadExtraValues = field( + default_factory=lambda: MappingProxyType({}) + ) + + @classmethod + def from_mapping( + cls, + values: ViewerDisplayPayloadExtraValues, + ) -> "ViewerDisplayPayloadExtra": + return cls(values) + + def to_wire_mapping(self) -> dict[str, ViewerWireValue]: + return dict(self.values) + + +EMPTY_DISPLAY_PAYLOAD_EXTRA = ViewerDisplayPayloadExtra() + + +class StreamingComponentValueDomainAuthority: + """Build batch-level component value domains from stream item metadata.""" + + @staticmethod + def wire_payload( + stream_request: ViewerStreamRequest, + batch_images: Sequence[ViewerWireMapping], + ) -> dict[str, ViewerWireValue]: + component_order = tuple( + str(component) + for component in stream_request.display_config.COMPONENT_ORDER + ) + values_by_component: dict[str, list[ComponentValue]] = { + component: [] for component in component_order + } + for image_payload in batch_images: + metadata = StreamingComponentValueDomainAuthority._metadata(image_payload) + for component in component_order: + if component not in metadata: + continue + value = StreamingComponentValueDomainAuthority._component_value( + metadata[component] + ) + if value not in values_by_component[component]: + values_by_component[component].append(value) + return { + component: values + for component, values in values_by_component.items() + if values + } + + @staticmethod + def _metadata(image_payload: ViewerWireMapping) -> ViewerWireMapping: + metadata = image_payload[ViewerBatchItemWireField.METADATA.value] + if not isinstance(metadata, Mapping): + raise TypeError( + "Streaming batch item metadata must be a mapping, " + f"got {type(metadata).__name__}." + ) + return dict(metadata) + + @staticmethod + def _component_value(value: ViewerWireValue) -> ComponentValue: + if isinstance(value, (str, int, float, bool)) or value is None: + return value + if isinstance(value, tuple): + return value + if isinstance(value, Sequence) and not isinstance(value, (str, bytes)): + return tuple(value) + raise TypeError( + "Streaming component values must be JSON scalar or tuple-like, " + f"got {type(value).__name__}." + ) + +@dataclass(frozen=True) +class StreamingComponentNamesRequest: + """Component-label metadata requested for one viewer batch.""" + + component_names: Sequence[str] + log_prefix: str | None = None + verbose: bool = False + + @classmethod + def from_stream_request( + cls, + stream_request: ViewerStreamRequest, + log_prefix: str | None = None, + verbose: bool = False, + ) -> "StreamingComponentNamesRequest": + return cls( + component_names=tuple( + str(component) + for component in stream_request.display_config.COMPONENT_ORDER + ), + log_prefix=log_prefix, + verbose=verbose, + ) + + +@dataclass(frozen=True) +class StreamingBatchMessageRequest: + """Inputs for building one viewer batch message.""" + + data_list: list[StreamablePayload] + file_paths: list[FilePath] + stream_request: ViewerStreamRequest + component_names_request: StreamingComponentNamesRequest | None = None + display_payload_extra: ViewerDisplayPayloadExtra = field( + default_factory=ViewerDisplayPayloadExtra + ) + + def resolved_component_names_request(self) -> StreamingComponentNamesRequest: + if self.component_names_request is not None: + return self.component_names_request + return StreamingComponentNamesRequest.from_stream_request( + self.stream_request + ) + + +@dataclass(frozen=True) +class StreamingPreparedBatchItems: + """Prepared per-item viewer payloads before batch-level metadata is attached.""" + + batch_images: list[dict[str, ViewerWireValue]] + image_ids: list[str] + + +@dataclass(frozen=True) +class StreamingBuiltBatch(StreamingPreparedBatchItems): + """Prepared viewer message and per-item transmission bookkeeping.""" + + message: dict[str, ViewerWireValue] + + +@dataclass(frozen=True) +class StreamingItemPath: + """Nominal path identity for one item in a viewer stream batch.""" + + value: FilePath + + @property + def wire_value(self) -> str: + return str(self.value) + + +@dataclass(frozen=True) +class StreamingPayloadFileRequest: + """Shared payload/file identity for viewer item preparation requests.""" + + data: StreamablePayload + item_path: StreamingItemPath + + +@dataclass(frozen=True) +class StreamingItemPreparationRequest(StreamingPayloadFileRequest): + """Inputs needed to prepare one payload for a viewer batch item.""" + + data_type: StreamingDataType + + +@dataclass(frozen=True) +class StreamingSharedMemoryRequest(StreamingPayloadFileRequest): + """Inputs needed to allocate one image payload into shared memory.""" + + shm_prefix: str + + +@dataclass(frozen=True) +class StreamingSharedMemoryPayload: + """Wire payload describing a shared-memory image allocation.""" + + item_path: StreamingItemPath + shape: tuple[int, ...] + dtype: str + shm_name: str + + def to_wire_mapping(self) -> dict[str, ViewerWireValue]: + return { + ViewerBatchItemWireField.PATH.value: self.item_path.wire_value, + ViewerBatchItemWireField.SHAPE.value: self.shape, + ViewerBatchItemWireField.DTYPE.value: self.dtype, + ViewerBatchItemWireField.SHM_NAME.value: self.shm_name, + } + + +@dataclass(frozen=True) +class StreamingSharedMemoryBlock: + """Allocated shared memory and the wire payload that names it.""" + + shared_memory: shared_memory.SharedMemory + payload: StreamingSharedMemoryPayload + + +class StreamingPayloadMemoryAuthority: + """Memory conversion authority for streamable image payloads.""" + + @staticmethod + def to_numpy(data: StreamablePayload) -> np.ndarray: + if isinstance(data, np.ndarray): + return data + if isinstance(data, (list, tuple)): + return np.asarray(data) + return convert_memory( + data, + detect_memory_type(data), + ArrayBridgeMemoryType.NUMPY.value, + gpu_id=0, + ) + + +class StreamingSharedMemoryAuthority: + """Allocate image payloads for viewer transfer through shared memory.""" + + @classmethod + def create( + cls, + request: StreamingSharedMemoryRequest, + ) -> StreamingSharedMemoryBlock: + np_data = StreamingPayloadMemoryAuthority.to_numpy(request.data) + shm_name = cls._name(request.shm_prefix) + shm = shared_memory.SharedMemory( + create=True, + size=np_data.nbytes, + name=shm_name, + ) + resource_tracker.unregister(shm._name, "shared_memory") + + shm_array = np.ndarray(np_data.shape, dtype=np_data.dtype, buffer=shm.buf) + shm_array[:] = np_data[:] + + return StreamingSharedMemoryBlock( + shared_memory=shm, + payload=StreamingSharedMemoryPayload( + item_path=request.item_path, + shape=tuple(int(dimension) for dimension in np_data.shape), + dtype=str(np_data.dtype), + shm_name=shm_name, + ), + ) + + @staticmethod + def _name(shm_prefix: str) -> str: + return f"{shm_prefix}{uuid.uuid4().hex[:12]}" + + +class StreamingDataTypeAuthority: + """Detect the viewer payload kind for one streamed object.""" + + @staticmethod + def detect(data: StreamablePayload) -> StreamingDataType: + is_roi = isinstance(data, list) and len(data) > 0 and isinstance(data[0], ROI) + + if not is_roi: + return StreamingDataType.IMAGE + + all_points = all( + roi.shapes and all(isinstance(shape, PointShape) for shape in roi.shapes) + for roi in data + ) + + return StreamingDataType.POINTS if all_points else StreamingDataType.SHAPES + + +class StreamingComponentNamesMetadataCollector: + """Collect viewer component-label metadata for one batch.""" + + @staticmethod + def collect( + plate_path: FilePath | None, + microscope_handler: ViewerMicroscopeHandlerABC, + request: StreamingComponentNamesRequest, + ) -> dict[str, ViewerWireValue]: + component_names_metadata = {} + + if plate_path is None: + if request.verbose and request.log_prefix: + logger.warning("%s: No plate_path in kwargs", request.log_prefix) + return component_names_metadata + + for component_name in request.component_names: + metadata = microscope_handler.metadata_handler.get_component_values( + plate_path, + component_name, + ) + if request.verbose and request.log_prefix: + logger.info( + "%s: Got %s metadata: %s", + request.log_prefix, + component_name, + metadata, + ) + if metadata: + component_names_metadata[component_name] = metadata + + return component_names_metadata + + +class StreamingDisplayPayloadBuilder: + """Build the shared viewer display-config payload.""" + + @staticmethod + def build( + stream_request: ViewerStreamRequest, + display_payload_extra: ViewerDisplayPayloadExtra, + ) -> ViewerBatchDisplayPayload: + return ViewerBatchDisplayPayload( + component_modes={ + str(component): str(mode.value if isinstance(mode, Enum) else mode) + for component, mode in stream_request.display_config.component_modes().items() + }, + component_order=tuple( + str(component) + for component in stream_request.display_config.COMPONENT_ORDER + ), + extra=display_payload_extra.to_wire_mapping(), + ) + + +class StreamingBatchItemPreparationAuthority: + """Prepare per-item viewer payloads and transmission bookkeeping.""" + + @staticmethod + def prepare( + backend: "StreamingBackend", + request: StreamingBatchMessageRequest, + ) -> StreamingPreparedBatchItems: + batch_images = [] + image_ids = [] + + for index, (data, file_path) in enumerate( + zip(request.data_list, request.file_paths) + ): + item_path = StreamingItemPath(file_path) + image_id = str(uuid.uuid4()) + image_ids.append(image_id) + + data_type = StreamingDataTypeAuthority.detect(data) + explicit_component_metadata = ( + request.stream_request.source.metadata.component_metadata_for_item( + item_path.value, + index, + ) + ) + item_data, data_type_value = backend._prepare_batch_item( + StreamingItemPreparationRequest( + data=data, + item_path=item_path, + data_type=data_type, + ) + ) + + batch_images.append( + ViewerBatchItemPayload.from_parts( + item_payload=item_data, + data_type=data_type_value, + metadata=explicit_component_metadata, + producer_identity=( + request.stream_request.producer_identity.to_payload() + ), + image_id=image_id, + ).to_wire_mapping() + ) + + return StreamingPreparedBatchItems( + batch_images=batch_images, + image_ids=image_ids, + ) + + +class StreamingComponentMetadataPayloadAuthority: + """Resolve the component metadata payload for one viewer batch.""" + + @staticmethod + def payload( + request: StreamingBatchMessageRequest, + prepared_items: StreamingPreparedBatchItems, + ) -> ViewerComponentMetadataPayload: + declared = ViewerComponentMetadataPayload.from_optional_wire_mapping( + request.stream_request.message_extra_payload() + ) + if declared is not None: + return declared + return ViewerComponentMetadataPayload( + component_names_metadata=( + StreamingComponentNamesMetadataCollector.collect( + request.stream_request.source.identity.plate_path, + request.stream_request.source.identity.microscope_handler, + request.resolved_component_names_request(), + ) + ), + component_value_domain=( + StreamingComponentValueDomainAuthority.wire_payload( + request.stream_request, + prepared_items.batch_images, + ) + ), + ) + + +class StreamingBatchMessageBuilder: + """Build complete viewer batch messages from prepared items.""" + + @classmethod + def build( + cls, + backend: "StreamingBackend", + request: StreamingBatchMessageRequest, + ) -> StreamingBuiltBatch: + if len(request.data_list) != len(request.file_paths): + raise ValueError("data_list and file_paths must have the same length") + + prepared_items = StreamingBatchItemPreparationAuthority.prepare( + backend, + request, + ) + + component_metadata_payload = ( + StreamingComponentMetadataPayloadAuthority.payload( + request, + prepared_items, + ) + ) + + display_payload = StreamingDisplayPayloadBuilder.build( + request.stream_request, + request.display_payload_extra, + ) + message = ViewerBatchMessagePayload.from_parts( + images=prepared_items.batch_images, + display_payload=display_payload, + component_metadata=component_metadata_payload, + timestamp=time.time(), + extra=ViewerComponentMetadataPayload.strip_component_metadata( + backend._message_extra(request.stream_request) + ), + ).to_wire_mapping() + + return StreamingBuiltBatch( + message=message, + batch_images=prepared_items.batch_images, + image_ids=prepared_items.image_ids, + ) + + class StreamingBackend(DataSink): """ Abstract base class for ZeroMQ-based streaming backends. @@ -42,8 +525,8 @@ class StreamingBackend(DataSink): """ # Abstract class attributes that subclasses must define - VIEWER_TYPE: str = None - SHM_PREFIX: str = None + VIEWER_TYPE: str + SHM_PREFIX: str # Class attribute: streaming backends only support image array data and ROIs supports_arbitrary_files: bool = False @@ -59,9 +542,9 @@ def requires_filesystem_validation(self) -> bool: def _filter_streamable_files( self, - data_list: List[Any], - file_paths: List[Union[str, Path]], - ) -> tuple[List[Any], List[Union[str, Path]], List[Union[str, Path]]]: + data_list: list[StreamablePayload], + file_paths: list[FilePath], + ) -> tuple[list[StreamablePayload], list[FilePath], list[FilePath]]: """ Filter data to only include files with supported extensions. @@ -97,155 +580,33 @@ def _filter_streamable_files( return filtered_data, filtered_paths, skipped_paths - def __init__(self, transport_config=None): + def __init__(self, transport_config: ZMQConfig = POLYSTORE_ZMQ_CONFIG): """Initialize ZeroMQ and shared memory infrastructure.""" self._publishers = {} self._context = None self._shared_memory_blocks = {} - self._transport_config = transport_config or POLYSTORE_ZMQ_CONFIG - - def _get_publisher(self, host: str, port: int, transport_mode: TransportMode, transport_config=None): - """ - Lazy initialization of ZeroMQ publisher (common for all streaming backends). - - Uses REQ socket for Fiji (synchronous request/reply with blocking) - and PUB socket for Napari (broadcast pattern). - - Args: - host: Host to connect to (ignored for IPC mode) - port: Port to connect to - transport_mode: IPC or TCP transport (required - comes from config) - - Returns: - ZeroMQ publisher socket - """ - # Generate transport URL using centralized function - transport_config = transport_config or self._transport_config - url = get_zmq_transport_url( - port, - host=host, - mode=coerce_transport_mode(transport_mode), - config=transport_config, - ) - - key = url # Use URL as key instead of host:port - if key not in self._publishers: - try: - import zmq - if self._context is None: - self._context = zmq.Context() - - # Use REQ socket for all viewers (synchronous request/reply) - # All viewers must send acknowledgment after processing - publisher = self._context.socket(zmq.REQ) - - publisher.connect(url) - socket_name = "REQ" - logger.info(f"{self.VIEWER_TYPE} streaming {socket_name} socket connected to {url}") - time.sleep(0.1) - self._publishers[key] = publisher - - except ImportError: - logger.error("ZeroMQ not available - streaming disabled") - raise RuntimeError("ZeroMQ required for streaming") - - return self._publishers[key] - - def _parse_component_metadata(self, file_path: Union[str, Path], microscope_handler, - source: str) -> dict: - """ - Parse component metadata from filename (common for all streaming backends). - - Args: - file_path: Path to parse - microscope_handler: Handler with parser - source: Pre-built source value (step_name during execution, subdir when loading from disk) - - Returns: - Component metadata dict with source added - """ - filename = os.path.basename(str(file_path)) - component_metadata = microscope_handler.parser.parse_filename(filename) - - # Add pre-built source value directly - component_metadata['source'] = source - - return component_metadata - - def _detect_data_type(self, data: Any): - """ - Detect if data is ROI (shapes/points) or image (common for all streaming backends). - - Args: - data: Data to check + self._transport_config = transport_config - Returns: - StreamingDataType enum value (IMAGE, SHAPES, or POINTS) - """ - is_roi = isinstance(data, list) and len(data) > 0 and isinstance(data[0], ROI) - - if not is_roi: - return StreamingDataType.IMAGE - - # Check if all ROIs contain only PointShape objects (for points layer) - all_points = all( - roi.shapes and all(isinstance(shape, PointShape) for shape in roi.shapes) - for roi in data + def create_shared_memory_payload( + self, + data: StreamablePayload, + file_path: FilePath, + ) -> dict[str, ViewerWireValue]: + block = StreamingSharedMemoryAuthority.create( + StreamingSharedMemoryRequest( + data=data, + item_path=StreamingItemPath(file_path), + shm_prefix=self.SHM_PREFIX, + ) ) - - return StreamingDataType.POINTS if all_points else StreamingDataType.SHAPES - - def _create_shared_memory(self, data: Any, file_path: Union[str, Path]) -> dict: - """ - Create shared memory for image data (common for all streaming backends). - - Args: - data: Image data to put in shared memory - file_path: Path identifier - - Returns: - Dict with shared memory metadata - """ - # Convert to numpy - np_data = data.cpu().numpy() if hasattr(data, 'cpu') else \ - data.get() if hasattr(data, 'get') else np.asarray(data) - - # Create shared memory with hash-based naming to avoid "File name too long" errors - # Hash the timestamp and object ID to create a short, unique name - from multiprocessing import shared_memory, resource_tracker - import hashlib - timestamp = time.time_ns() - obj_id = id(data) - hash_input = f"{obj_id}_{timestamp}" - hash_suffix = hashlib.md5(hash_input.encode()).hexdigest()[:8] - shm_name = f"{self.SHM_PREFIX}{hash_suffix}" - shm = shared_memory.SharedMemory(create=True, size=np_data.nbytes, name=shm_name) - - # Unregister from resource tracker - we manage cleanup manually - # This prevents resource tracker warnings when worker processes exit - # before the viewer has unlinked the shared memory - try: - resource_tracker.unregister(shm._name, "shared_memory") - except Exception: - pass # Ignore errors if already unregistered - - shm_array = np.ndarray(np_data.shape, dtype=np_data.dtype, buffer=shm.buf) - shm_array[:] = np_data[:] - self._shared_memory_blocks[shm_name] = shm - - return { - 'path': str(file_path), - 'shape': np_data.shape, - 'dtype': str(np_data.dtype), - 'shm_name': shm_name, - } + self._shared_memory_blocks[block.payload.shm_name] = block.shared_memory + return block.payload.to_wire_mapping() def _register_with_queue_tracker( self, - port: int, - image_ids: List[str], - transport_mode: TransportMode | None = None, - transport_config=None, + transport_endpoint: ViewerTransportEndpoint, + image_ids: list[str], + transport_config: ZMQConfig, ) -> None: """ Register sent images with queue tracker (common for all streaming backends). @@ -255,168 +616,143 @@ def _register_with_queue_tracker( image_ids: List of image IDs to register """ listener = GlobalAckListener() - transport_config = transport_config or self._transport_config listener.start( port=transport_config.shared_ack_port, - transport_mode=coerce_transport_mode(transport_mode), + transport_mode=transport_endpoint.resolved_transport_mode(), config=transport_config, ) from zmqruntime.queue_tracker import GlobalQueueTrackerRegistry registry = GlobalQueueTrackerRegistry() - tracker = registry.get_or_create_tracker(port, self.VIEWER_TYPE) + tracker = registry.get_or_create_tracker( + transport_endpoint.port, + self.VIEWER_TYPE, + ) for image_id in image_ids: tracker.register_sent(image_id) - def _build_component_modes(self, display_config) -> dict: - component_modes = {} - for comp_name in display_config.COMPONENT_ORDER: - mode_field = f"{comp_name}_mode" - if hasattr(display_config, mode_field): - mode = getattr(display_config, mode_field) - component_modes[comp_name] = mode.value - return component_modes - - def _build_display_config_base(self, display_config, component_modes: dict) -> dict: - return { - "component_modes": component_modes, - "component_order": display_config.COMPONENT_ORDER, - } + def _cleanup_shared_memory_blocks(self, batch_images, unlink: bool = False) -> None: + for img in batch_images: + shm_name = img.get(ViewerBatchItemWireField.SHM_NAME.value) + if shm_name and shm_name in self._shared_memory_blocks: + try: + shm = self._shared_memory_blocks.pop(shm_name) + shm.close() + if unlink: + shm.unlink() + except Exception as e: + logger.warning(f"Failed to cleanup shared memory {shm_name}: {e}") - def _collect_component_names_metadata( + def _prepare_batch_item( self, - plate_path, - microscope_handler, - component_names: List[str] | None = None, - log_prefix: str | None = None, - verbose: bool = False, - ) -> dict: - component_names = component_names or ["channel", "well", "site"] - component_names_metadata = {} - - if not plate_path or not microscope_handler: - if verbose and log_prefix: - if not plate_path: - logger.warning(f"{log_prefix}: No plate_path in kwargs") - if not microscope_handler: - logger.warning(f"{log_prefix}: No microscope_handler") - return component_names_metadata - - try: - for comp_name in component_names: - method_name = f"get_{comp_name}_values" - method = getattr(microscope_handler.metadata_handler, method_name, None) - if callable(method): - try: - metadata = method(plate_path) - if verbose and log_prefix: - logger.info(f"{log_prefix}: Got {comp_name} metadata: {metadata}") - if metadata: - component_names_metadata[comp_name] = metadata - except Exception as e: - if verbose and log_prefix: - logger.warning(f"{log_prefix}: Could not get {comp_name} metadata: {e}", exc_info=True) - elif verbose and log_prefix: - logger.info(f"{log_prefix}: No method {method_name} on metadata_handler") - except Exception as e: - if verbose and log_prefix: - logger.warning(f"{log_prefix}: Could not get component metadata: {e}", exc_info=True) + request: StreamingItemPreparationRequest, + ) -> tuple[ViewerWireMapping, str]: + raise NotImplementedError - return component_names_metadata - - def _prepare_batch_items( + def _display_payload_extra( self, - data_list: List[Any], - file_paths: List[Union[str, Path]], - microscope_handler, - source: str, - prepare_item: Callable[[Any, Union[str, Path], Any], tuple[dict, str]], - ) -> tuple[list[dict], list[str]]: - batch_images = [] - image_ids = [] - - for data, file_path in zip(data_list, file_paths): - image_id = str(uuid.uuid4()) - image_ids.append(image_id) + stream_request: ViewerStreamRequest, + ) -> ViewerDisplayPayloadExtra: + return EMPTY_DISPLAY_PAYLOAD_EXTRA - data_type = self._detect_data_type(data) - component_metadata = self._parse_component_metadata( - file_path, microscope_handler, source - ) - item_data, data_type_value = prepare_item(data, file_path, data_type) - - batch_images.append( - { - **item_data, - "data_type": data_type_value, - "metadata": component_metadata, - "image_id": image_id, - } - ) + def _message_extra( + self, + stream_request: ViewerStreamRequest, + ) -> dict[str, ViewerWireValue]: + return stream_request.message_extra_payload() - return batch_images, image_ids + def _component_names_request( + self, + stream_request: ViewerStreamRequest, + ) -> StreamingComponentNamesRequest: + return StreamingComponentNamesRequest.from_stream_request(stream_request) - def _build_batch_message( + def _after_batch_message_built( self, - data_list: List[Any], - file_paths: List[Union[str, Path]], - microscope_handler, - source: str, - display_config, - prepare_item: Callable[[Any, Union[str, Path], Any], tuple[dict, str]], - plate_path: Union[str, Path, None] = None, - component_names_kwargs: dict | None = None, - display_payload_extra: dict | None = None, - message_extra: dict | None = None, - ) -> tuple[dict, list[dict], list[str]]: - if len(data_list) != len(file_paths): - raise ValueError("data_list and file_paths must have the same length") + stream_request: ViewerStreamRequest, + built_batch: StreamingBuiltBatch, + ) -> None: + pass - batch_images, image_ids = self._prepare_batch_items( + def save_batch( + self, + data_list: list[StreamablePayload], + file_paths: list[FilePath], + **kwargs, + ) -> None: + """Stream a batch of image or ROI payloads to this viewer.""" + data_list, file_paths, _skipped_paths = self._filter_streamable_files( data_list, file_paths, - microscope_handler, - source, - prepare_item, ) + if not data_list: + return + + stream_request = ViewerStreamBackendKwargs.from_kwargs(kwargs).stream_request + built_batch = StreamingBatchMessageBuilder.build( + self, + StreamingBatchMessageRequest( + data_list=data_list, + file_paths=file_paths, + stream_request=stream_request, + component_names_request=self._component_names_request(stream_request), + display_payload_extra=self._display_payload_extra(stream_request), + ), + ) + self._after_batch_message_built(stream_request, built_batch) - component_modes = self._build_component_modes(display_config) - - component_names_metadata = self._collect_component_names_metadata( - plate_path, - microscope_handler, - **(component_names_kwargs or {}), + transport_config = stream_request.transport_config.resolve( + self._transport_config ) + transport_endpoint = stream_request.viewer_transport + self._register_with_queue_tracker( + transport_endpoint, + built_batch.image_ids, + transport_config=transport_config, + ) + url = transport_endpoint.data_url(transport_config) - display_payload = self._build_display_config_base(display_config, component_modes) - if display_payload_extra: - display_payload.update(display_payload_extra) + if self._context is None: + self._context = zmq.Context() - message = { - "type": "batch", - "images": batch_images, - "display_config": display_payload, - "component_names_metadata": component_names_metadata, - "timestamp": time.time(), - } - if message_extra: - message.update(message_extra) + viewer_name = str(self.VIEWER_TYPE).title() + viewer_label = viewer_name.upper() + ack_policy = STREAMING_TRANSPORT_DEFAULTS.ack_policy(viewer_name) + socket = self._context.socket(zmq.REQ) + ack_policy.apply_socket_options(socket) + socket.connect(url) + time.sleep(0.1) - return message, batch_images, image_ids + try: + logger.info( + "📤 %s BACKEND: Sending batch of %d images to %s on port %s " + "(REQ/REP - blocking until ack)", + viewer_label, + len(built_batch.batch_images), + viewer_name, + transport_endpoint.port, + ) + socket.send_json(built_batch.message) + ack_response = ack_policy.receive( + socket, + lambda: self._cleanup_shared_memory_blocks( + built_batch.batch_images, + unlink=True, + ), + port=transport_endpoint.port, + ) + logger.info( + "✅ %s BACKEND: Received ack from %s: %s", + viewer_label, + viewer_name, + ack_policy.status(ack_response), + ) + finally: + socket.close() - def _cleanup_shared_memory_blocks(self, batch_images, unlink: bool = False) -> None: - for img in batch_images: - shm_name = img.get("shm_name") - if shm_name and shm_name in self._shared_memory_blocks: - try: - shm = self._shared_memory_blocks.pop(shm_name) - shm.close() - if unlink: - shm.unlink() - except Exception as e: - logger.warning(f"Failed to cleanup shared memory {shm_name}: {e}") + self._cleanup_shared_memory_blocks(built_batch.batch_images, unlink=False) - def save(self, data: Any, file_path: Union[str, Path], **kwargs) -> None: + def save(self, data: StreamablePayload | str, file_path: FilePath, **kwargs) -> None: """ Stream single item (common for all streaming backends). diff --git a/src/polystore/streaming/base.py b/src/polystore/streaming/base.py index 5cbfec0..c0bfb70 100644 --- a/src/polystore/streaming/base.py +++ b/src/polystore/streaming/base.py @@ -27,7 +27,6 @@ class TypedData(Generic[T]): """ items: List[T] metadata: Dict[str, Any] - source: str class ComponentAccessor(ABC): diff --git a/src/polystore/streaming/identity.py b/src/polystore/streaming/identity.py new file mode 100644 index 0000000..1c543a6 --- /dev/null +++ b/src/polystore/streaming/identity.py @@ -0,0 +1,241 @@ +"""Nominal stream identity records shared by viewer streaming backends.""" + +from __future__ import annotations + +from dataclasses import dataclass +from enum import Enum +from typing import ClassVar, Mapping, Sequence, TypeAlias + + +StreamProducerPayloadValue: TypeAlias = str | int | None +StreamProducerPayloadMapping: TypeAlias = Mapping[str, StreamProducerPayloadValue] +RouteKeyPart: TypeAlias = str | int | float | bool | None + + +class StreamProducerOrigin(str, Enum): + """Nominal stream producer origin values.""" + + PIPELINE = "pipeline" + MANUAL = "manual" + DIRECT = "direct" + + +class FixedStreamProducerIdentityKind(str, Enum): + """Producer identities whose origin and output kind are intentionally equal.""" + + MANUAL = StreamProducerOrigin.MANUAL.value + DIRECT = StreamProducerOrigin.DIRECT.value + + +class StreamProducerIdentityPayload(dict[str, StreamProducerPayloadValue]): + """Wire payload for one stream producer identity.""" + + @classmethod + def from_identity( + cls, + identity: "StreamProducerIdentity", + ) -> "StreamProducerIdentityPayload": + return cls( + origin=identity.origin, + output_kind=identity.output_kind, + output_key=identity.output_key, + step_name=identity.step_name, + pipeline_position=identity.pipeline_position, + step_scope_id=identity.step_scope_id, + invocation_key=identity.invocation_key, + artifact_kind=identity.artifact_kind, + ) + + +@dataclass(frozen=True, slots=True) +class StreamProducerIdentity: + """Producer/output identity for one streamed viewer item.""" + + origin: str + output_kind: str + output_key: str + step_name: str | None = None + pipeline_position: int | None = None + step_scope_id: str | None = None + invocation_key: str | None = None + artifact_kind: str | None = None + + @classmethod + def pipeline_output( + cls, + *, + output_kind: str, + output_key: str, + step_name: str, + pipeline_position: int | None, + step_scope_id: str | None = None, + artifact_kind: str | None = None, + ) -> "StreamProducerIdentity": + """Build identity for one pipeline-produced stream output.""" + return cls( + origin=StreamProducerOrigin.PIPELINE.value, + output_kind=output_kind, + output_key=output_key, + step_name=step_name, + pipeline_position=pipeline_position, + step_scope_id=step_scope_id, + artifact_kind=artifact_kind, + ) + + @classmethod + def fixed_output( + cls, + kind: FixedStreamProducerIdentityKind, + output_key: str, + ) -> "StreamProducerIdentity": + """Build identity for producer kinds whose origin owns the output kind.""" + return cls( + origin=kind.value, + output_kind=kind.value, + output_key=output_key, + ) + + @classmethod + def from_payload( + cls, + payload: "StreamProducerIdentity | StreamProducerPayloadMapping", + ) -> "StreamProducerIdentity": + if isinstance(payload, cls): + return payload + if not isinstance(payload, Mapping): + raise TypeError( + "Stream producer identity must be a mapping or StreamProducerIdentity, " + f"got {type(payload).__name__}." + ) + return cls( + origin=_required_payload_str(payload, "origin"), + output_kind=_required_payload_str(payload, "output_kind"), + output_key=_required_payload_str(payload, "output_key"), + step_name=_optional_payload_str(payload, "step_name"), + pipeline_position=_optional_payload_int(payload, "pipeline_position"), + step_scope_id=_optional_payload_str(payload, "step_scope_id"), + invocation_key=_optional_payload_str(payload, "invocation_key"), + artifact_kind=_optional_payload_str(payload, "artifact_kind"), + ) + + def to_payload(self) -> StreamProducerIdentityPayload: + return StreamProducerIdentityPayload.from_identity(self) + + def route_parts(self) -> tuple[str, ...]: + parts = [ + f"origin_{self.origin}", + f"kind_{self.output_kind}", + f"out_{self.output_key}", + ] + if self.pipeline_position is not None: + parts.append(f"step_{self.pipeline_position}") + if self.step_scope_id: + parts.append(f"scope_{self.step_scope_id}") + if self.step_name: + parts.append(f"name_{self.step_name}") + if self.invocation_key: + parts.append(f"invocation_{self.invocation_key}") + if self.artifact_kind: + parts.append(f"artifact_{self.artifact_kind}") + return tuple(parts) + + +def _required_payload_str( + payload: StreamProducerPayloadMapping, + field_name: str, +) -> str: + if field_name not in payload: + raise ValueError( + f"Stream producer identity missing required field: {field_name}" + ) + value = payload[field_name] + if value in (None, ""): + raise ValueError( + f"Stream producer identity missing required field: {field_name}" + ) + return str(value) + + +def _optional_payload_str( + payload: StreamProducerPayloadMapping, + field_name: str, +) -> str | None: + if field_name not in payload: + return None + return _optional_str(payload[field_name]) + + +def _optional_payload_int( + payload: StreamProducerPayloadMapping, + field_name: str, +) -> int | None: + if field_name not in payload: + return None + value = payload[field_name] + if value is None: + return None + return int(value) + + +def _optional_str(value: StreamProducerPayloadValue) -> str | None: + if value is None: + return None + text = str(value) + return text or None + + +class StreamProducerDisplayNameAuthority: + """Own user-facing labels derived from stream producer identity.""" + + PIPELINE_DISPLAY_INDEX_BASE: ClassVar[int] = 1 + OUTPUT_KEY_OMITTING_KINDS: ClassVar[frozenset[str]] = frozenset( + {"main", "manual", "direct"} + ) + + @staticmethod + def producer_base(producer: StreamProducerIdentity) -> str: + if producer.step_name: + return producer.step_name + return producer.output_key + + @classmethod + def producer_label(cls, producer: StreamProducerIdentity) -> str: + base = cls.producer_base(producer) + if producer.pipeline_position is None: + return base + return f"{producer.pipeline_position + cls.PIPELINE_DISPLAY_INDEX_BASE}. {base}" + + @classmethod + def output_label(cls, producer: StreamProducerIdentity) -> str: + parts = [cls.producer_label(producer)] + if cls.includes_output_key(producer): + parts.append(producer.output_key) + return " ".join(part for part in parts if part) + + @classmethod + def disambiguation_label(cls, producer: StreamProducerIdentity) -> str: + if producer.pipeline_position is not None: + return f"step {producer.pipeline_position + cls.PIPELINE_DISPLAY_INDEX_BASE}" + return producer.output_key or producer.origin + + @classmethod + def includes_output_key(cls, producer: StreamProducerIdentity) -> bool: + if not producer.output_key: + return False + if producer.output_kind in cls.OUTPUT_KEY_OMITTING_KINDS: + return False + return producer.output_key != cls.producer_base(producer) + + +class StreamRouteKeyAuthority: + """Own stable key-token projection for viewer route keys.""" + + @staticmethod + def token(value: RouteKeyPart) -> str: + return str(value).replace("/", "_").replace("\\", "_").replace(" ", "_") + + @classmethod + def join(cls, parts: Sequence[RouteKeyPart]) -> str: + if not parts: + raise ValueError("Cannot build a stream route key with no parts.") + return "_".join(cls.token(part) for part in parts) diff --git a/src/polystore/streaming/receivers/__init__.py b/src/polystore/streaming/receivers/__init__.py index b1876be..c6d58db 100644 --- a/src/polystore/streaming/receivers/__init__.py +++ b/src/polystore/streaming/receivers/__init__.py @@ -9,13 +9,15 @@ WindowProjectionABC, DebouncedBatchEngine, GroupedWindowItems, + WindowProjectionPayloadProvider, + WindowProjectionSource, group_items_by_component_modes, ) from polystore.streaming.receivers.fiji.fiji_batch_processor import FijiBatchProcessor from polystore.streaming.receivers.napari import ( NapariBatchProcessor, normalize_component_layout, - build_layer_key, + build_route_key, ) __all__ = [ @@ -23,9 +25,11 @@ "WindowProjectionABC", "DebouncedBatchEngine", "GroupedWindowItems", + "WindowProjectionPayloadProvider", + "WindowProjectionSource", "group_items_by_component_modes", "FijiBatchProcessor", "NapariBatchProcessor", "normalize_component_layout", - "build_layer_key", + "build_route_key", ] diff --git a/src/polystore/streaming/receivers/core/__init__.py b/src/polystore/streaming/receivers/core/__init__.py index 084d686..786f526 100644 --- a/src/polystore/streaming/receivers/core/__init__.py +++ b/src/polystore/streaming/receivers/core/__init__.py @@ -7,6 +7,8 @@ ) from polystore.streaming.receivers.core.window_projection import ( GroupedWindowItems, + WindowProjectionPayloadProvider, + WindowProjectionSource, group_items_by_component_modes, ) @@ -15,6 +17,7 @@ "WindowProjectionABC", "DebouncedBatchEngine", "GroupedWindowItems", + "WindowProjectionPayloadProvider", + "WindowProjectionSource", "group_items_by_component_modes", ] - diff --git a/src/polystore/streaming/receivers/core/window_projection.py b/src/polystore/streaming/receivers/core/window_projection.py index 4987960..308df65 100644 --- a/src/polystore/streaming/receivers/core/window_projection.py +++ b/src/polystore/streaming/receivers/core/window_projection.py @@ -2,87 +2,194 @@ from __future__ import annotations +from abc import ABC, abstractmethod +from collections.abc import Mapping, Sequence from dataclasses import dataclass -from pathlib import Path -from typing import Any, Callable - - -WindowValueNormalizer = Callable[[str, Any, dict[str, Any], str | None], Any] - - -@dataclass(frozen=True) -class GroupedWindowItems: +from typing import Generic, TypeVar + +from polystore.streaming.identity import ( + StreamProducerDisplayNameAuthority, + StreamProducerIdentity, + StreamRouteKeyAuthority, +) +from zmqruntime.viewer_protocol import ( + ViewerBatchDisplayPayload, + ViewerBatchItemWireField, + ViewerComponentMode, + ViewerWireMapping, + ViewerWireValue, +) + + +WINDOW_COMPONENT_MODES = ( + ViewerComponentMode.WINDOW, + ViewerComponentMode.CHANNEL, + ViewerComponentMode.SLICE, + ViewerComponentMode.FRAME, +) +WindowLabel = tuple[str, ViewerWireValue] +WindowProjectionItemT = TypeVar("WindowProjectionItemT") +WindowProjectionProviderT = TypeVar( + "WindowProjectionProviderT", + bound="WindowProjectionPayloadProvider", +) + + +class WindowProjectionPayloadProvider(ABC): + """Item that can expose its viewer wire payload for window projection.""" + + @abstractmethod + def window_projection_payload(self) -> Mapping[str, ViewerWireValue]: + """Return the wire payload used for component/window projection.""" + + +class WindowItemPayload(dict[str, ViewerWireValue]): + """Normalized wire payload retained for one projected window item.""" + + @classmethod + def from_mapping( + cls, + payload: Mapping[str, ViewerWireValue], + ) -> "WindowItemPayload": + return cls(dict(payload)) + + +@dataclass(frozen=True, slots=True) +class GroupedWindowItems(Generic[WindowProjectionItemT]): """Projection result for a single batch.""" window_components: list[str] channel_components: list[str] slice_components: list[str] frame_components: list[str] - windows: dict[str, list[dict[str, Any]]] - fixed_window_labels: dict[str, list[tuple[str, Any]]] - - -def _default_normalizer( - component_name: str, - value: Any, - item: dict[str, Any], - images_dir: str | None, -) -> Any: - """Normalize window component values for stable keying across payload types.""" - data_type = item.get("data_type") - if component_name == "source" and images_dir and data_type == "rois": - value_str = str(value) - if "_results" in value_str or "/" in value_str: - return Path(images_dir).name - return value + windows: dict[str, list[WindowProjectionItemT]] + fixed_window_labels: dict[str, tuple[WindowLabel, ...]] + + +@dataclass(frozen=True, slots=True) +class WindowProjectionSource(Generic[WindowProjectionItemT]): + """Validated receiver item source used by window projection.""" + + item: WindowProjectionItemT + payload: Mapping[str, ViewerWireValue] + metadata: ViewerWireMapping + producer: StreamProducerIdentity + + @classmethod + def from_wire_payload( + cls, + payload: Mapping[str, ViewerWireValue], + ) -> "WindowProjectionSource[WindowItemPayload]": + window_payload = WindowItemPayload.from_mapping(payload) + return cls.from_item(window_payload, window_payload) + + @classmethod + def from_wire_payloads( + cls, + payloads: Sequence[Mapping[str, ViewerWireValue]], + ) -> list["WindowProjectionSource[WindowItemPayload]"]: + return [cls.from_wire_payload(payload) for payload in payloads] + + @classmethod + def from_payload_provider( + cls, + item: WindowProjectionProviderT, + ) -> "WindowProjectionSource[WindowProjectionProviderT]": + return cls.from_item(item, item.window_projection_payload()) + + @classmethod + def from_payload_providers( + cls, + items: Sequence[WindowProjectionProviderT], + ) -> list["WindowProjectionSource[WindowProjectionProviderT]"]: + return [cls.from_payload_provider(item) for item in items] + + @classmethod + def from_item( + cls, + item: WindowProjectionItemT, + payload: Mapping[str, ViewerWireValue], + ) -> "WindowProjectionSource[WindowProjectionItemT]": + metadata = cls._required_mapping( + payload, + ViewerBatchItemWireField.METADATA.value, + ) + producer_identity = cls._required_mapping( + payload, + ViewerBatchItemWireField.PRODUCER_IDENTITY.value, + ) + return cls( + item=item, + payload=payload, + metadata=metadata, + producer=StreamProducerIdentity.from_payload(producer_identity), + ) + + @staticmethod + def _required_mapping( + payload: Mapping[str, ViewerWireValue], + field_name: str, + ) -> ViewerWireMapping: + if field_name not in payload: + raise ValueError( + f"Viewer window projection item missing required field {field_name!r}." + ) + value = payload[field_name] + if not isinstance(value, Mapping): + raise TypeError( + f"Viewer window projection item field {field_name!r} must be a mapping, " + f"got {type(value).__name__}." + ) + return dict(value) def group_items_by_component_modes( - items: list[dict[str, Any]], - component_modes: dict[str, str], - component_order: list[str], - *, - images_dir: str | None = None, - normalizer: WindowValueNormalizer | None = None, -) -> GroupedWindowItems: + items: Sequence[WindowProjectionSource[WindowProjectionItemT]], + display_layout: ViewerBatchDisplayPayload, +) -> GroupedWindowItems[WindowProjectionItemT]: """Project items into window groups using declared component modes.""" - if normalizer is None: - normalizer = _default_normalizer - - result: dict[str, list[str]] = { - "window": [], - "channel": [], - "slice": [], - "frame": [], - } - for comp_name in component_order: - mode = component_modes[comp_name] - result[mode].append(comp_name) - - window_components = result["window"] - channel_components = result["channel"] - slice_components = result["slice"] - frame_components = result["frame"] - - windows: dict[str, list[dict[str, Any]]] = {} - fixed_window_labels: dict[str, list[tuple[str, Any]]] = {} + mode_groups = display_layout.component_mode_groups(WINDOW_COMPONENT_MODES) + mode_groups.require_all_supported("window projection") + + window_components = list( + mode_groups.components_for_mode(ViewerComponentMode.WINDOW) + ) + channel_components = list( + mode_groups.components_for_mode(ViewerComponentMode.CHANNEL) + ) + slice_components = list( + mode_groups.components_for_mode(ViewerComponentMode.SLICE) + ) + frame_components = list( + mode_groups.components_for_mode(ViewerComponentMode.FRAME) + ) + + windows: dict[str, list[WindowProjectionItemT]] = {} + fixed_window_labels: dict[str, tuple[WindowLabel, ...]] = {} for item in items: - meta = item.get("metadata", {}) - key_parts: list[str] = [] - fixed_labels: list[tuple[str, Any]] = [] + key_parts: list[str] = list(item.producer.route_parts()) + fixed_labels: list[WindowLabel] = [ + ( + "producer", + StreamProducerDisplayNameAuthority.output_label(item.producer), + ) + ] for comp in window_components: - if comp not in meta: - continue - value = normalizer(comp, meta[comp], item, images_dir) + if comp not in item.metadata: + raise ValueError( + f"Viewer window projection item missing window component {comp!r}." + ) + value = item.metadata[comp] key_parts.append(f"{comp}_{value}") fixed_labels.append((comp, value)) - window_key = "_".join(key_parts) if key_parts else "default_window" - windows.setdefault(window_key, []).append(item) + window_key = StreamRouteKeyAuthority.join(key_parts) if window_key not in fixed_window_labels: - fixed_window_labels[window_key] = fixed_labels + windows[window_key] = [] + fixed_window_labels[window_key] = tuple(fixed_labels) + windows[window_key].append(item.item) return GroupedWindowItems( window_components=window_components, @@ -92,4 +199,3 @@ def group_items_by_component_modes( windows=windows, fixed_window_labels=fixed_window_labels, ) - diff --git a/src/polystore/streaming/receivers/fiji/fiji_batch_processor.py b/src/polystore/streaming/receivers/fiji/fiji_batch_processor.py index 60e0a3b..b1a95f8 100644 --- a/src/polystore/streaming/receivers/fiji/fiji_batch_processor.py +++ b/src/polystore/streaming/receivers/fiji/fiji_batch_processor.py @@ -57,6 +57,7 @@ def add_items( display_config: Dict[str, Any], images_dir: str, component_names_metadata: Dict[str, Any], + component_value_domain: Dict[str, Any], ): """ Add items to the batch for processing. @@ -65,13 +66,15 @@ def add_items( window_key: Unique identifier for the Fiji window items: List of items to add (images) display_config: Display configuration dict - images_dir: Source image subdirectory + images_dir: Artifact image directory context. component_names_metadata: Component name mappings for dimension labels + component_value_domain: Component value domains for axis cardinality """ context = { "display_config": display_config, "images_dir": images_dir, "component_names_metadata": component_names_metadata, + "component_value_domain": component_value_domain, "window_key": window_key, } self._engine.enqueue(items=items, context=context) @@ -90,15 +93,17 @@ def _process_batch(self, items: List[Dict[str, Any]], context: Dict[str, Any]) - display_config = context["display_config"] images_dir = context["images_dir"] component_names_metadata = context["component_names_metadata"] + component_value_domain = context["component_value_domain"] window_key = context["window_key"] logger.info( "FijiBatchProcessor: Processing batch of %d items for window '%s'", len(items), window_key, ) - self.fiji_server._process_items_from_batch( + self.fiji_server.batch_processor.process_wire_items( items=items, - display_config_dict=display_config, + display_config=display_config, images_dir=images_dir, component_names_metadata=component_names_metadata, + component_value_domain=component_value_domain, ) diff --git a/src/polystore/streaming/receivers/napari/__init__.py b/src/polystore/streaming/receivers/napari/__init__.py index 9ece1bf..6472c4c 100644 --- a/src/polystore/streaming/receivers/napari/__init__.py +++ b/src/polystore/streaming/receivers/napari/__init__.py @@ -3,7 +3,7 @@ from polystore.streaming.receivers.napari.napari_batch_processor import NapariBatchProcessor from polystore.streaming.receivers.napari.layer_key import ( normalize_component_layout, - build_layer_key, + build_route_key, ) -__all__ = ["NapariBatchProcessor", "normalize_component_layout", "build_layer_key"] +__all__ = ["NapariBatchProcessor", "normalize_component_layout", "build_route_key"] diff --git a/src/polystore/streaming/receivers/napari/layer_key.py b/src/polystore/streaming/receivers/napari/layer_key.py index dec6fff..c382853 100644 --- a/src/polystore/streaming/receivers/napari/layer_key.py +++ b/src/polystore/streaming/receivers/napari/layer_key.py @@ -1,46 +1,93 @@ -"""Canonical napari layer-key construction from component metadata.""" +"""Canonical napari route-key construction.""" from __future__ import annotations -from typing import Any +from collections.abc import Mapping +from polystore.streaming.identity import StreamProducerIdentity, StreamRouteKeyAuthority from polystore.streaming_constants import StreamingDataType +from zmqruntime.viewer_protocol import ( + ViewerBatchDisplayPayload, + ViewerComponentMode, + ViewerDisplayConfigWireField, + ViewerWireMapping, + ViewerWireValue, +) -def normalize_component_layout(display_config: Any) -> tuple[dict[str, str], list[str]]: - """Return canonical (component_modes, component_order) from display config.""" +def normalize_component_layout( + display_config: ViewerBatchDisplayPayload | ViewerWireMapping, +) -> ViewerBatchDisplayPayload: + """Return canonical display layout from a viewer display-config payload.""" + if isinstance(display_config, ViewerBatchDisplayPayload): + return display_config if isinstance(display_config, dict): - component_modes = display_config["component_modes"] - component_order = display_config["component_order"] - return component_modes, component_order - - component_order = list(display_config.COMPONENT_ORDER) - component_modes: dict[str, str] = {} - for component in component_order: - mode_field = f"{component}_mode" - mode_value = display_config.__getattribute__(mode_field) - component_modes[component] = mode_value.value - return component_modes, component_order - - -def build_layer_key( - component_info: dict[str, Any], - component_modes: dict[str, str], - component_order: list[str], + return ViewerBatchDisplayPayload( + component_modes=_required_mapping( + display_config, + ViewerDisplayConfigWireField.COMPONENT_MODES.value, + ), + component_order=_required_sequence( + display_config, + ViewerDisplayConfigWireField.COMPONENT_ORDER.value, + ), + ) + + raise TypeError( + "Napari component layout requires ViewerBatchDisplayPayload or mapping, " + f"got {type(display_config).__name__}." + ) + + +def build_route_key( + producer_identity: StreamProducerIdentity | Mapping[str, ViewerWireValue], + component_info: Mapping[str, ViewerWireValue], + display_layout: ViewerBatchDisplayPayload, data_type: StreamingDataType, ) -> str: - """Build canonical layer key from slice-mode components and payload type.""" - layer_key_parts: list[str] = [] - for component in component_order: - mode = component_modes[component] - if mode == "slice" and component in component_info: - layer_key_parts.append(f"{component}_{component_info[component]}") + """Build hidden route key from producer identity, slice components, and type.""" + producer = StreamProducerIdentity.from_payload(producer_identity) + route_parts: list[str] = list(producer.route_parts()) + for component in display_layout.components_for_mode(ViewerComponentMode.SLICE): + if component not in component_info: + raise ValueError( + f"Napari route key missing slice component {component!r}." + ) + route_parts.append(f"{component}_{component_info[component]}") + + route_key = StreamRouteKeyAuthority.join(route_parts) + + return f"{route_key}{data_type.napari_layer_suffix}" + - layer_key = "_".join(layer_key_parts) if layer_key_parts else "default_layer" +def _required_mapping( + payload: Mapping[str, ViewerWireValue], + field_name: str, +) -> dict[str, str]: + if field_name not in payload: + raise ValueError(f"Display config missing required field {field_name!r}.") + value = payload[field_name] + if not isinstance(value, Mapping): + raise TypeError( + f"Display config field {field_name!r} must be a mapping, " + f"got {type(value).__name__}." + ) + return { + str(component): str(mode) + for component, mode in value.items() + } - if data_type == StreamingDataType.SHAPES: - return f"{layer_key}_shapes" - if data_type == StreamingDataType.POINTS: - return f"{layer_key}_points" - return layer_key +def _required_sequence( + payload: Mapping[str, ViewerWireValue], + field_name: str, +) -> list[str]: + if field_name not in payload: + raise ValueError(f"Display config missing required field {field_name!r}.") + value = payload[field_name] + if isinstance(value, str) or not isinstance(value, list | tuple): + raise TypeError( + f"Display config field {field_name!r} must be a sequence, " + f"got {type(value).__name__}." + ) + return [str(component) for component in value] diff --git a/src/polystore/streaming/receivers/napari/napari_batch_processor.py b/src/polystore/streaming/receivers/napari/napari_batch_processor.py index b8dcbdd..e6e80d5 100644 --- a/src/polystore/streaming/receivers/napari/napari_batch_processor.py +++ b/src/polystore/streaming/receivers/napari/napari_batch_processor.py @@ -1,20 +1,45 @@ import logging -from typing import Any, Dict, List, Optional - -from polystore.streaming.receivers.core import DebouncedBatchEngine +from dataclasses import dataclass +from collections.abc import Sequence +from typing import Generic, Optional, TypeVar logger = logging.getLogger(__name__) +NapariBatchItemT = TypeVar("NapariBatchItemT") +NapariDisplayPayloadT = TypeVar("NapariDisplayPayloadT") +NapariComponentNamesMetadataT = TypeVar("NapariComponentNamesMetadataT") + + +@dataclass(frozen=True) +class NapariBatchDisplayRequest( + Generic[ + NapariBatchItemT, + NapariDisplayPayloadT, + NapariComponentNamesMetadataT, + ] +): + """Nominal request for one debounced Napari display update.""" + + layer_key: str + items: Sequence[NapariBatchItemT] + display_payload: NapariDisplayPayloadT + component_names_metadata: NapariComponentNamesMetadataT + + def dispatch_to(self, napari_server) -> None: + napari_server.display_layer_batch( + layer_key=self.layer_key, + items=self.items, + display_payload=self.display_payload, + component_names_metadata=self.component_names_metadata, + ) class NapariBatchProcessor: """ - Batch processor for Napari viewer with configurable batching strategies. - - Accumulates items and displays them based on batch_size configuration: - - None: Wait for all items in operation, then display once - - N: Display every N items incrementally - - Uses debouncing to collect items arriving in rapid succession. + Batch processor for Napari viewer display operations. + + Napari layer mutation must run on the Qt event-loop thread. OpenHCS owns that + Qt-thread debounce before this processor is called, so this class only + adapts batch payloads into the server display operation. """ def __init__( @@ -29,22 +54,15 @@ def __init__( Args: napari_server: Reference to NapariViewerServer for display operations - batch_size: Number of items to batch before displaying - None = wait for all (default), N = display every N items - debounce_delay_ms: Wait time after last item before processing (ms) - max_debounce_wait_ms: Maximum total wait time before forcing display (ms) + batch_size: Reserved for compatibility with viewer configuration + debounce_delay_ms: Qt-thread debounce delay owned by the caller + max_debounce_wait_ms: Reserved for compatibility with viewer configuration """ self.napari_server = napari_server self.batch_size = batch_size self.debounce_delay_ms = debounce_delay_ms self.max_debounce_wait_ms = max_debounce_wait_ms - self._engine = DebouncedBatchEngine( - process_fn=self._process_batch, - debounce_delay_ms=debounce_delay_ms, - max_debounce_wait_ms=max_debounce_wait_ms, - ) - logger.info( f"NapariBatchProcessor: Created with batch_size={batch_size}, " f"debounce={debounce_delay_ms}ms, max_wait={max_debounce_wait_ms}ms" @@ -53,27 +71,25 @@ def __init__( def add_items( self, layer_key: str, - items: List[Dict[str, Any]], - display_config: Dict[str, Any], - component_names_metadata: Dict[str, Any], + items: Sequence[NapariBatchItemT], + display_payload: NapariDisplayPayloadT, + component_names_metadata: NapariComponentNamesMetadataT, ): """ - Add items to the batch for processing. + Display items already released by the Qt-thread debounce. Args: layer_key: Unique identifier for the layer items: List of items to add (images or ROIs) - display_config: Display configuration dict + display_payload: Viewer-owned display payload object component_names_metadata: Component name mappings for dimension labels """ - self._engine.enqueue( + NapariBatchDisplayRequest( + layer_key=layer_key, items=items, - context={ - "display_config": display_config, - "component_names_metadata": component_names_metadata, - "layer_key": layer_key, - }, - ) + display_payload=display_payload, + component_names_metadata=component_names_metadata, + ).dispatch_to(self.napari_server) logger.debug( "NapariBatchProcessor: Added %d items to batch for layer '%s'", len(items), @@ -81,14 +97,4 @@ def add_items( ) def flush(self) -> None: - """Force immediate processing of the pending batch.""" - self._engine.flush() - - def _process_batch(self, items: List[Dict[str, Any]], context: Dict[str, Any]) -> None: - """Process callback used by shared debounced batch engine.""" - self.napari_server._display_layer_batch( - layer_key=context["layer_key"], - items=items, - display_config=context["display_config"], - component_names_metadata=context["component_names_metadata"], - ) + """Compatibility no-op; OpenHCS owns the Qt-thread debounce timer.""" diff --git a/src/polystore/streaming/viewer_transport.py b/src/polystore/streaming/viewer_transport.py new file mode 100644 index 0000000..5d2207d --- /dev/null +++ b/src/polystore/streaming/viewer_transport.py @@ -0,0 +1,332 @@ +"""Nominal transport helpers for blocking viewer stream backends.""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from collections.abc import Mapping, Sequence +from dataclasses import dataclass, field +from enum import Enum +from pathlib import Path +from typing import ( + ClassVar, + TypeAlias, +) + +from polystore.registry import AutoRegisterMeta +from polystore.streaming.identity import StreamProducerIdentity +from zmqruntime.config import ZMQConfig +from zmqruntime.viewer_protocol import ( + ViewerAckPolicy, + ViewerTransportEndpoint, + ViewerTransportMode, + ViewerWireMapping, + ViewerWireValue, +) + + +DisplayComponentToken: TypeAlias = str | Enum +DisplayModeToken: TypeAlias = str | Enum | None +ViewerComponentMetadataByPath: TypeAlias = ( + Mapping[str, ViewerWireMapping] + | Sequence[ViewerWireMapping] + | None +) + + +class ViewerDisplayConfigABC(ABC): + """Display-config surface required by viewer streaming backends.""" + + COMPONENT_ORDER: Sequence[DisplayComponentToken] + + @abstractmethod + def component_modes(self) -> Mapping[DisplayComponentToken, DisplayModeToken]: + """Return mode assignments by display component.""" + + +class ViewerFilenameParserABC(ABC): + """Filename parser surface needed by viewer streaming metadata.""" + + @abstractmethod + def parse_filename(self, filename: str) -> ViewerWireMapping | None: + """Return component metadata parsed from a filename.""" + + +class ViewerMetadataHandlerABC(ABC): + """Metadata-handler surface needed by viewer component labels.""" + + @abstractmethod + def get_component_values( + self, + plate_path: str | Path | None, + component_name: str, + ) -> ViewerWireValue: + """Return display-name metadata for one component.""" + + +class ViewerMicroscopeHandlerABC(ABC): + """Microscope-handler surface used by viewer streaming.""" + + parser: ViewerFilenameParserABC + metadata_handler: ViewerMetadataHandlerABC + + +class ViewerTransportConfigSelection(ABC, metaclass=AutoRegisterMeta): + """Nominal selection of the transport config used for one stream request.""" + + __registry_key__ = "registry_key" + registry_key: ClassVar[str | None] = None + + @classmethod + def select(cls, value) -> "ViewerTransportConfigSelection": + if isinstance(value, cls): + return value + for selection_type in cls.__registry__.values(): + if selection_type.accepts(value): + return selection_type.from_raw(value) + raise TypeError( + "transport_config must be a ZMQConfig, " + "ViewerTransportConfigSelection, or None." + ) + + @classmethod + @abstractmethod + def accepts(cls, value) -> bool: + """Return whether this registered selection can adapt the raw value.""" + + @classmethod + @abstractmethod + def from_raw(cls, value) -> "ViewerTransportConfigSelection": + """Adapt the raw value into a concrete transport-config selection.""" + + @abstractmethod + def resolve(self, default_transport_config: ZMQConfig) -> ZMQConfig: + """Return the concrete config for this request.""" + + +@dataclass(frozen=True) +class DefaultViewerTransportConfig(ViewerTransportConfigSelection): + """Use the backend's configured transport settings.""" + + registry_key: ClassVar[str] = "default" + + @classmethod + def accepts(cls, value) -> bool: + return value is None + + @classmethod + def from_raw(cls, value) -> "DefaultViewerTransportConfig": + return cls() + + def resolve(self, default_transport_config: ZMQConfig) -> ZMQConfig: + return default_transport_config + + +@dataclass(frozen=True) +class ExplicitViewerTransportConfig(ViewerTransportConfigSelection): + """Use a caller-supplied transport config for this request.""" + + registry_key: ClassVar[str] = "explicit" + + config: ZMQConfig + + @classmethod + def accepts(cls, value) -> bool: + return isinstance(value, ZMQConfig) + + @classmethod + def from_raw(cls, value) -> "ExplicitViewerTransportConfig": + return cls(value) + + def resolve(self, default_transport_config: ZMQConfig) -> ZMQConfig: + return self.config + + +@dataclass(frozen=True) +class ViewerTransportDefaults: + """Declared transport defaults shared by viewer streaming backends.""" + + ack_timeout_ms: int = 30_000 + + def ack_policy(self, viewer_name: str) -> ViewerAckPolicy: + return ViewerAckPolicy( + viewer_name=viewer_name, + timeout_ms=self.ack_timeout_ms, + ) + + +@dataclass(frozen=True) +class ViewerStreamSourceMetadata: + """Component metadata authority for streamed source items.""" + + component_metadata: ViewerWireMapping | None = None + component_metadata_by_path: ViewerComponentMetadataByPath = None + + def __post_init__(self) -> None: + if ( + self.component_metadata is not None + and self.component_metadata_by_path is not None + ): + raise ValueError( + "Viewer stream source context accepts either component_metadata " + "or component_metadata_by_path, not both." + ) + + def component_metadata_for_item( + self, + file_path: str | Path, + index: int, + ) -> dict[str, ViewerWireValue]: + """Return explicit component metadata for one batch item.""" + if self.component_metadata_by_path is None: + return self._batch_component_metadata(file_path) + + if isinstance(self.component_metadata_by_path, Mapping): + return self._mapping_component_metadata( + file_path, + self.component_metadata_by_path, + ) + + if index < len(self.component_metadata_by_path): + return self.component_metadata_by_path[index] + + raise IndexError( + "Viewer stream component_metadata_by_path has no entry for " + f"item {index} at {file_path!r}." + ) + + def _batch_component_metadata( + self, + file_path: str | Path, + ) -> dict[str, ViewerWireValue]: + if self.component_metadata is None: + raise ValueError( + "Viewer stream item requires explicit component_metadata or " + f"component_metadata_by_path; got no metadata for {file_path!r}." + ) + return self._metadata_payload( + self.component_metadata, + f"batch metadata for {file_path!r}", + ) + + def _mapping_component_metadata( + self, + file_path: str | Path, + metadata_by_path: Mapping[str, ViewerWireMapping], + ) -> dict[str, ViewerWireValue]: + path = Path(file_path) + for key in (str(file_path), path.as_posix(), path.name): + if key in metadata_by_path: + return self._metadata_payload( + metadata_by_path[key], + f"path metadata for {file_path!r}", + ) + raise KeyError( + "Viewer stream component_metadata_by_path has no entry for " + f"{file_path!r}." + ) + + @staticmethod + def _metadata_payload( + value: ViewerWireMapping, + source_label: str, + ) -> dict[str, ViewerWireValue]: + if not isinstance(value, Mapping): + raise TypeError( + "Viewer stream component metadata must be a mapping " + f"for {source_label}; got {type(value).__name__}." + ) + return dict(value) + + +@dataclass(frozen=True) +class ViewerStreamSourceIdentity: + """Stable source identity shared by all stream batches for one plate.""" + + microscope_handler: ViewerMicroscopeHandlerABC + plate_path: str | Path | None = None + + +class ViewerStreamKwarg(str, Enum): + """Raw kwarg names accepted at the top-level viewer stream boundary.""" + + STREAM_REQUEST = "stream_request" + + +@dataclass(frozen=True) +class ViewerStreamSource: + """Source provenance and metadata authority for one viewer stream.""" + + identity: ViewerStreamSourceIdentity + metadata: ViewerStreamSourceMetadata = field( + default_factory=ViewerStreamSourceMetadata + ) + + +@dataclass(frozen=True) +class ViewerStreamRequest: + """Typed view of backend kwargs at the viewer streaming boundary.""" + + viewer_transport: ViewerTransportEndpoint + display_config: ViewerDisplayConfigABC + source: ViewerStreamSource + producer_identity: StreamProducerIdentity + transport_config: ViewerTransportConfigSelection = DefaultViewerTransportConfig() + message_extra: ViewerWireMapping | None = None + images_dir: str | None = None + + @property + def host(self) -> str: + return self.viewer_transport.host + + @property + def port(self) -> int: + return self.viewer_transport.port + + @property + def transport_mode(self) -> ViewerTransportMode: + return self.viewer_transport.transport_mode + + def message_extra_payload(self) -> dict[str, ViewerWireValue]: + return ViewerMessageExtraAuthority.payload(self.message_extra) + + +ViewerStreamKwargPayloadMapping: TypeAlias = Mapping[ + str, + "ViewerStreamRequest", +] + + +@dataclass(frozen=True) +class ViewerStreamBackendKwargs: + """The only accepted FileManager kwarg payload for viewer stream backends.""" + + stream_request: ViewerStreamRequest + + @classmethod + def from_kwargs( + cls, + kwargs: ViewerStreamKwargPayloadMapping, + ) -> "ViewerStreamBackendKwargs": + expected = frozenset((ViewerStreamKwarg.STREAM_REQUEST.value,)) + actual = frozenset(kwargs) + if actual != expected: + raise ValueError( + "Viewer stream backends require exactly one kwarg: stream_request" + ) + value = kwargs[ViewerStreamKwarg.STREAM_REQUEST.value] + if not isinstance(value, ViewerStreamRequest): + raise TypeError("stream_request must be a ViewerStreamRequest instance") + return cls(value) + + def to_kwargs(self) -> dict[str, ViewerStreamRequest]: + return {ViewerStreamKwarg.STREAM_REQUEST.value: self.stream_request} + + +class ViewerMessageExtraAuthority: + """Formal boundary for absent caller-supplied viewer message extras.""" + + @staticmethod + def payload(message_extra: Mapping[str, ViewerWireValue] | None) -> dict[str, ViewerWireValue]: + if message_extra is None: + return {} + return dict(message_extra) diff --git a/src/polystore/streaming_constants.py b/src/polystore/streaming_constants.py index d7f0596..05c834c 100644 --- a/src/polystore/streaming_constants.py +++ b/src/polystore/streaming_constants.py @@ -15,6 +15,21 @@ class StreamingDataType(Enum): POINTS = "points" # Napari points layer (e.g., skeleton tracings) ROIS = "rois" # Fiji ROI payloads + @property + def uses_napari_vector_payload(self) -> bool: + """Whether napari should receive this type through vector layer payloads.""" + return self in (type(self).SHAPES, type(self).POINTS) + + @property + def napari_layer_suffix(self) -> str: + """Layer-key suffix contributed by this data type.""" + return { + type(self).IMAGE: "", + type(self).SHAPES: "_shapes", + type(self).POINTS: "_points", + type(self).ROIS: "", + }[self] + class NapariShapeType(Enum): """Napari shape types for ROI visualization.""" diff --git a/src/polystore/virtual_workspace.py b/src/polystore/virtual_workspace.py index 45081a3..e47c657 100644 --- a/src/polystore/virtual_workspace.py +++ b/src/polystore/virtual_workspace.py @@ -3,17 +3,133 @@ import logging import json from pathlib import Path -from typing import Any, Dict, List, Optional, Set, Union +from abc import ABC, abstractmethod +from typing import Any, ClassVar, Dict, List, Mapping, Optional, Set, Union from fnmatch import fnmatch from .disk import DiskStorageBackend from .metadata_writer import get_metadata_path from .exceptions import StorageResolutionError from .base import ReadOnlyBackend +from .constants import Backend +from .registry import AutoRegisterMeta logger = logging.getLogger(__name__) +class VirtualWorkspaceSourceRefResolver(ABC, metaclass=AutoRegisterMeta): + """Nominal loader family for virtual-workspace source references.""" + + __registry_key__ = "resolver_key" + __skip_if_no_key__ = True + resolver_key: ClassVar[str | None] = None + priority: ClassVar[int] + + @classmethod + def for_ref(cls, source_ref: Any) -> "VirtualWorkspaceSourceRefResolver": + for resolver_type in sorted( + cls.__registry__.values(), + key=lambda registered_type: registered_type.priority, + ): + resolver = resolver_type() + if resolver.accepts(source_ref): + return resolver + raise StorageResolutionError( + f"Unsupported virtual workspace source reference: {source_ref!r}" + ) + + @abstractmethod + def accepts(self, source_ref: Any) -> bool: + """Return whether this resolver owns the reference shape.""" + + @abstractmethod + def source_path(self, plate_root: Path, source_ref: Any) -> Path: + """Return the concrete source path for existence and diagnostics.""" + + @abstractmethod + def load( + self, + disk_backend: DiskStorageBackend, + plate_root: Path, + source_ref: Any, + **kwargs: Any, + ) -> Any: + """Load the payload addressed by this source reference.""" + + +class PathSourceRefResolver(VirtualWorkspaceSourceRefResolver): + """Resolve legacy string path mappings.""" + + resolver_key = "path" + priority = 100 + + def accepts(self, source_ref: Any) -> bool: + return isinstance(source_ref, (str, Path)) + + def source_path(self, plate_root: Path, source_ref: Any) -> Path: + path = Path(source_ref) + return path if path.is_absolute() else plate_root / path + + def load( + self, + disk_backend: DiskStorageBackend, + plate_root: Path, + source_ref: Any, + **kwargs: Any, + ) -> Any: + return disk_backend.load(self.source_path(plate_root, source_ref), **kwargs) + + +class DiskSourceRefResolver(VirtualWorkspaceSourceRefResolver): + """Resolve structured disk refs, including single-plane TIFF pages.""" + + resolver_key = "disk" + priority = 10 + + def accepts(self, source_ref: Any) -> bool: + return ( + isinstance(source_ref, Mapping) + and source_ref.get("backend", Backend.DISK.value) == Backend.DISK.value + and isinstance(source_ref.get("source_path"), (str, Path)) + ) + + def source_path(self, plate_root: Path, source_ref: Any) -> Path: + path = Path(source_ref["source_path"]) + return path if path.is_absolute() else plate_root / path + + def load( + self, + disk_backend: DiskStorageBackend, + plate_root: Path, + source_ref: Any, + **kwargs: Any, + ) -> Any: + payload = disk_backend.load(self.source_path(plate_root, source_ref), **kwargs) + plane_index = source_ref.get("plane_index") + if plane_index is None: + return payload + return _payload_plane(payload, int(plane_index), source_ref) + + +def _payload_plane(payload: Any, plane_index: int, source_ref: Mapping[str, Any]) -> Any: + if not hasattr(payload, "ndim") or not hasattr(payload, "shape"): + raise StorageResolutionError( + f"Source ref {source_ref!r} requested plane {plane_index}, but the loaded " + f"payload has no array shape." + ) + if payload.ndim < 3: + raise StorageResolutionError( + f"Source ref {source_ref!r} requested plane {plane_index}, but loaded " + f"payload shape {payload.shape!r} is not a stack." + ) + if plane_index < 0 or plane_index >= payload.shape[0]: + raise StorageResolutionError( + f"Source ref {source_ref!r} requested plane {plane_index}, but loaded " + f"payload shape is {payload.shape!r}." + ) + return payload[plane_index] + + class VirtualWorkspaceBackend(ReadOnlyBackend): """ Read-only path translation layer for virtual workspace. @@ -53,7 +169,7 @@ def __init__(self, plate_root: Path): """ self.plate_root = Path(plate_root) self.disk_backend = DiskStorageBackend() - self._mapping_cache: Optional[Dict[str, str]] = None + self._mapping_cache: Optional[Dict[str, Any]] = None self._cache_mtime: Optional[float] = None # Load mapping eagerly - fail loud if metadata missing @@ -76,7 +192,7 @@ def _normalize_relative_path(path_str: str) -> str: normalized = path_str.replace('\\', '/') return '' if normalized == '.' else normalized - def _load_mapping(self) -> Dict[str, str]: + def _load_mapping(self) -> Dict[str, Any]: """ Load workspace_mapping from metadata with mtime-based caching. @@ -122,7 +238,7 @@ def _load_mapping(self) -> Dict[str, str]: logger.info(f"Loaded {len(combined_mapping)} mappings for {self.plate_root}") return combined_mapping - def _resolve_path(self, path: Union[str, Path]) -> str: + def _resolve_ref(self, path: Union[str, Path]) -> Any: """ Resolve virtual path to real plate path using plate-relative mapping. @@ -163,20 +279,30 @@ def _resolve_path(self, path: Union[str, Path]) -> str: f"This path must be accessed through the virtual workspace mapping." ) - real_relative = self._mapping_cache[relative_str] - real_absolute = self.plate_root / real_relative - logger.debug(f"Resolved virtual → real: {relative_str} → {real_relative}") - return str(real_absolute) + source_ref = self._mapping_cache[relative_str] + logger.debug("Resolved virtual source ref: %s -> %r", relative_str, source_ref) + return source_ref + + def _resolve_path(self, path: Union[str, Path]) -> str: + """Resolve a virtual path to the concrete source path for diagnostics.""" + source_ref = self._resolve_ref(path) + resolver = VirtualWorkspaceSourceRefResolver.for_ref(source_ref) + return str(resolver.source_path(self.plate_root, source_ref)) def load(self, file_path: Union[str, Path], **kwargs) -> Any: """Load file from virtual workspace.""" - real_path = self._resolve_path(file_path) - return self.disk_backend.load(real_path, **kwargs) + source_ref = self._resolve_ref(file_path) + resolver = VirtualWorkspaceSourceRefResolver.for_ref(source_ref) + return resolver.load( + self.disk_backend, + self.plate_root, + source_ref, + **kwargs, + ) def load_batch(self, file_paths: List[Union[str, Path]], **kwargs) -> List[Any]: """Load multiple files from virtual workspace.""" - real_paths = [self._resolve_path(fp) for fp in file_paths] - return self.disk_backend.load_batch(real_paths, **kwargs) + return [self.load(file_path, **kwargs) for file_path in file_paths] def list_files(self, directory: Union[str, Path], pattern: Optional[str] = None, extensions: Optional[Set[str]] = None, recursive: bool = False, @@ -205,10 +331,20 @@ def list_files(self, directory: Union[str, Path], pattern: Optional[str] = None, if self._mapping_cache is None: self._load_mapping() - logger.info(f"VirtualWorkspace.list_files called: directory={directory}, recursive={recursive}, pattern={pattern}, extensions={extensions}") - logger.info(f" plate_root={self.plate_root}") - logger.info(f" relative_dir_str='{relative_dir_str}'") - logger.info(f" mapping has {len(self._mapping_cache)} entries") + logger.debug( + "VirtualWorkspace.list_files directory=%s recursive=%s pattern=%s extensions=%s", + directory, + recursive, + pattern, + extensions, + ) + logger.debug(" plate_root=%s", self.plate_root) + logger.debug(" relative_dir_str=%r", relative_dir_str) + logger.debug(" mapping has %s entries", len(self._mapping_cache)) + + lowercase_extensions = ( + None if extensions is None else {ext.lower() for ext in extensions} + ) # Filter paths in this directory results = [] @@ -230,20 +366,20 @@ def list_files(self, directory: Union[str, Path], pattern: Optional[str] = None, vpath = Path(virtual_relative) if pattern and not fnmatch(vpath.name, pattern): continue - if extensions and vpath.suffix not in extensions: + if lowercase_extensions and vpath.suffix.lower() not in lowercase_extensions: continue # Return absolute path results.append(str(self.plate_root / virtual_relative)) - logger.info(f" VirtualWorkspace.list_files returning {len(results)} files") + logger.debug(" VirtualWorkspace.list_files returning %s files", len(results)) if len(results) == 0 and len(self._mapping_cache) > 0: # Log first few mapping keys to help debug sample_keys = list(self._mapping_cache.keys())[:3] - logger.info(f" Sample mapping keys: {sample_keys}") + logger.debug(" Sample mapping keys: %s", sample_keys) if not recursive and relative_dir_str == '': sample_parents = [str(Path(k).parent).replace('\\', '/') for k in sample_keys] - logger.info(f" Sample parent dirs: {sample_parents}") + logger.debug(" Sample parent dirs: %s", sample_parents) logger.info(f" Expected parent to match: '{relative_dir_str}'") return sorted(results) diff --git a/tests/test_memory_backend.py b/tests/test_memory_backend.py index f55996b..ec8a080 100644 --- a/tests/test_memory_backend.py +++ b/tests/test_memory_backend.py @@ -109,6 +109,17 @@ def test_list_files_with_extension_filter(self): npy_files = self.backend.list_files("/test", extensions={".npy"}) assert len(npy_files) == 2 + def test_list_files_extension_filter_is_case_insensitive(self): + """Test extension filtering matches backend contract case-insensitively.""" + self.backend.save(np.array([1]), "/test/image.TIF") + self.backend.save(np.array([2]), "/test/image.tif") + self.backend.save("text", "/test/notes.TXT") + + tif_files = self.backend.list_files("/test", extensions={".tif"}) + + assert len(tif_files) == 2 + assert {path.name for path in tif_files} == {"image.TIF", "image.tif"} + def test_list_files_recursive(self): """Test recursive file listing.""" # Create files in multiple levels diff --git a/tests/test_roi.py b/tests/test_roi.py new file mode 100644 index 0000000..565022f --- /dev/null +++ b/tests/test_roi.py @@ -0,0 +1,79 @@ +import numpy as np + +from polystore.roi import MaskShape +from polystore.roi import PolygonShape +from polystore.roi import load_rois_from_json +from polystore.roi import extract_rois_from_labeled_mask + + +def test_extract_rois_from_labeled_mask_applies_spatial_origin_to_polygons(): + labels = np.zeros((8, 8), dtype=np.int32) + labels[2:6, 3:7] = 1 + + rois = extract_rois_from_labeled_mask( + labels, + min_area=0, + extract_contours=True, + spatial_origin_yx=(10, 20), + ) + + assert len(rois) == 1 + assert rois[0].metadata["bbox"] == (12, 23, 16, 27) + assert rois[0].metadata["centroid"] == (13.5, 24.5) + assert isinstance(rois[0].shapes[0], PolygonShape) + assert float(rois[0].shapes[0].coordinates[:, 0].min()) >= 11.5 + assert float(rois[0].shapes[0].coordinates[:, 1].min()) >= 22.5 + + +def test_extract_rois_from_labeled_mask_applies_spatial_origin_to_mask_bbox(): + labels = np.zeros((8, 8), dtype=np.int32) + labels[2:6, 3:7] = 1 + + rois = extract_rois_from_labeled_mask( + labels, + min_area=0, + extract_contours=False, + spatial_origin_yx=(10, 20), + ) + + assert len(rois) == 1 + assert isinstance(rois[0].shapes[0], MaskShape) + assert rois[0].shapes[0].bbox == (12, 23, 16, 27) + + +def test_extract_rois_from_labeled_mask_records_source_canvas_shape(): + labels = np.zeros((8, 8), dtype=np.int32) + labels[2:6, 3:7] = 1 + + rois = extract_rois_from_labeled_mask( + labels, + min_area=0, + source_spatial_shape_yx=(100, 200), + ) + + assert len(rois) == 1 + assert rois[0].metadata["source_spatial_shape_yx"] == (100, 200) + + +def test_load_rois_from_json_decodes_shapes_through_nominal_registry(tmp_path): + roi_path = tmp_path / "rois.json" + roi_path.write_text( + """ + [ + { + "metadata": {"label": 1}, + "shapes": [ + {"type": "polygon", "coordinates": [[1, 2], [3, 4], [5, 6]]}, + {"type": "mask", "mask": [[true, false], [false, true]], "bbox": [10, 20, 12, 22]} + ] + } + ] + """ + ) + + rois = load_rois_from_json(roi_path) + + assert len(rois) == 1 + assert isinstance(rois[0].shapes[0], PolygonShape) + assert isinstance(rois[0].shapes[1], MaskShape) + assert rois[0].shapes[1].bbox == (10, 20, 12, 22) diff --git a/tests/test_streaming_metadata.py b/tests/test_streaming_metadata.py new file mode 100644 index 0000000..a90d013 --- /dev/null +++ b/tests/test_streaming_metadata.py @@ -0,0 +1,293 @@ +from types import SimpleNamespace + +import pytest + +from polystore.streaming._streaming_backend import StreamingBackend +from polystore.streaming._streaming_backend import StreamingBatchItemPreparationAuthority +from polystore.streaming._streaming_backend import StreamingBatchMessageBuilder +from polystore.streaming._streaming_backend import StreamingBatchMessageRequest +from polystore.streaming._streaming_backend import StreamingItemPath +from polystore.streaming._streaming_backend import StreamingItemPreparationRequest +from polystore.streaming.identity import StreamProducerIdentity +from polystore.streaming.viewer_transport import ViewerDisplayConfigABC +from polystore.streaming.viewer_transport import ViewerMicroscopeHandlerABC +from polystore.streaming.viewer_transport import ViewerStreamRequest +from polystore.streaming.viewer_transport import ViewerStreamSource +from polystore.streaming.viewer_transport import ViewerStreamSourceIdentity +from polystore.streaming.viewer_transport import ViewerStreamSourceMetadata +from zmqruntime.config import TransportMode +from zmqruntime.viewer_protocol import ViewerAckPolicy +from zmqruntime.viewer_protocol import ViewerTransportEndpoint + + +class MetadataProbeStreamingBackend(StreamingBackend): + VIEWER_TYPE = "probe" + SHM_PREFIX = "probe_" + + def _prepare_batch_item(self, request: StreamingItemPreparationRequest): + return {"path": request.item_path.wire_value, "payload": "ok"}, "image" + + def save_batch(self, data_list, file_paths, **kwargs): + raise NotImplementedError + + +class DisplayConfigStub(ViewerDisplayConfigABC): + COMPONENT_ORDER = ("well", "site", "channel") + + def component_modes(self): + return { + "well": "stack", + "site": "stack", + "channel": "stack", + } + + +PRODUCER_IDENTITY = StreamProducerIdentity( + origin="pipeline", + output_kind="main", + output_key="main", + step_name="IdentifyPrimaryObjects", +) + + +EMPTY_SOURCE_METADATA = ViewerStreamSourceMetadata() + + +def stream_request( + microscope_handler, + source_metadata=EMPTY_SOURCE_METADATA, + *, + plate_path=None, + message_extra=None, +): + return ViewerStreamRequest( + viewer_transport=ViewerTransportEndpoint( + host="127.0.0.1", + port=5555, + transport_mode=TransportMode.TCP, + ), + display_config=DisplayConfigStub(), + source=ViewerStreamSource( + identity=ViewerStreamSourceIdentity( + microscope_handler=microscope_handler, + plate_path=plate_path, + ), + metadata=source_metadata, + ), + producer_identity=PRODUCER_IDENTITY, + message_extra=message_extra, + ) + + +def batch_message_request(data_list, file_paths, viewer_request): + return StreamingBatchMessageRequest( + data_list=data_list, + file_paths=file_paths, + stream_request=viewer_request, + ) + + +def microscope_handler_with_parser(parser): + class MicroscopeHandlerStub(ViewerMicroscopeHandlerABC): + pass + + microscope_handler = MicroscopeHandlerStub() + microscope_handler.parser = parser + microscope_handler.metadata_handler = SimpleNamespace( + get_component_values=lambda _plate_path, _component_name: None + ) + return microscope_handler + + +def test_streaming_source_metadata_rejects_missing_component_metadata() -> None: + with pytest.raises(ValueError, match="explicit component_metadata"): + ViewerStreamSourceMetadata().component_metadata_for_item( + "A01_s001_w1_z001_t001_Nuclei_step3_rois.roi.zip", + 0, + ) + + +def test_streaming_batch_items_reject_unparsed_artifact_filename() -> None: + backend = MetadataProbeStreamingBackend() + microscope_handler = microscope_handler_with_parser( + SimpleNamespace(parse_filename=lambda _filename: None) + ) + + with pytest.raises(ValueError, match="explicit component_metadata"): + StreamingBatchItemPreparationAuthority.prepare( + backend, + batch_message_request( + [object()], + ["A01_s001_w1_z001_t001_Nuclei_step3_rois.roi.zip"], + stream_request(microscope_handler), + ) + ) + + +def test_streaming_batch_items_accept_per_path_component_metadata() -> None: + backend = MetadataProbeStreamingBackend() + microscope_handler = microscope_handler_with_parser( + SimpleNamespace(parse_filename=lambda _filename: None) + ) + + prepared_items = StreamingBatchItemPreparationAuthority.prepare( + backend, + batch_message_request( + [object()], + ["A01_s001_w1_z001_t001_Nuclei_step3_rois.roi.zip"], + stream_request( + microscope_handler, + ViewerStreamSourceMetadata( + component_metadata_by_path={ + "A01_s001_w1_z001_t001_Nuclei_step3_rois.roi.zip": { + "well": "A01", + "site": 1, + "channel": 1, + }, + } + ), + ), + ) + ) + + assert prepared_items.batch_images[0]["metadata"] == { + "well": "A01", + "site": 1, + "channel": 1, + } + assert ( + prepared_items.batch_images[0]["producer_identity"] + == PRODUCER_IDENTITY.to_payload() + ) + + +def test_streaming_item_component_metadata_preserves_explicit_fields() -> None: + metadata = ViewerStreamSourceMetadata( + component_metadata={"well": "A01", "site": 1, "channel": 1}, + ).component_metadata_for_item( + StreamingItemPath("A01_s001_w1_z001_t001_Nuclei_step3_rois.roi.zip").value, + 0, + ) + + assert metadata == { + "well": "A01", + "site": 1, + "channel": 1, + } + + +def test_streaming_batch_message_declares_component_value_domain() -> None: + backend = MetadataProbeStreamingBackend() + microscope_handler = microscope_handler_with_parser( + SimpleNamespace(parse_filename=lambda _filename: None) + ) + + built_batch = StreamingBatchMessageBuilder.build( + backend, + batch_message_request( + [object(), object()], + ["A01_s001_w1_z001_t001.tif", "A01_s002_w2_z001_t001.tif"], + stream_request( + microscope_handler, + ViewerStreamSourceMetadata( + component_metadata_by_path=( + {"well": "A01", "site": 1, "channel": 1}, + {"well": "A01", "site": 2, "channel": 2}, + ), + ), + ), + ), + ) + + assert built_batch.message["component_value_domain"] == { + "well": ["A01"], + "site": [1, 2], + "channel": [1, 2], + } + + +def test_streaming_batch_message_honors_declared_component_metadata_payload() -> None: + backend = MetadataProbeStreamingBackend() + microscope_handler = microscope_handler_with_parser( + SimpleNamespace(parse_filename=lambda _filename: None) + ) + + built_batch = StreamingBatchMessageBuilder.build( + backend, + batch_message_request( + [object()], + ["A01_s001_w1_z001_t001.tif"], + stream_request( + microscope_handler, + ViewerStreamSourceMetadata( + component_metadata={"well": "A01", "site": 1, "channel": 1}, + ), + message_extra={ + "component_value_domain": {"well": ["A01", "B01"]}, + "component_names_metadata": {"well": {"A01": "control"}}, + }, + ), + ), + ) + + assert built_batch.message["component_value_domain"] == {"well": ["A01", "B01"]} + assert built_batch.message["component_names_metadata"] == { + "well": {"A01": "control"} + } + + +def test_streaming_batch_message_rejects_partial_declared_component_metadata_payload() -> None: + backend = MetadataProbeStreamingBackend() + microscope_handler = microscope_handler_with_parser( + SimpleNamespace(parse_filename=lambda _filename: None) + ) + + with pytest.raises(ValueError, match="component_names_metadata"): + StreamingBatchMessageBuilder.build( + backend, + batch_message_request( + [object()], + ["A01_s001_w1_z001_t001.tif"], + stream_request( + microscope_handler, + ViewerStreamSourceMetadata( + component_metadata={"well": "A01", "site": 1, "channel": 1}, + ), + message_extra={"component_value_domain": {"well": ["A01"]}}, + ), + ), + ) + + +def test_streaming_component_metadata_rejects_invalid_explicit_metadata() -> None: + with pytest.raises(TypeError, match="must be a mapping"): + ViewerStreamSourceMetadata( + component_metadata=["not", "metadata"], + ).component_metadata_for_item( + StreamingItemPath("A01_s001_w1_z001_t001.TIF").value, + 0, + ) + + +class ViewerAckSocketStub: + def __init__(self, response): + self.response = response + + def recv_json(self): + return self.response + + +def test_viewer_ack_policy_rejects_error_status_and_cleans_up() -> None: + cleanup_calls = [] + policy = ViewerAckPolicy(viewer_name="Napari", timeout_ms=30_000) + + with pytest.raises(RuntimeError, match="Napari rejected stream batch"): + policy.receive( + ViewerAckSocketStub( + {"status": "error", "message": "missing component_value_domain"} + ), + lambda: cleanup_calls.append("cleanup"), + port=5555, + ) + + assert cleanup_calls == ["cleanup"] diff --git a/tests/test_streaming_receiver_core.py b/tests/test_streaming_receiver_core.py index 6f7cd63..eb7a76b 100644 --- a/tests/test_streaming_receiver_core.py +++ b/tests/test_streaming_receiver_core.py @@ -4,75 +4,230 @@ import time from polystore.streaming_constants import StreamingDataType +from polystore.streaming.identity import ( + FixedStreamProducerIdentityKind, + StreamProducerDisplayNameAuthority, + StreamProducerIdentity, +) from polystore.streaming.receivers.core import ( DebouncedBatchEngine, + WindowProjectionSource, group_items_by_component_modes, ) from polystore.streaming.receivers.napari import ( normalize_component_layout, - build_layer_key, + build_route_key, ) +from zmqruntime.viewer_protocol import ViewerBatchDisplayPayload + +class PipelineProducerFixture: + """Nominal producer fixtures for receiver-core tests.""" + + MAIN_KIND = "main" + MAIN_KEY = "main" + ARTIFACT_KIND = "artifact" + + @classmethod + def main_output( + cls, + *, + step_name: str, + pipeline_position: int, + ) -> StreamProducerIdentity: + return StreamProducerIdentity.pipeline_output( + output_kind=cls.MAIN_KIND, + output_key=cls.MAIN_KEY, + step_name=step_name, + pipeline_position=pipeline_position, + ) + @classmethod + def artifact_output( + cls, + *, + output_key: str, + step_name: str, + pipeline_position: int, + artifact_kind: str | None = None, + ) -> StreamProducerIdentity: + return StreamProducerIdentity.pipeline_output( + output_kind=cls.ARTIFACT_KIND, + output_key=output_key, + step_name=step_name, + pipeline_position=pipeline_position, + artifact_kind=artifact_kind, + ) -def test_group_items_by_component_modes_source_normalization_for_rois() -> None: + +def test_group_items_by_component_modes_keys_windows_by_producer_identity() -> None: + image_identity = PipelineProducerFixture.main_output( + step_name="RawLoad", + pipeline_position=1, + ) + roi_identity = PipelineProducerFixture.artifact_output( + output_key="Nuclei", + step_name="Segment", + pipeline_position=2, + artifact_kind="object_labels", + ) items = [ { "data_type": "rois", - "metadata": {"source": "/tmp/foo_results", "well": "A01", "channel": 1}, + "metadata": {"well": "A01", "channel": 1}, + "producer_identity": roi_identity.to_payload(), }, { "data_type": "image", - "metadata": {"source": "step_1", "well": "A01", "channel": 1}, + "metadata": {"well": "A01", "channel": 1}, + "producer_identity": image_identity.to_payload(), }, ] - component_modes = {"source": "window", "well": "frame", "channel": "channel"} - component_order = ["source", "well", "channel"] + component_modes = {"well": "frame", "channel": "channel"} + component_order = ["well", "channel"] grouped = group_items_by_component_modes( - items, - component_modes=component_modes, - component_order=component_order, - images_dir="/my/plate/images", + WindowProjectionSource.from_wire_payloads(items), + display_layout=ViewerBatchDisplayPayload( + component_modes=component_modes, + component_order=component_order, + ), ) - assert grouped.window_components == ["source"] + assert grouped.window_components == [] assert grouped.channel_components == ["channel"] assert grouped.frame_components == ["well"] - assert "source_images" in grouped.windows - assert "source_step_1" in grouped.windows + assert grouped.slice_components == [] + assert grouped.fixed_window_labels[ + "origin_pipeline_kind_artifact_out_Nuclei_step_2_name_Segment_artifact_object_labels" + ] == (("producer", "3. Segment Nuclei"),) + assert set(grouped.windows) == { + "origin_pipeline_kind_artifact_out_Nuclei_step_2_name_Segment_artifact_object_labels", + "origin_pipeline_kind_main_out_main_step_1_name_RawLoad", + } + + +def test_group_items_by_component_modes_rejects_missing_metadata() -> None: + producer = PipelineProducerFixture.main_output( + step_name="RawLoad", + pipeline_position=1, + ) + + try: + group_items_by_component_modes( + WindowProjectionSource.from_wire_payloads( + [{"producer_identity": producer.to_payload()}] + ), + display_layout=ViewerBatchDisplayPayload( + component_modes={"well": "window"}, + component_order=["well"], + ), + ) + except ValueError as error: + assert "metadata" in str(error) + else: + raise AssertionError("missing metadata must fail loudly") + + +def test_stream_producer_display_name_authority_matches_pipeline_editor_indexing() -> None: + main_output = PipelineProducerFixture.main_output( + step_name="ConvertObjectsToImage", + pipeline_position=8, + ) + artifact_output = PipelineProducerFixture.artifact_output( + output_key="NucleiObjects3D", + step_name="ConvertObjectsToImage", + pipeline_position=8, + artifact_kind="object_labels", + ) + manual_output = StreamProducerIdentity.fixed_output( + FixedStreamProducerIdentityKind.MANUAL, + "selected_rois", + ) + assert ( + StreamProducerDisplayNameAuthority.producer_label(main_output) + == "9. ConvertObjectsToImage" + ) + assert ( + StreamProducerDisplayNameAuthority.output_label(main_output) + == "9. ConvertObjectsToImage" + ) + assert ( + StreamProducerDisplayNameAuthority.output_label(artifact_output) + == "9. ConvertObjectsToImage NucleiObjects3D" + ) + assert StreamProducerDisplayNameAuthority.output_label(manual_output) == "selected_rois" + assert ( + StreamProducerDisplayNameAuthority.disambiguation_label(main_output) + == "step 9" + ) -def test_napari_layer_key_builder_uses_slice_components_and_payload_type() -> None: + +def test_napari_route_key_builder_uses_producer_slice_components_and_payload_type() -> None: + producer = PipelineProducerFixture.artifact_output( + output_key="Nuclei", + step_name="Segment", + pipeline_position=2, + ) component_modes = {"well": "slice", "channel": "stack", "site": "slice"} component_order = ["well", "channel", "site"] component_info = {"well": "A01", "channel": 2, "site": 3} - key_image = build_layer_key( + key_image = build_route_key( + producer_identity=producer, component_info=component_info, - component_modes=component_modes, - component_order=component_order, + display_layout=ViewerBatchDisplayPayload( + component_modes=component_modes, + component_order=component_order, + ), data_type=StreamingDataType.IMAGE, ) - key_shapes = build_layer_key( + key_shapes = build_route_key( + producer_identity=producer, component_info=component_info, - component_modes=component_modes, - component_order=component_order, + display_layout=ViewerBatchDisplayPayload( + component_modes=component_modes, + component_order=component_order, + ), data_type=StreamingDataType.SHAPES, ) - assert key_image == "well_A01_site_3" - assert key_shapes == "well_A01_site_3_shapes" + assert key_image == "origin_pipeline_kind_artifact_out_Nuclei_step_2_name_Segment_well_A01_site_3" + assert key_shapes == "origin_pipeline_kind_artifact_out_Nuclei_step_2_name_Segment_well_A01_site_3_shapes" + + +def test_napari_route_key_builder_rejects_missing_slice_component() -> None: + producer = PipelineProducerFixture.artifact_output( + output_key="Nuclei", + step_name="Segment", + pipeline_position=2, + ) + + try: + build_route_key( + producer_identity=producer, + component_info={"well": "A01"}, + display_layout=ViewerBatchDisplayPayload( + component_modes={"well": "slice", "site": "slice"}, + component_order=["well", "site"], + ), + data_type=StreamingDataType.IMAGE, + ) + except ValueError as error: + assert "site" in str(error) + else: + raise AssertionError("missing slice component must fail loudly") def test_normalize_component_layout_dict_config() -> None: - component_modes, component_order = normalize_component_layout( + display_layout = normalize_component_layout( { "component_modes": {"well": "slice", "channel": "stack"}, "component_order": ["well", "channel"], } ) - assert component_order == ["well", "channel"] - assert component_modes["well"] == "slice" + assert list(display_layout.component_order) == ["well", "channel"] + assert display_layout.component_modes["well"] == "slice" def test_debounced_batch_engine_flush_processes_pending_once() -> None: diff --git a/tests/test_viewer_transport.py b/tests/test_viewer_transport.py new file mode 100644 index 0000000..753de6f --- /dev/null +++ b/tests/test_viewer_transport.py @@ -0,0 +1,151 @@ +import pytest + +from polystore.streaming.identity import StreamProducerIdentity +from polystore.streaming.viewer_transport import ExplicitViewerTransportConfig +from polystore.streaming.viewer_transport import ViewerStreamBackendKwargs +from polystore.streaming.viewer_transport import ViewerStreamKwarg +from polystore.streaming.viewer_transport import ViewerDisplayConfigABC +from polystore.streaming.viewer_transport import ViewerFilenameParserABC +from polystore.streaming.viewer_transport import ViewerMetadataHandlerABC +from polystore.streaming.viewer_transport import ViewerMicroscopeHandlerABC +from polystore.streaming.viewer_transport import ViewerStreamRequest +from polystore.streaming.viewer_transport import ViewerStreamSource +from polystore.streaming.viewer_transport import ViewerStreamSourceIdentity +from polystore.streaming.viewer_transport import ViewerStreamSourceMetadata +from zmqruntime.config import TransportMode, ZMQConfig +from zmqruntime.viewer_protocol import ViewerTransportEndpoint + + +class DisplayConfigFixture(ViewerDisplayConfigABC): + COMPONENT_ORDER = ("well", "site", "channel") + + def component_modes(self): + return { + "well": "stack", + "site": "slice", + "channel": "channel", + } + + +class FilenameParserFixture(ViewerFilenameParserABC): + def parse_filename(self, filename): + return {"filename": filename} + + +class MetadataHandlerFixture(ViewerMetadataHandlerABC): + def get_component_values(self, plate_path, component_name): + return f"{plate_path}:{component_name}" + + +class MicroscopeHandlerFixture(ViewerMicroscopeHandlerABC): + parser = FilenameParserFixture() + metadata_handler = MetadataHandlerFixture() + + +EMPTY_SOURCE_METADATA = ViewerStreamSourceMetadata() + + +def stream_source( + source_metadata=EMPTY_SOURCE_METADATA, + *, + plate_path="/tmp/plate", +): + return ViewerStreamSource( + identity=ViewerStreamSourceIdentity( + microscope_handler=MicroscopeHandlerFixture(), + plate_path=plate_path, + ), + metadata=source_metadata, + ) + + +def required_stream_request(**kwargs): + values = { + "viewer_transport": ViewerTransportEndpoint( + host="127.0.0.1", + port=5555, + transport_mode=TransportMode.TCP, + ), + "display_config": DisplayConfigFixture(), + "source": stream_source(), + "producer_identity": StreamProducerIdentity.pipeline_output( + output_kind="main", + output_key="main", + step_name="IdentifyPrimaryObjects", + pipeline_position=2, + ), + } + values.update(kwargs) + return ViewerStreamRequest(**values) + + +def test_viewer_stream_kwargs_declares_explicit_backend_request() -> None: + stream_kwargs = required_stream_request( + source=stream_source( + ViewerStreamSourceMetadata( + component_metadata_by_path=( + {"well": "A01", "site": 1}, + {"well": "A01", "site": 2}, + ), + ), + plate_path="/tmp/plate", + ), + message_extra={"component_value_domain": {"well": ["A01"]}}, + images_dir="/tmp/images", + ) + + backend_kwargs = ViewerStreamBackendKwargs(stream_kwargs).to_kwargs() + + assert backend_kwargs == {ViewerStreamKwarg.STREAM_REQUEST.value: stream_kwargs} + assert ViewerStreamBackendKwargs.from_kwargs(backend_kwargs).stream_request is stream_kwargs + assert stream_kwargs.host == "127.0.0.1" + assert stream_kwargs.port == 5555 + assert stream_kwargs.transport_mode is TransportMode.TCP + assert stream_kwargs.producer_identity == StreamProducerIdentity.pipeline_output( + output_kind="main", + output_key="main", + step_name="IdentifyPrimaryObjects", + pipeline_position=2, + ) + assert stream_kwargs.source.metadata.component_metadata is None + assert stream_kwargs.source.metadata.component_metadata_by_path == ( + {"well": "A01", "site": 1}, + {"well": "A01", "site": 2}, + ) + default_config = ZMQConfig(default_port=9001) + assert stream_kwargs.transport_config.resolve(default_config) is default_config + + +def test_viewer_stream_kwargs_rejects_mixed_source_metadata() -> None: + with pytest.raises(ValueError, match="either component_metadata"): + ViewerStreamSourceMetadata( + component_metadata={"well": "A01", "site": 1}, + component_metadata_by_path=( + {"well": "A01", "site": 1}, + ), + ) + + +def test_viewer_stream_backend_rejects_flat_kwargs() -> None: + with pytest.raises(ValueError, match="stream_request"): + ViewerStreamBackendKwargs.from_kwargs( + {"display_config": DisplayConfigFixture()} + ) + + +def test_viewer_stream_kwargs_preserves_explicit_transport_config() -> None: + explicit_config = ZMQConfig(shared_ack_port=8111) + default_config = ZMQConfig(shared_ack_port=8222) + + stream_kwargs = required_stream_request( + transport_config=ExplicitViewerTransportConfig(explicit_config) + ) + + assert stream_kwargs.transport_config.resolve(default_config) is explicit_config + + +def test_viewer_stream_backend_rejects_non_request_payload() -> None: + with pytest.raises(TypeError, match="ViewerStreamRequest"): + ViewerStreamBackendKwargs.from_kwargs( + {ViewerStreamKwarg.STREAM_REQUEST.value: DisplayConfigFixture()} + )