From f2c41d1fcf83466c3b2fcda061f8934296654f2e Mon Sep 17 00:00:00 2001 From: Tristan Simas Date: Tue, 28 Apr 2026 20:32:56 -0400 Subject: [PATCH 01/11] Support standard raster image formats --- pyproject.toml | 3 ++- src/polystore/disk.py | 15 +++++++++++++-- src/polystore/formats.py | 11 ++++++++++- src/polystore/virtual_workspace.py | 6 +++++- 4 files changed, 30 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 1dd9dfc..08c8602 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,7 @@ dependencies = [ "numpy>=1.26.0", "portalocker>=2.8.0", # Cross-platform file locking "metaclass-registry", + "imageio>=2.37.0", "zarr>=2.18.0,<3.0", # Required for ZarrStorageBackend "ome-zarr>=0.11.0", # Required for OME-ZARR HCS compliance ] @@ -197,4 +198,4 @@ ignore = [ ] [tool.ruff.per-file-ignores] -"__init__.py" = ["F401"] # unused imports \ No newline at end of file +"__init__.py" = ["F401"] # unused imports diff --git a/src/polystore/disk.py b/src/polystore/disk.py index 40c33d9..fe82b86 100644 --- a/src/polystore/disk.py +++ b/src/polystore/disk.py @@ -9,6 +9,7 @@ import logging import os import shutil +import importlib from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Set, Union @@ -23,7 +24,7 @@ def optional_import(module_name): try: - return __import__(module_name) + return importlib.import_module(module_name) except ImportError: return None @@ -44,6 +45,7 @@ def optional_import(module_name): cupy = get_cupy() tf = get_tf() tifffile = optional_import("tifffile") +imageio = optional_import("imageio.v3") # Optional arraybridge integration for memory conversion try: @@ -99,6 +101,7 @@ def _register_formats(self): # Complex formats - use custom handlers (FileFormat.TIFF, tifffile, self._tiff_writer, self._tiff_reader), + (FileFormat.RASTER_IMAGE, imageio, self._image_writer, self._image_reader), (FileFormat.TEXT, True, self._text_writer, self._text_reader), (FileFormat.JSON, True, self._json_writer, self._json_reader), (FileFormat.CSV, True, self._csv_writer, self._csv_reader), @@ -164,6 +167,14 @@ def _tiff_reader(self, path): else: return tifffile.imread(str(path)) + def _image_writer(self, path, data, **kwargs): + """Write standard raster images using imageio.""" + imageio.imwrite(path, np.asarray(data)) + + def _image_reader(self, path): + """Read standard raster images using imageio.""" + return imageio.imread(path) + def _text_writer(self, path, data, **kwargs): """Write text data to file. Accepts and ignores extra kwargs for compatibility.""" path.write_text(str(data)) @@ -261,7 +272,7 @@ def load(self, file_path: Union[str, Path], **kwargs) -> Any: ext = disk_path.suffix.lower() if not self.format_registry.is_registered(ext): - raise ValueError(f"No writer registered for extension '{ext}'") + raise ValueError(f"No reader registered for extension '{ext}'") try: reader = self.format_registry.get_reader(ext) diff --git a/src/polystore/formats.py b/src/polystore/formats.py index ddfb9a5..3643361 100644 --- a/src/polystore/formats.py +++ b/src/polystore/formats.py @@ -20,6 +20,7 @@ class FileFormat(Enum): # Image formats TIFF = "tiff" + RASTER_IMAGE = "raster_image" # Data formats CSV = "csv" @@ -44,6 +45,7 @@ def extensions(self): FileFormat.TENSORFLOW: [".tf"], FileFormat.ZARR: [".zarr"], FileFormat.TIFF: [".tif", ".tiff"], + FileFormat.RASTER_IMAGE: [".bmp", ".gif", ".jpeg", ".jpg", ".png"], FileFormat.CSV: [".csv"], FileFormat.JSON: [".json"], FileFormat.TEXT: [".txt"], @@ -51,7 +53,14 @@ def extensions(self): } # Default image extensions -DEFAULT_IMAGE_EXTENSIONS = {".tif", ".tiff", ".TIF", ".TIFF"} +DEFAULT_IMAGE_EXTENSIONS = { + extension + for extensions in ( + FILE_FORMAT_EXTENSIONS[FileFormat.TIFF], + FILE_FORMAT_EXTENSIONS[FileFormat.RASTER_IMAGE], + ) + for extension in extensions +} def get_format_from_extension(ext: str) -> FileFormat: diff --git a/src/polystore/virtual_workspace.py b/src/polystore/virtual_workspace.py index 45081a3..c7bc61b 100644 --- a/src/polystore/virtual_workspace.py +++ b/src/polystore/virtual_workspace.py @@ -210,6 +210,10 @@ def list_files(self, directory: Union[str, Path], pattern: Optional[str] = None, logger.info(f" relative_dir_str='{relative_dir_str}'") logger.info(f" mapping has {len(self._mapping_cache)} entries") + lowercase_extensions = ( + None if extensions is None else {ext.lower() for ext in extensions} + ) + # Filter paths in this directory results = [] for virtual_relative in self._mapping_cache.keys(): @@ -230,7 +234,7 @@ def list_files(self, directory: Union[str, Path], pattern: Optional[str] = None, vpath = Path(virtual_relative) if pattern and not fnmatch(vpath.name, pattern): continue - if extensions and vpath.suffix not in extensions: + if lowercase_extensions and vpath.suffix.lower() not in lowercase_extensions: continue # Return absolute path From 637107ed87707a0577c96f7cf755272a88170d1a Mon Sep 17 00:00:00 2001 From: Tristan Simas Date: Wed, 29 Apr 2026 05:49:00 -0400 Subject: [PATCH 02/11] Make memory extension filtering case-insensitive --- src/polystore/memory.py | 8 +++++++- tests/test_memory_backend.py | 11 +++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/src/polystore/memory.py b/src/polystore/memory.py index a59114f..5f3f6df 100644 --- a/src/polystore/memory.py +++ b/src/polystore/memory.py @@ -139,6 +139,9 @@ def list_files( if self._memory_store[dir_key] is not None: raise NotADirectoryError(f"Path is not a directory: {directory}") + lowercase_extensions = ( + None if extensions is None else {extension.lower() for extension in extensions} + ) result = [] dir_prefix = dir_key + "/" if not dir_key.endswith("/") else dir_key @@ -159,7 +162,10 @@ def list_files( filename = Path(rel_path).name # If pattern is None, match all files if pattern is None or fnmatch(filename, pattern): - if not extensions or Path(filename).suffix in extensions: + if ( + lowercase_extensions is None + or Path(filename).suffix.lower() in lowercase_extensions + ): # Calculate depth for breadth-first sorting depth = rel_path.count('/') result.append((Path(path), depth)) diff --git a/tests/test_memory_backend.py b/tests/test_memory_backend.py index f55996b..ec8a080 100644 --- a/tests/test_memory_backend.py +++ b/tests/test_memory_backend.py @@ -109,6 +109,17 @@ def test_list_files_with_extension_filter(self): npy_files = self.backend.list_files("/test", extensions={".npy"}) assert len(npy_files) == 2 + def test_list_files_extension_filter_is_case_insensitive(self): + """Test extension filtering matches backend contract case-insensitively.""" + self.backend.save(np.array([1]), "/test/image.TIF") + self.backend.save(np.array([2]), "/test/image.tif") + self.backend.save("text", "/test/notes.TXT") + + tif_files = self.backend.list_files("/test", extensions={".tif"}) + + assert len(tif_files) == 2 + assert {path.name for path in tif_files} == {"image.TIF", "image.tif"} + def test_list_files_recursive(self): """Test recursive file listing.""" # Create files in multiple levels From 3365356d20eed6af0ff6ddde9cccd9bf02d595e6 Mon Sep 17 00:00:00 2001 From: Tristan Simas Date: Mon, 4 May 2026 17:03:42 -0400 Subject: [PATCH 03/11] Reduce VFS debug log noise --- src/polystore/base.py | 13 +++++++------ src/polystore/virtual_workspace.py | 20 +++++++++++++------- 2 files changed, 20 insertions(+), 13 deletions(-) diff --git a/src/polystore/base.py b/src/polystore/base.py index 2b033fc..e18849e 100644 --- a/src/polystore/base.py +++ b/src/polystore/base.py @@ -546,15 +546,16 @@ def reset_memory_backend() -> None: # Clear files from existing memory backend while preserving directories memory_backend = storage_registry[Backend.MEMORY.value] - # DEBUG: Log what's in memory before clearing existing_keys = list(memory_backend._memory_store.keys()) - logger.info(f"🔍 VFS_CLEAR: Memory backend has {len(existing_keys)} entries BEFORE clear") - logger.info(f"🔍 VFS_CLEAR: First 10 keys: {existing_keys[:10]}") + logger.debug("Memory backend has %s entries before clear", len(existing_keys)) + logger.debug("First memory backend keys before clear: %s", existing_keys[:10]) memory_backend.clear_files_only() - # DEBUG: Log what's in memory after clearing remaining_keys = list(memory_backend._memory_store.keys()) - logger.info(f"🔍 VFS_CLEAR: Memory backend has {len(remaining_keys)} entries AFTER clear (directories only)") - logger.info(f"🔍 VFS_CLEAR: First 10 remaining keys: {remaining_keys[:10]}") + logger.debug( + "Memory backend has %s entries after clear (directories only)", + len(remaining_keys), + ) + logger.debug("First memory backend keys after clear: %s", remaining_keys[:10]) logger.info("Memory backend reset - files cleared, directories preserved") diff --git a/src/polystore/virtual_workspace.py b/src/polystore/virtual_workspace.py index c7bc61b..bec8be5 100644 --- a/src/polystore/virtual_workspace.py +++ b/src/polystore/virtual_workspace.py @@ -205,10 +205,16 @@ def list_files(self, directory: Union[str, Path], pattern: Optional[str] = None, if self._mapping_cache is None: self._load_mapping() - logger.info(f"VirtualWorkspace.list_files called: directory={directory}, recursive={recursive}, pattern={pattern}, extensions={extensions}") - logger.info(f" plate_root={self.plate_root}") - logger.info(f" relative_dir_str='{relative_dir_str}'") - logger.info(f" mapping has {len(self._mapping_cache)} entries") + logger.debug( + "VirtualWorkspace.list_files directory=%s recursive=%s pattern=%s extensions=%s", + directory, + recursive, + pattern, + extensions, + ) + logger.debug(" plate_root=%s", self.plate_root) + logger.debug(" relative_dir_str=%r", relative_dir_str) + logger.debug(" mapping has %s entries", len(self._mapping_cache)) lowercase_extensions = ( None if extensions is None else {ext.lower() for ext in extensions} @@ -240,14 +246,14 @@ def list_files(self, directory: Union[str, Path], pattern: Optional[str] = None, # Return absolute path results.append(str(self.plate_root / virtual_relative)) - logger.info(f" VirtualWorkspace.list_files returning {len(results)} files") + logger.debug(" VirtualWorkspace.list_files returning %s files", len(results)) if len(results) == 0 and len(self._mapping_cache) > 0: # Log first few mapping keys to help debug sample_keys = list(self._mapping_cache.keys())[:3] - logger.info(f" Sample mapping keys: {sample_keys}") + logger.debug(" Sample mapping keys: %s", sample_keys) if not recursive and relative_dir_str == '': sample_parents = [str(Path(k).parent).replace('\\', '/') for k in sample_keys] - logger.info(f" Sample parent dirs: {sample_parents}") + logger.debug(" Sample parent dirs: %s", sample_parents) logger.info(f" Expected parent to match: '{relative_dir_str}'") return sorted(results) From 000ab4ea425ec2f8367d4f1c3d35a69450848b71 Mon Sep 17 00:00:00 2001 From: Tristan Simas Date: Wed, 20 May 2026 17:00:26 -0400 Subject: [PATCH 04/11] Move napari batching to Qt event loop --- .../napari/napari_batch_processor.py | 40 +++++++------------ 1 file changed, 14 insertions(+), 26 deletions(-) diff --git a/src/polystore/streaming/receivers/napari/napari_batch_processor.py b/src/polystore/streaming/receivers/napari/napari_batch_processor.py index b8dcbdd..ad485af 100644 --- a/src/polystore/streaming/receivers/napari/napari_batch_processor.py +++ b/src/polystore/streaming/receivers/napari/napari_batch_processor.py @@ -1,20 +1,16 @@ import logging from typing import Any, Dict, List, Optional -from polystore.streaming.receivers.core import DebouncedBatchEngine - logger = logging.getLogger(__name__) class NapariBatchProcessor: """ - Batch processor for Napari viewer with configurable batching strategies. - - Accumulates items and displays them based on batch_size configuration: - - None: Wait for all items in operation, then display once - - N: Display every N items incrementally - - Uses debouncing to collect items arriving in rapid succession. + Batch processor for Napari viewer display operations. + + Napari layer mutation must run on the Qt event-loop thread. OpenHCS owns that + Qt-thread debounce before this processor is called, so this class only + adapts batch payloads into the server display operation. """ def __init__( @@ -29,22 +25,15 @@ def __init__( Args: napari_server: Reference to NapariViewerServer for display operations - batch_size: Number of items to batch before displaying - None = wait for all (default), N = display every N items - debounce_delay_ms: Wait time after last item before processing (ms) - max_debounce_wait_ms: Maximum total wait time before forcing display (ms) + batch_size: Reserved for compatibility with viewer configuration + debounce_delay_ms: Qt-thread debounce delay owned by the caller + max_debounce_wait_ms: Reserved for compatibility with viewer configuration """ self.napari_server = napari_server self.batch_size = batch_size self.debounce_delay_ms = debounce_delay_ms self.max_debounce_wait_ms = max_debounce_wait_ms - self._engine = DebouncedBatchEngine( - process_fn=self._process_batch, - debounce_delay_ms=debounce_delay_ms, - max_debounce_wait_ms=max_debounce_wait_ms, - ) - logger.info( f"NapariBatchProcessor: Created with batch_size={batch_size}, " f"debounce={debounce_delay_ms}ms, max_wait={max_debounce_wait_ms}ms" @@ -58,7 +47,7 @@ def add_items( component_names_metadata: Dict[str, Any], ): """ - Add items to the batch for processing. + Display items already released by the Qt-thread debounce. Args: layer_key: Unique identifier for the layer @@ -66,9 +55,9 @@ def add_items( display_config: Display configuration dict component_names_metadata: Component name mappings for dimension labels """ - self._engine.enqueue( - items=items, - context={ + self._process_batch( + items, + { "display_config": display_config, "component_names_metadata": component_names_metadata, "layer_key": layer_key, @@ -81,12 +70,11 @@ def add_items( ) def flush(self) -> None: - """Force immediate processing of the pending batch.""" - self._engine.flush() + """Compatibility no-op; OpenHCS owns the Qt-thread debounce timer.""" def _process_batch(self, items: List[Dict[str, Any]], context: Dict[str, Any]) -> None: """Process callback used by shared debounced batch engine.""" - self.napari_server._display_layer_batch( + self.napari_server.display_layer_batch( layer_key=context["layer_key"], items=items, display_config=context["display_config"], From e5909ba51def2e703bb0a9ccb6360262972416b3 Mon Sep 17 00:00:00 2001 From: Tristan Simas Date: Wed, 20 May 2026 18:22:26 -0400 Subject: [PATCH 05/11] Handle unparsed streaming artifact metadata --- src/polystore/__init__.py | 6 +- src/polystore/disk.py | 3 + src/polystore/memory.py | 3 + src/polystore/streaming/_streaming_backend.py | 34 ++++++-- tests/test_streaming_metadata.py | 78 +++++++++++++++++++ 5 files changed, 115 insertions(+), 9 deletions(-) create mode 100644 tests/test_streaming_metadata.py diff --git a/src/polystore/__init__.py b/src/polystore/__init__.py index 5c38d68..123c449 100644 --- a/src/polystore/__init__.py +++ b/src/polystore/__init__.py @@ -26,10 +26,10 @@ get_backend, ) from .constants import Backend, MemoryType, TransportMode -from .disk import DiskStorageBackend +from .disk import DiskBackend, DiskStorageBackend from .filemanager import FileManager from .formats import FileFormat, DEFAULT_IMAGE_EXTENSIONS -from .memory import MemoryStorageBackend +from .memory import MemoryBackend, MemoryStorageBackend from .metadata_writer import ( AtomicMetadataWriter, MetadataWriteError, @@ -76,7 +76,9 @@ "register_cleanup_callback", "STORAGE_BACKENDS", "DiskStorageBackend", + "DiskBackend", "MemoryStorageBackend", + "MemoryBackend", "FileManager", "file_lock", "atomic_write_json", diff --git a/src/polystore/disk.py b/src/polystore/disk.py index fe82b86..ca24e7c 100644 --- a/src/polystore/disk.py +++ b/src/polystore/disk.py @@ -834,3 +834,6 @@ def _save_rois(self, rois: List, output_path: Path, images_dir: str = None, **kw logger.info(f"Saved {roi_count} ROIs to .roi.zip archive: {output_path}") return str(output_path) + + +DiskBackend = DiskStorageBackend diff --git a/src/polystore/memory.py b/src/polystore/memory.py index 5f3f6df..872d581 100644 --- a/src/polystore/memory.py +++ b/src/polystore/memory.py @@ -657,3 +657,6 @@ def __init__(self, target: str): def __repr__(self): return f"" + + +MemoryBackend = MemoryStorageBackend diff --git a/src/polystore/streaming/_streaming_backend.py b/src/polystore/streaming/_streaming_backend.py index 417baa2..932a2c4 100644 --- a/src/polystore/streaming/_streaming_backend.py +++ b/src/polystore/streaming/_streaming_backend.py @@ -9,8 +9,9 @@ import os import time import uuid +from dataclasses import dataclass from pathlib import Path -from typing import Any, Callable, List, Set, Union +from typing import Any, Callable, List, Mapping, Set, Union import numpy as np from ..base import DataSink @@ -24,6 +25,27 @@ logger = logging.getLogger(__name__) +@dataclass(frozen=True) +class StreamingComponentMetadata: + """Message metadata for one streamed item.""" + + parsed_filename_metadata: Mapping[str, Any] | None + source: str + + def to_payload(self) -> dict[str, Any]: + if self.parsed_filename_metadata is None: + metadata: dict[str, Any] = {} + elif isinstance(self.parsed_filename_metadata, Mapping): + metadata = dict(self.parsed_filename_metadata) + else: + raise TypeError( + "Streaming filename parser must return a mapping or None, " + f"got {type(self.parsed_filename_metadata).__name__}." + ) + metadata["source"] = self.source + return metadata + + class StreamingBackend(DataSink): """ Abstract base class for ZeroMQ-based streaming backends. @@ -165,12 +187,10 @@ def _parse_component_metadata(self, file_path: Union[str, Path], microscope_hand Component metadata dict with source added """ filename = os.path.basename(str(file_path)) - component_metadata = microscope_handler.parser.parse_filename(filename) - - # Add pre-built source value directly - component_metadata['source'] = source - - return component_metadata + return StreamingComponentMetadata( + microscope_handler.parser.parse_filename(filename), + source, + ).to_payload() def _detect_data_type(self, data: Any): """ diff --git a/tests/test_streaming_metadata.py b/tests/test_streaming_metadata.py new file mode 100644 index 0000000..8141f1d --- /dev/null +++ b/tests/test_streaming_metadata.py @@ -0,0 +1,78 @@ +from types import SimpleNamespace + +import pytest + +from polystore.streaming._streaming_backend import StreamingBackend + + +class MetadataProbeStreamingBackend(StreamingBackend): + VIEWER_TYPE = "probe" + SHM_PREFIX = "probe_" + + def save_batch(self, data_list, file_paths, **kwargs): + raise NotImplementedError + + +def test_streaming_component_metadata_accepts_unparsed_artifact_filename() -> None: + backend = MetadataProbeStreamingBackend() + microscope_handler = SimpleNamespace( + parser=SimpleNamespace(parse_filename=lambda _filename: None) + ) + + metadata = backend._parse_component_metadata( + "A01_s001_w1_z001_t001_Nuclei_step3_rois.roi.zip", + microscope_handler, + source="IdentifyPrimaryObjects", + ) + + assert metadata == {"source": "IdentifyPrimaryObjects"} + + +def test_streaming_batch_items_accept_unparsed_artifact_filename() -> None: + backend = MetadataProbeStreamingBackend() + microscope_handler = SimpleNamespace( + parser=SimpleNamespace(parse_filename=lambda _filename: None) + ) + + batch_images, image_ids = backend._prepare_batch_items( + [object()], + ["A01_s001_w1_z001_t001_Nuclei_step3_rois.roi.zip"], + microscope_handler, + "IdentifyPrimaryObjects", + lambda _data, _path, _data_type: ({"payload": "ok"}, "image"), + ) + + assert len(image_ids) == 1 + assert batch_images[0]["metadata"] == {"source": "IdentifyPrimaryObjects"} + assert batch_images[0]["payload"] == "ok" + + +def test_streaming_component_metadata_preserves_parsed_filename_fields() -> None: + backend = MetadataProbeStreamingBackend() + microscope_handler = SimpleNamespace( + parser=SimpleNamespace( + parse_filename=lambda _filename: {"well": "A01", "channel": 1} + ) + ) + + metadata = backend._parse_component_metadata( + "A01_s001_w1_z001_t001.TIF", + microscope_handler, + source="Crop", + ) + + assert metadata == {"well": "A01", "channel": 1, "source": "Crop"} + + +def test_streaming_component_metadata_rejects_invalid_parser_result() -> None: + backend = MetadataProbeStreamingBackend() + microscope_handler = SimpleNamespace( + parser=SimpleNamespace(parse_filename=lambda _filename: ["not", "metadata"]) + ) + + with pytest.raises(TypeError, match="mapping or None"): + backend._parse_component_metadata( + "A01_s001_w1_z001_t001.TIF", + microscope_handler, + source="Crop", + ) From 28f670c623e8ec1f53cd93fc10e17096c47f4a06 Mon Sep 17 00:00:00 2001 From: Tristan Simas Date: Thu, 21 May 2026 00:05:51 -0400 Subject: [PATCH 06/11] Improve ROI streaming metadata handling --- pyproject.toml | 1 + src/polystore/fiji_stream.py | 5 +- src/polystore/napari_stream.py | 8 +- src/polystore/roi.py | 142 +++++++--- src/polystore/roi_converters.py | 247 ++++++++++++------ src/polystore/streaming/_streaming_backend.py | 180 ++++++------- .../streaming/receivers/napari/layer_key.py | 15 +- src/polystore/streaming_constants.py | 15 ++ tests/test_roi.py | 79 ++++++ tests/test_streaming_metadata.py | 62 +++-- 10 files changed, 501 insertions(+), 253 deletions(-) create mode 100644 tests/test_roi.py diff --git a/pyproject.toml b/pyproject.toml index 08c8602..5d6f7da 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,7 @@ classifiers = [ ] dependencies = [ + "arraybridge>=0.2.9", "numpy>=1.26.0", "portalocker>=2.8.0", # Cross-platform file locking "metaclass-registry", diff --git a/src/polystore/fiji_stream.py b/src/polystore/fiji_stream.py index 4d52817..08132bc 100644 --- a/src/polystore/fiji_stream.py +++ b/src/polystore/fiji_stream.py @@ -31,12 +31,9 @@ class FijiStreamingBackend(StreamingBackend): """Fiji streaming backend with ZMQ publisher pattern (matches Napari architecture).""" _backend_type = Backend.FIJI_STREAM.value - # Configure ABC attributes VIEWER_TYPE = 'fiji' SHM_PREFIX = 'fiji_' - # __init__, _get_publisher, save, cleanup now inherited from ABC - def _prepare_rois_data(self, data: Any, file_path: Union[str, Path]) -> dict: """ Prepare ROIs data for transmission. @@ -90,6 +87,7 @@ def save_batch(self, data_list: List[Any], file_paths: List[Union[str, Path]], * source = kwargs.get('source', 'unknown_source') # Pre-built source value images_dir = kwargs.get('images_dir') # Source image subdirectory for ROI mapping plate_path = kwargs.get('plate_path') + component_metadata = kwargs.get('component_metadata') logger.info(f"🏷️ FIJI BACKEND: plate_path = {plate_path}") logger.info(f"🏷️ FIJI BACKEND: microscope_handler = {microscope_handler}") display_payload_extra = { @@ -108,6 +106,7 @@ def save_batch(self, data_list: List[Any], file_paths: List[Union[str, Path]], * display_config, self._prepare_batch_item, plate_path=plate_path, + component_metadata=component_metadata, component_names_kwargs={"log_prefix": "🏷️ FIJI BACKEND", "verbose": True}, display_payload_extra=display_payload_extra, message_extra=message_extra, diff --git a/src/polystore/napari_stream.py b/src/polystore/napari_stream.py index 630bcc8..d762cd6 100644 --- a/src/polystore/napari_stream.py +++ b/src/polystore/napari_stream.py @@ -20,7 +20,6 @@ import zmq from .constants import Backend, TransportMode -from .streaming_constants import StreamingDataType from .streaming import StreamingBackend from .roi_converters import NapariROIConverter from zmqruntime.transport import get_zmq_transport_url, coerce_transport_mode @@ -32,12 +31,9 @@ class NapariStreamingBackend(StreamingBackend): """Napari streaming backend with automatic registration.""" _backend_type = Backend.NAPARI_STREAM.value - # Configure ABC attributes VIEWER_TYPE = 'napari' SHM_PREFIX = 'napari_' - # __init__, _get_publisher, save, cleanup now inherited from ABC - def _prepare_shapes_data(self, data: Any, file_path: Union[str, Path]) -> dict: """ Prepare shapes data for transmission. @@ -57,7 +53,7 @@ def _prepare_shapes_data(self, data: Any, file_path: Union[str, Path]) -> dict: } def _prepare_batch_item(self, data: Any, file_path: Union[str, Path], data_type): - if data_type in (StreamingDataType.SHAPES, StreamingDataType.POINTS): + if data_type.uses_napari_vector_payload: item_data = self._prepare_shapes_data(data, file_path) data_type_value = data_type.value else: @@ -88,6 +84,7 @@ def save_batch(self, data_list: List[Any], file_paths: List[Union[str, Path]], * microscope_handler = kwargs['microscope_handler'] source = kwargs.get('source', 'unknown_source') # Pre-built source value plate_path = kwargs.get('plate_path') + component_metadata = kwargs.get('component_metadata') display_payload_extra = { "colormap": display_config.get_colormap_name(), "variable_size_handling": display_config.variable_size_handling.value @@ -103,6 +100,7 @@ def save_batch(self, data_list: List[Any], file_paths: List[Union[str, Path]], * display_config, self._prepare_batch_item, plate_path=plate_path, + component_metadata=component_metadata, display_payload_extra=display_payload_extra, ) diff --git a/src/polystore/roi.py b/src/polystore/roi.py index fb6bdb6..d841591 100644 --- a/src/polystore/roi.py +++ b/src/polystore/roi.py @@ -6,12 +6,14 @@ """ import logging +from abc import ABC, abstractmethod from dataclasses import dataclass, field from enum import Enum from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, ClassVar, Dict, List, Optional, Tuple, Union import numpy as np +from metaclass_registry import AutoRegisterMeta from .constants import Backend @@ -27,8 +29,14 @@ class ShapeType(Enum): ELLIPSE = "ellipse" +class ROIShape(ABC): + """Nominal base for all ROI shape records.""" + + shape_type: ShapeType + + @dataclass(frozen=True) -class PolygonShape: +class PolygonShape(ROIShape): """Polygon ROI shape defined by vertex coordinates.""" coordinates: np.ndarray # Nx2 array of (y, x) coordinates shape_type: ShapeType = field(default=ShapeType.POLYGON, init=False) @@ -41,7 +49,7 @@ def __post_init__(self): @dataclass(frozen=True) -class PolylineShape: +class PolylineShape(ROIShape): """Polyline ROI shape defined by path coordinates (open path, not closed polygon).""" coordinates: np.ndarray # Nx2 array of (y, x) coordinates shape_type: ShapeType = field(default=ShapeType.POLYLINE, init=False) @@ -54,7 +62,7 @@ def __post_init__(self): @dataclass(frozen=True) -class MaskShape: +class MaskShape(ROIShape): """Binary mask ROI shape.""" mask: np.ndarray # 2D boolean array bbox: Tuple[int, int, int, int] # (min_y, min_x, max_y, max_x) @@ -68,7 +76,7 @@ def __post_init__(self): @dataclass(frozen=True) -class PointShape: +class PointShape(ROIShape): """Point ROI shape.""" y: float x: float @@ -76,7 +84,7 @@ class PointShape: @dataclass(frozen=True) -class EllipseShape: +class EllipseShape(ROIShape): """Ellipse ROI shape.""" center_y: float center_x: float @@ -95,14 +103,82 @@ def __post_init__(self): if not self.shapes: raise ValueError("ROI must have at least one shape") for shape in self.shapes: - if not hasattr(shape, "shape_type"): - raise ValueError(f"Shape {shape} must have shape_type attribute") + if not isinstance(shape, ROIShape): + raise ValueError(f"Shape {shape} must be an ROIShape") + + +class ROIJsonShapeDecoder(ABC, metaclass=AutoRegisterMeta): + """Decode one serialized ROI shape variant.""" + + __registry_key__ = "shape_type" + __skip_if_no_key__ = True + + shape_type: ClassVar[ShapeType | None] = None + + @classmethod + def for_serialized_shape(cls, shape_dict: Dict[str, Any]) -> "ROIJsonShapeDecoder | None": + shape_type = shape_dict.get("type") + try: + shape_key = ShapeType(shape_type) + except ValueError: + logger.warning(f"Unknown shape type: {shape_type}, skipping") + return None + return cls.__registry__[shape_key]() + + @abstractmethod + def decode(self, shape_dict: Dict[str, Any]) -> Any: + """Return the concrete ROI shape represented by ``shape_dict``.""" + + +class PolygonROIJsonShapeDecoder(ROIJsonShapeDecoder): + shape_type = ShapeType.POLYGON + + def decode(self, shape_dict: Dict[str, Any]) -> PolygonShape: + return PolygonShape(coordinates=np.array(shape_dict["coordinates"])) + + +class PolylineROIJsonShapeDecoder(ROIJsonShapeDecoder): + shape_type = ShapeType.POLYLINE + + def decode(self, shape_dict: Dict[str, Any]) -> PolylineShape: + return PolylineShape(coordinates=np.array(shape_dict["coordinates"])) + + +class MaskROIJsonShapeDecoder(ROIJsonShapeDecoder): + shape_type = ShapeType.MASK + + def decode(self, shape_dict: Dict[str, Any]) -> MaskShape: + return MaskShape( + mask=np.array(shape_dict["mask"], dtype=bool), + bbox=tuple(shape_dict["bbox"]), + ) + + +class PointROIJsonShapeDecoder(ROIJsonShapeDecoder): + shape_type = ShapeType.POINT + + def decode(self, shape_dict: Dict[str, Any]) -> PointShape: + return PointShape(y=shape_dict["y"], x=shape_dict["x"]) + + +class EllipseROIJsonShapeDecoder(ROIJsonShapeDecoder): + shape_type = ShapeType.ELLIPSE + + def decode(self, shape_dict: Dict[str, Any]) -> EllipseShape: + return EllipseShape( + center_y=shape_dict["center_y"], + center_x=shape_dict["center_x"], + radius_y=shape_dict["radius_y"], + radius_x=shape_dict["radius_x"], + ) def extract_rois_from_labeled_mask( labeled_mask: np.ndarray, min_area: int = 10, extract_contours: bool = True, + spatial_origin_yx: Optional[Tuple[int, int]] = None, + source_spatial_shape_yx: Optional[Tuple[int, int]] = None, ) -> List[ROI]: """Extract ROIs from a labeled segmentation mask.""" from skimage import measure @@ -117,19 +193,33 @@ def extract_rois_from_labeled_mask( regions = regionprops(labeled_mask) slices = find_objects(labeled_mask) + origin_y, origin_x = spatial_origin_yx or (0, 0) rois = [] for region in regions: if region.area < min_area: continue + min_y, min_x, max_y, max_x = region.bbox metadata = { "label": int(region.label), "area": float(region.area), "perimeter": float(region.perimeter), - "centroid": tuple(float(c) for c in region.centroid), - "bbox": tuple(int(b) for b in region.bbox), + "centroid": ( + float(region.centroid[0] + origin_y), + float(region.centroid[1] + origin_x), + ), + "bbox": ( + int(min_y + origin_y), + int(min_x + origin_x), + int(max_y + origin_y), + int(max_x + origin_x), + ), } + if source_spatial_shape_yx is not None: + metadata["source_spatial_shape_yx"] = tuple( + int(value) for value in source_spatial_shape_yx + ) shapes = [] if extract_contours: @@ -142,14 +232,14 @@ def extract_rois_from_labeled_mask( contours = measure.find_contours(padded_mask, level=0.5) offset_y = slice_y.start offset_x = slice_x.start - padding_offset = np.array([offset_y, offset_x]) - 1 + padding_offset = np.array([offset_y + origin_y, offset_x + origin_x]) - 1 for contour in contours: if len(contour) >= 3: contour_full = contour + padding_offset shapes.append(PolygonShape(coordinates=contour_full)) else: binary_mask = (labeled_mask == region.label) - shapes.append(MaskShape(mask=binary_mask, bbox=region.bbox)) + shapes.append(MaskShape(mask=binary_mask, bbox=metadata["bbox"])) if shapes: rois.append(ROI(shapes=shapes, metadata=metadata)) @@ -203,31 +293,9 @@ def load_rois_from_json(json_path: Path) -> List[ROI]: metadata = roi_dict.get("metadata", {}) shapes = [] for shape_dict in roi_dict.get("shapes", []): - shape_type = shape_dict.get("type") - - if shape_type == "polygon": - coordinates = np.array(shape_dict["coordinates"]) - shapes.append(PolygonShape(coordinates=coordinates)) - elif shape_type == "polyline": - coordinates = np.array(shape_dict["coordinates"]) - shapes.append(PolylineShape(coordinates=coordinates)) - elif shape_type == "mask": - mask = np.array(shape_dict["mask"], dtype=bool) - bbox = tuple(shape_dict["bbox"]) - shapes.append(MaskShape(mask=mask, bbox=bbox)) - elif shape_type == "point": - shapes.append(PointShape(y=shape_dict["y"], x=shape_dict["x"])) - elif shape_type == "ellipse": - shapes.append( - EllipseShape( - center_y=shape_dict["center_y"], - center_x=shape_dict["center_x"], - radius_y=shape_dict["radius_y"], - radius_x=shape_dict["radius_x"], - ) - ) - else: - logger.warning(f"Unknown shape type: {shape_type}, skipping") + decoder = ROIJsonShapeDecoder.for_serialized_shape(shape_dict) + if decoder is not None: + shapes.append(decoder.decode(shape_dict)) if shapes: rois.append(ROI(shapes=shapes, metadata=metadata)) diff --git a/src/polystore/roi_converters.py b/src/polystore/roi_converters.py index 46e8631..616e4c4 100644 --- a/src/polystore/roi_converters.py +++ b/src/polystore/roi_converters.py @@ -7,63 +7,184 @@ """ import logging -from typing import Any, Dict, List, Tuple +from abc import ABC, abstractmethod +from dataclasses import dataclass +from enum import Enum +from typing import Any, ClassVar, Dict, List, Tuple import numpy as np +from metaclass_registry import AutoRegisterMeta -from .roi import EllipseShape, PointShape, PolygonShape, PolylineShape, ROI -from .streaming_constants import NapariShapeType +from .roi import EllipseShape, PointShape, PolygonShape, PolylineShape, ROI, ShapeType logger = logging.getLogger(__name__) -class NapariROIConverter: - """Convert ROI objects to Napari shapes format.""" +@dataclass(frozen=True, slots=True) +class NapariShapeTypeAlias: + """Inert alias from Napari wire shape names to ROI shape types.""" + + alias: str + shape_type: ShapeType + + +NAPARI_SHAPE_TYPE_ALIASES = ( + NapariShapeTypeAlias("path", ShapeType.POLYLINE), + NapariShapeTypeAlias("points", ShapeType.POINT), +) + + +class NapariShapeConverter(ABC, metaclass=AutoRegisterMeta): + """Registered conversion behavior for one ROI shape type.""" + + __registry_key__ = "shape_type" + __skip_if_no_key__ = True + + shape_type: ClassVar[ShapeType | None] = None + + @classmethod + def for_shape_dict(cls, shape_dict: Dict[str, Any]) -> "NapariShapeConverter": + return cls.__registry__[_shape_type_from_napari(shape_dict["type"])]() + + def append_common_properties( + self, + metadata: Dict[str, Any], + properties: dict[str, list[Any]], + centroid: tuple[Any, Any], + *, + area: Any | None = None, + ) -> None: + properties["label"].append(metadata.get("label", "")) + properties["area"].append(metadata.get("area", 0) if area is None else area) + properties["centroid_y"].append(centroid[0]) + properties["centroid_x"].append(centroid[1]) + + @abstractmethod + def add_dimensions(self, shape_dict: Dict[str, Any], prepend_dims: np.ndarray) -> np.ndarray: + """Add dimensions to a 2D shape to make it nD.""" + + @abstractmethod + def append_napari_format( + self, + shape_dict: Dict[str, Any], + napari_shapes: list[np.ndarray], + shape_types: list[str], + properties: dict[str, list[Any]], + ) -> None: + """Append this shape to a Napari layer payload.""" + + +def _shape_type_from_napari(shape_type: object) -> ShapeType: + if isinstance(shape_type, ShapeType): + return shape_type + value = str(shape_type.value) if isinstance(shape_type, Enum) else str(shape_type) + for alias in NAPARI_SHAPE_TYPE_ALIASES: + if alias.alias == value: + return alias.shape_type + return ShapeType(value) + + +class CoordinateNapariShapeConverter(NapariShapeConverter): + """Shared converter for coordinate-list shapes.""" - _SHAPE_DIMENSION_HANDLERS = { - "polygon": lambda shape_dict, prepend_dims: np.hstack( - [np.tile(prepend_dims, (len(shape_dict["coordinates"]), 1)), np.array(shape_dict["coordinates"])] - ), - "polyline": lambda shape_dict, prepend_dims: np.hstack( - [np.tile(prepend_dims, (len(shape_dict["coordinates"]), 1)), np.array(shape_dict["coordinates"])] - ), - "ellipse": lambda shape_dict, prepend_dims: np.hstack( + napari_shape_type: ClassVar[str] + + def add_dimensions(self, shape_dict: Dict[str, Any], prepend_dims: np.ndarray) -> np.ndarray: + coordinates = np.array(shape_dict["coordinates"]) + return np.hstack([np.tile(prepend_dims, (len(coordinates), 1)), coordinates]) + + def append_napari_format( + self, + shape_dict: Dict[str, Any], + napari_shapes: list[np.ndarray], + shape_types: list[str], + properties: dict[str, list[Any]], + ) -> None: + metadata = shape_dict.get("metadata", {}) + napari_shapes.append(np.array(shape_dict["coordinates"])) + shape_types.append(self.napari_shape_type) + self.append_common_properties( + metadata, + properties, + metadata.get("centroid", (0, 0)), + ) + + +class PolygonNapariShapeConverter(CoordinateNapariShapeConverter): + shape_type = ShapeType.POLYGON + napari_shape_type = "polygon" + + +class PolylineNapariShapeConverter(CoordinateNapariShapeConverter): + shape_type = ShapeType.POLYLINE + napari_shape_type = "path" + + +class EllipseNapariShapeConverter(NapariShapeConverter): + shape_type = ShapeType.ELLIPSE + + def add_dimensions(self, shape_dict: Dict[str, Any], prepend_dims: np.ndarray) -> np.ndarray: + center = shape_dict["center"] + radii = shape_dict["radii"] + corners = np.array( [ - np.tile(prepend_dims, (4, 1)), - np.array( - [ - [ - shape_dict["center"][0] - shape_dict["radii"][0], - shape_dict["center"][1] - shape_dict["radii"][1], - ], - [ - shape_dict["center"][0] - shape_dict["radii"][0], - shape_dict["center"][1] + shape_dict["radii"][1], - ], - [ - shape_dict["center"][0] + shape_dict["radii"][0], - shape_dict["center"][1] + shape_dict["radii"][1], - ], - [ - shape_dict["center"][0] + shape_dict["radii"][0], - shape_dict["center"][1] - shape_dict["radii"][1], - ], - ] - ), + [center[0] - radii[0], center[1] - radii[1]], + [center[0] - radii[0], center[1] + radii[1]], + [center[0] + radii[0], center[1] + radii[1]], + [center[0] + radii[0], center[1] - radii[1]], ] - ), - "point": lambda shape_dict, prepend_dims: np.concatenate([prepend_dims, shape_dict["coordinates"]]).reshape(1, -1), - } + ) + return np.hstack([np.tile(prepend_dims, (4, 1)), corners]) + + def append_napari_format( + self, + shape_dict: Dict[str, Any], + napari_shapes: list[np.ndarray], + shape_types: list[str], + properties: dict[str, list[Any]], + ) -> None: + metadata = shape_dict.get("metadata", {}) + center = np.array(shape_dict["center"]) + radii = np.array(shape_dict["radii"]) + napari_shapes.append(np.array([center - radii, center + radii])) + shape_types.append("ellipse") + self.append_common_properties( + metadata, + properties, + metadata.get("centroid", (0, 0)), + ) + + +class PointNapariShapeConverter(NapariShapeConverter): + shape_type = ShapeType.POINT + + def add_dimensions(self, shape_dict: Dict[str, Any], prepend_dims: np.ndarray) -> np.ndarray: + return np.concatenate([prepend_dims, shape_dict["coordinates"]]).reshape(1, -1) + + def append_napari_format( + self, + shape_dict: Dict[str, Any], + napari_shapes: list[np.ndarray], + shape_types: list[str], + properties: dict[str, list[Any]], + ) -> None: + metadata = shape_dict.get("metadata", {}) + coordinates = shape_dict["coordinates"] + napari_shapes.append(np.array([coordinates])) + shape_types.append("point") + self.append_common_properties(metadata, properties, coordinates, area=0) + + +class NapariROIConverter: + """Convert ROI objects to Napari shapes format.""" @staticmethod def add_dimensions_to_shape(shape_dict: Dict[str, Any], prepend_dims: List[float]) -> np.ndarray: """Add dimensions to a 2D shape to make it nD.""" - shape_type = shape_dict["type"] - shape_type_enum = NapariShapeType(shape_type) if isinstance(shape_type, str) else shape_type - handler = NapariROIConverter._SHAPE_DIMENSION_HANDLERS.get(shape_type_enum.value) - if handler is None: - raise ValueError(f"Unsupported shape type: {shape_type}") - return handler(shape_dict, np.array(prepend_dims)) + return NapariShapeConverter.for_shape_dict(shape_dict).add_dimensions( + shape_dict, + np.array(prepend_dims), + ) @staticmethod def rois_to_shapes(rois: List[ROI]) -> List[Dict[str, Any]]: @@ -104,40 +225,12 @@ def shapes_to_napari_format(shapes_data: List[Dict]) -> Tuple[List[np.ndarray], properties = {"label": [], "area": [], "centroid_y": [], "centroid_x": []} for shape_dict in shapes_data: - shape_type = shape_dict.get("type") - metadata = shape_dict.get("metadata", {}) - - if shape_type == "polygon": - coords = np.array(shape_dict["coordinates"]) - napari_shapes.append(coords) - shape_types.append("polygon") - centroid = metadata.get("centroid", (0, 0)) - properties["label"].append(metadata.get("label", "")) - properties["area"].append(metadata.get("area", 0)) - properties["centroid_y"].append(centroid[0]) - properties["centroid_x"].append(centroid[1]) - - elif shape_type == "ellipse": - center = np.array(shape_dict["center"]) - radii = np.array(shape_dict["radii"]) - corners = np.array([center - radii, center + radii]) - napari_shapes.append(corners) - shape_types.append("ellipse") - centroid = metadata.get("centroid", (0, 0)) - properties["label"].append(metadata.get("label", "")) - properties["area"].append(metadata.get("area", 0)) - properties["centroid_y"].append(centroid[0]) - properties["centroid_x"].append(centroid[1]) - - elif shape_type == "point": - coords = np.array([shape_dict["coordinates"]]) - napari_shapes.append(coords) - shape_types.append("point") - point_coords = shape_dict["coordinates"] - properties["label"].append(metadata.get("label", "")) - properties["area"].append(0) - properties["centroid_y"].append(point_coords[0]) - properties["centroid_x"].append(point_coords[1]) + NapariShapeConverter.for_shape_dict(shape_dict).append_napari_format( + shape_dict, + napari_shapes, + shape_types, + properties, + ) return napari_shapes, shape_types, properties diff --git a/src/polystore/streaming/_streaming_backend.py b/src/polystore/streaming/_streaming_backend.py index 932a2c4..1f3a70a 100644 --- a/src/polystore/streaming/_streaming_backend.py +++ b/src/polystore/streaming/_streaming_backend.py @@ -13,6 +13,8 @@ from pathlib import Path from typing import Any, Callable, List, Mapping, Set, Union import numpy as np +from arraybridge import convert_memory, detect_memory_type +from arraybridge.types import MemoryType as ArrayBridgeMemoryType from ..base import DataSink from ..constants import TransportMode @@ -20,32 +22,62 @@ from ..roi import ROI, PointShape from ..zmq_config import POLYSTORE_ZMQ_CONFIG from zmqruntime.ack_listener import GlobalAckListener -from zmqruntime.transport import coerce_transport_mode, get_zmq_transport_url +from zmqruntime.transport import coerce_transport_mode logger = logging.getLogger(__name__) +PrepareStreamingItem = Callable[[Any, Union[str, Path], Any], tuple[dict, str]] + + @dataclass(frozen=True) class StreamingComponentMetadata: """Message metadata for one streamed item.""" - parsed_filename_metadata: Mapping[str, Any] | None + parsed_filename_metadata: Mapping[str, Any] source: str def to_payload(self) -> dict[str, Any]: - if self.parsed_filename_metadata is None: - metadata: dict[str, Any] = {} - elif isinstance(self.parsed_filename_metadata, Mapping): + if isinstance(self.parsed_filename_metadata, Mapping): metadata = dict(self.parsed_filename_metadata) else: raise TypeError( - "Streaming filename parser must return a mapping or None, " + "Streaming component metadata must be a mapping, " f"got {type(self.parsed_filename_metadata).__name__}." ) metadata["source"] = self.source return metadata +@dataclass(frozen=True) +class StreamingBatchRequest: + """Shared provenance for one streaming batch.""" + + data_list: List[Any] + file_paths: List[Union[str, Path]] + microscope_handler: Any + source: str + prepare_item: PrepareStreamingItem + component_metadata: Mapping[str, Any] | None = None + + +class StreamingPayloadMemoryAuthority: + """Memory conversion authority for streamable image payloads.""" + + @staticmethod + def to_numpy(data: Any) -> np.ndarray: + if isinstance(data, np.ndarray): + return data + if isinstance(data, (list, tuple)): + return np.asarray(data) + return convert_memory( + data, + detect_memory_type(data), + ArrayBridgeMemoryType.NUMPY.value, + gpu_id=0, + ) + + class StreamingBackend(DataSink): """ Abstract base class for ZeroMQ-based streaming backends. @@ -126,55 +158,13 @@ def __init__(self, transport_config=None): self._shared_memory_blocks = {} self._transport_config = transport_config or POLYSTORE_ZMQ_CONFIG - def _get_publisher(self, host: str, port: int, transport_mode: TransportMode, transport_config=None): - """ - Lazy initialization of ZeroMQ publisher (common for all streaming backends). - - Uses REQ socket for Fiji (synchronous request/reply with blocking) - and PUB socket for Napari (broadcast pattern). - - Args: - host: Host to connect to (ignored for IPC mode) - port: Port to connect to - transport_mode: IPC or TCP transport (required - comes from config) - - Returns: - ZeroMQ publisher socket - """ - # Generate transport URL using centralized function - transport_config = transport_config or self._transport_config - url = get_zmq_transport_url( - port, - host=host, - mode=coerce_transport_mode(transport_mode), - config=transport_config, - ) - - key = url # Use URL as key instead of host:port - if key not in self._publishers: - try: - import zmq - if self._context is None: - self._context = zmq.Context() - - # Use REQ socket for all viewers (synchronous request/reply) - # All viewers must send acknowledgment after processing - publisher = self._context.socket(zmq.REQ) - - publisher.connect(url) - socket_name = "REQ" - logger.info(f"{self.VIEWER_TYPE} streaming {socket_name} socket connected to {url}") - time.sleep(0.1) - self._publishers[key] = publisher - - except ImportError: - logger.error("ZeroMQ not available - streaming disabled") - raise RuntimeError("ZeroMQ required for streaming") - - return self._publishers[key] - - def _parse_component_metadata(self, file_path: Union[str, Path], microscope_handler, - source: str) -> dict: + def _parse_component_metadata( + self, + file_path: Union[str, Path], + microscope_handler, + source: str, + component_metadata: Mapping[str, Any] | None = None, + ) -> dict: """ Parse component metadata from filename (common for all streaming backends). @@ -187,10 +177,17 @@ def _parse_component_metadata(self, file_path: Union[str, Path], microscope_hand Component metadata dict with source added """ filename = os.path.basename(str(file_path)) - return StreamingComponentMetadata( - microscope_handler.parser.parse_filename(filename), - source, - ).to_payload() + parsed_metadata = ( + component_metadata + if component_metadata is not None + else microscope_handler.parser.parse_filename(filename) + ) + if parsed_metadata is None: + raise ValueError( + "Streaming component metadata requires explicit component_metadata " + f"or a parser-readable filename; got {filename!r}." + ) + return StreamingComponentMetadata(parsed_metadata, source).to_payload() def _detect_data_type(self, data: Any): """ @@ -226,9 +223,7 @@ def _create_shared_memory(self, data: Any, file_path: Union[str, Path]) -> dict: Returns: Dict with shared memory metadata """ - # Convert to numpy - np_data = data.cpu().numpy() if hasattr(data, 'cpu') else \ - data.get() if hasattr(data, 'get') else np.asarray(data) + np_data = StreamingPayloadMemoryAuthority.to_numpy(data) # Create shared memory with hash-based naming to avoid "File name too long" errors # Hash the timestamp and object ID to create a short, unique name @@ -289,13 +284,7 @@ def _register_with_queue_tracker( tracker.register_sent(image_id) def _build_component_modes(self, display_config) -> dict: - component_modes = {} - for comp_name in display_config.COMPONENT_ORDER: - mode_field = f"{comp_name}_mode" - if hasattr(display_config, mode_field): - mode = getattr(display_config, mode_field) - component_modes[comp_name] = mode.value - return component_modes + return display_config.component_modes() def _build_display_config_base(self, display_config, component_modes: dict) -> dict: return { @@ -324,20 +313,14 @@ def _collect_component_names_metadata( try: for comp_name in component_names: - method_name = f"get_{comp_name}_values" - method = getattr(microscope_handler.metadata_handler, method_name, None) - if callable(method): - try: - metadata = method(plate_path) - if verbose and log_prefix: - logger.info(f"{log_prefix}: Got {comp_name} metadata: {metadata}") - if metadata: - component_names_metadata[comp_name] = metadata - except Exception as e: - if verbose and log_prefix: - logger.warning(f"{log_prefix}: Could not get {comp_name} metadata: {e}", exc_info=True) - elif verbose and log_prefix: - logger.info(f"{log_prefix}: No method {method_name} on metadata_handler") + metadata = microscope_handler.metadata_handler.get_component_values( + plate_path, + comp_name, + ) + if verbose and log_prefix: + logger.info(f"{log_prefix}: Got {comp_name} metadata: {metadata}") + if metadata: + component_names_metadata[comp_name] = metadata except Exception as e: if verbose and log_prefix: logger.warning(f"{log_prefix}: Could not get component metadata: {e}", exc_info=True) @@ -346,24 +329,23 @@ def _collect_component_names_metadata( def _prepare_batch_items( self, - data_list: List[Any], - file_paths: List[Union[str, Path]], - microscope_handler, - source: str, - prepare_item: Callable[[Any, Union[str, Path], Any], tuple[dict, str]], + request: StreamingBatchRequest, ) -> tuple[list[dict], list[str]]: batch_images = [] image_ids = [] - for data, file_path in zip(data_list, file_paths): + for data, file_path in zip(request.data_list, request.file_paths): image_id = str(uuid.uuid4()) image_ids.append(image_id) data_type = self._detect_data_type(data) component_metadata = self._parse_component_metadata( - file_path, microscope_handler, source + file_path, + request.microscope_handler, + request.source, + request.component_metadata, ) - item_data, data_type_value = prepare_item(data, file_path, data_type) + item_data, data_type_value = request.prepare_item(data, file_path, data_type) batch_images.append( { @@ -383,9 +365,10 @@ def _build_batch_message( microscope_handler, source: str, display_config, - prepare_item: Callable[[Any, Union[str, Path], Any], tuple[dict, str]], + prepare_item: PrepareStreamingItem, plate_path: Union[str, Path, None] = None, component_names_kwargs: dict | None = None, + component_metadata: Mapping[str, Any] | None = None, display_payload_extra: dict | None = None, message_extra: dict | None = None, ) -> tuple[dict, list[dict], list[str]]: @@ -393,11 +376,14 @@ def _build_batch_message( raise ValueError("data_list and file_paths must have the same length") batch_images, image_ids = self._prepare_batch_items( - data_list, - file_paths, - microscope_handler, - source, - prepare_item, + StreamingBatchRequest( + data_list=data_list, + file_paths=file_paths, + microscope_handler=microscope_handler, + source=source, + prepare_item=prepare_item, + component_metadata=component_metadata, + ) ) component_modes = self._build_component_modes(display_config) diff --git a/src/polystore/streaming/receivers/napari/layer_key.py b/src/polystore/streaming/receivers/napari/layer_key.py index dec6fff..51b7d67 100644 --- a/src/polystore/streaming/receivers/napari/layer_key.py +++ b/src/polystore/streaming/receivers/napari/layer_key.py @@ -14,13 +14,7 @@ def normalize_component_layout(display_config: Any) -> tuple[dict[str, str], lis component_order = display_config["component_order"] return component_modes, component_order - component_order = list(display_config.COMPONENT_ORDER) - component_modes: dict[str, str] = {} - for component in component_order: - mode_field = f"{component}_mode" - mode_value = display_config.__getattribute__(mode_field) - component_modes[component] = mode_value.value - return component_modes, component_order + return display_config.component_modes(), list(display_config.COMPONENT_ORDER) def build_layer_key( @@ -38,9 +32,4 @@ def build_layer_key( layer_key = "_".join(layer_key_parts) if layer_key_parts else "default_layer" - if data_type == StreamingDataType.SHAPES: - return f"{layer_key}_shapes" - if data_type == StreamingDataType.POINTS: - return f"{layer_key}_points" - return layer_key - + return f"{layer_key}{data_type.napari_layer_suffix}" diff --git a/src/polystore/streaming_constants.py b/src/polystore/streaming_constants.py index d7f0596..05c834c 100644 --- a/src/polystore/streaming_constants.py +++ b/src/polystore/streaming_constants.py @@ -15,6 +15,21 @@ class StreamingDataType(Enum): POINTS = "points" # Napari points layer (e.g., skeleton tracings) ROIS = "rois" # Fiji ROI payloads + @property + def uses_napari_vector_payload(self) -> bool: + """Whether napari should receive this type through vector layer payloads.""" + return self in (type(self).SHAPES, type(self).POINTS) + + @property + def napari_layer_suffix(self) -> str: + """Layer-key suffix contributed by this data type.""" + return { + type(self).IMAGE: "", + type(self).SHAPES: "_shapes", + type(self).POINTS: "_points", + type(self).ROIS: "", + }[self] + class NapariShapeType(Enum): """Napari shape types for ROI visualization.""" diff --git a/tests/test_roi.py b/tests/test_roi.py new file mode 100644 index 0000000..565022f --- /dev/null +++ b/tests/test_roi.py @@ -0,0 +1,79 @@ +import numpy as np + +from polystore.roi import MaskShape +from polystore.roi import PolygonShape +from polystore.roi import load_rois_from_json +from polystore.roi import extract_rois_from_labeled_mask + + +def test_extract_rois_from_labeled_mask_applies_spatial_origin_to_polygons(): + labels = np.zeros((8, 8), dtype=np.int32) + labels[2:6, 3:7] = 1 + + rois = extract_rois_from_labeled_mask( + labels, + min_area=0, + extract_contours=True, + spatial_origin_yx=(10, 20), + ) + + assert len(rois) == 1 + assert rois[0].metadata["bbox"] == (12, 23, 16, 27) + assert rois[0].metadata["centroid"] == (13.5, 24.5) + assert isinstance(rois[0].shapes[0], PolygonShape) + assert float(rois[0].shapes[0].coordinates[:, 0].min()) >= 11.5 + assert float(rois[0].shapes[0].coordinates[:, 1].min()) >= 22.5 + + +def test_extract_rois_from_labeled_mask_applies_spatial_origin_to_mask_bbox(): + labels = np.zeros((8, 8), dtype=np.int32) + labels[2:6, 3:7] = 1 + + rois = extract_rois_from_labeled_mask( + labels, + min_area=0, + extract_contours=False, + spatial_origin_yx=(10, 20), + ) + + assert len(rois) == 1 + assert isinstance(rois[0].shapes[0], MaskShape) + assert rois[0].shapes[0].bbox == (12, 23, 16, 27) + + +def test_extract_rois_from_labeled_mask_records_source_canvas_shape(): + labels = np.zeros((8, 8), dtype=np.int32) + labels[2:6, 3:7] = 1 + + rois = extract_rois_from_labeled_mask( + labels, + min_area=0, + source_spatial_shape_yx=(100, 200), + ) + + assert len(rois) == 1 + assert rois[0].metadata["source_spatial_shape_yx"] == (100, 200) + + +def test_load_rois_from_json_decodes_shapes_through_nominal_registry(tmp_path): + roi_path = tmp_path / "rois.json" + roi_path.write_text( + """ + [ + { + "metadata": {"label": 1}, + "shapes": [ + {"type": "polygon", "coordinates": [[1, 2], [3, 4], [5, 6]]}, + {"type": "mask", "mask": [[true, false], [false, true]], "bbox": [10, 20, 12, 22]} + ] + } + ] + """ + ) + + rois = load_rois_from_json(roi_path) + + assert len(rois) == 1 + assert isinstance(rois[0].shapes[0], PolygonShape) + assert isinstance(rois[0].shapes[1], MaskShape) + assert rois[0].shapes[1].bbox == (10, 20, 12, 22) diff --git a/tests/test_streaming_metadata.py b/tests/test_streaming_metadata.py index 8141f1d..95d3b00 100644 --- a/tests/test_streaming_metadata.py +++ b/tests/test_streaming_metadata.py @@ -3,6 +3,7 @@ import pytest from polystore.streaming._streaming_backend import StreamingBackend +from polystore.streaming._streaming_backend import StreamingBatchRequest class MetadataProbeStreamingBackend(StreamingBackend): @@ -13,38 +14,36 @@ def save_batch(self, data_list, file_paths, **kwargs): raise NotImplementedError -def test_streaming_component_metadata_accepts_unparsed_artifact_filename() -> None: +def test_streaming_component_metadata_rejects_unparsed_artifact_filename() -> None: backend = MetadataProbeStreamingBackend() microscope_handler = SimpleNamespace( parser=SimpleNamespace(parse_filename=lambda _filename: None) ) - metadata = backend._parse_component_metadata( - "A01_s001_w1_z001_t001_Nuclei_step3_rois.roi.zip", - microscope_handler, - source="IdentifyPrimaryObjects", - ) - - assert metadata == {"source": "IdentifyPrimaryObjects"} + with pytest.raises(ValueError, match="explicit component_metadata"): + backend._parse_component_metadata( + "A01_s001_w1_z001_t001_Nuclei_step3_rois.roi.zip", + microscope_handler, + source="IdentifyPrimaryObjects", + ) -def test_streaming_batch_items_accept_unparsed_artifact_filename() -> None: +def test_streaming_batch_items_reject_unparsed_artifact_filename() -> None: backend = MetadataProbeStreamingBackend() microscope_handler = SimpleNamespace( parser=SimpleNamespace(parse_filename=lambda _filename: None) ) - batch_images, image_ids = backend._prepare_batch_items( - [object()], - ["A01_s001_w1_z001_t001_Nuclei_step3_rois.roi.zip"], - microscope_handler, - "IdentifyPrimaryObjects", - lambda _data, _path, _data_type: ({"payload": "ok"}, "image"), - ) - - assert len(image_ids) == 1 - assert batch_images[0]["metadata"] == {"source": "IdentifyPrimaryObjects"} - assert batch_images[0]["payload"] == "ok" + with pytest.raises(ValueError, match="explicit component_metadata"): + backend._prepare_batch_items( + StreamingBatchRequest( + data_list=[object()], + file_paths=["A01_s001_w1_z001_t001_Nuclei_step3_rois.roi.zip"], + microscope_handler=microscope_handler, + source="IdentifyPrimaryObjects", + prepare_item=lambda _data, _path, _data_type: ({"payload": "ok"}, "image"), + ) + ) def test_streaming_component_metadata_preserves_parsed_filename_fields() -> None: @@ -64,13 +63,34 @@ def test_streaming_component_metadata_preserves_parsed_filename_fields() -> None assert metadata == {"well": "A01", "channel": 1, "source": "Crop"} +def test_streaming_component_metadata_prefers_explicit_metadata() -> None: + backend = MetadataProbeStreamingBackend() + microscope_handler = SimpleNamespace( + parser=SimpleNamespace(parse_filename=lambda _filename: None) + ) + + metadata = backend._parse_component_metadata( + "A01_s001_w1_z001_t001_Nuclei_step3_rois.roi.zip", + microscope_handler, + source="IdentifyPrimaryObjects", + component_metadata={"well": "A01", "site": 1, "channel": 1}, + ) + + assert metadata == { + "well": "A01", + "site": 1, + "channel": 1, + "source": "IdentifyPrimaryObjects", + } + + def test_streaming_component_metadata_rejects_invalid_parser_result() -> None: backend = MetadataProbeStreamingBackend() microscope_handler = SimpleNamespace( parser=SimpleNamespace(parse_filename=lambda _filename: ["not", "metadata"]) ) - with pytest.raises(TypeError, match="mapping or None"): + with pytest.raises(TypeError, match="must be a mapping"): backend._parse_component_metadata( "A01_s001_w1_z001_t001.TIF", microscope_handler, From ce8dc58fd180064dd6a5bdd9054210a7d8f96780 Mon Sep 17 00:00:00 2001 From: Tristan Simas Date: Sun, 24 May 2026 13:57:49 -0400 Subject: [PATCH 07/11] Support stacked labeled mask ROI extraction --- src/polystore/roi.py | 234 +++++++++++++++++++++++++++++++------------ 1 file changed, 169 insertions(+), 65 deletions(-) diff --git a/src/polystore/roi.py b/src/polystore/roi.py index d841591..26c1ef1 100644 --- a/src/polystore/roi.py +++ b/src/polystore/roi.py @@ -107,6 +107,167 @@ def __post_init__(self): raise ValueError(f"Shape {shape} must be an ROIShape") +@dataclass(frozen=True, slots=True) +class LabeledMaskROIExtractionRequest: + """Request to extract ROIs from a labeled mask or stack.""" + + labeled_mask: np.ndarray + min_area: int = 10 + extract_contours: bool = True + spatial_origin_yx: Optional[Tuple[int, int]] = None + source_spatial_shape_yx: Optional[Tuple[int, int]] = None + + +class LabeledMaskROIExtractor(ABC, metaclass=AutoRegisterMeta): + """Registered extraction behavior for one labeled-mask dimensional family.""" + + __registry_key__ = "__name__" + __skip_if_no_key__ = True + + @classmethod + def for_request( + cls, + request: LabeledMaskROIExtractionRequest, + ) -> "LabeledMaskROIExtractor": + for extractor_type in cls.__registry__.values(): + extractor = extractor_type() + if extractor.accepts(request.labeled_mask): + return extractor + raise ValueError( + "No ROI extractor registered for labeled mask shape " + f"{request.labeled_mask.shape}." + ) + + @abstractmethod + def accepts(self, labeled_mask: np.ndarray) -> bool: + """Return whether this extractor owns the mask dimensionality.""" + + @abstractmethod + def extract(self, request: LabeledMaskROIExtractionRequest) -> List[ROI]: + """Extract ROIs from the request.""" + + +class TwoDimensionalLabeledMaskROIExtractor(LabeledMaskROIExtractor): + """Extract ROIs from a single 2D labeled mask.""" + + def accepts(self, labeled_mask: np.ndarray) -> bool: + return labeled_mask.ndim == 2 + + def extract(self, request: LabeledMaskROIExtractionRequest) -> List[ROI]: + from skimage import measure + from skimage.measure import regionprops + from scipy.ndimage import find_objects + + labeled_mask = request.labeled_mask + if not np.issubdtype(labeled_mask.dtype, np.integer): + labeled_mask = labeled_mask.astype(np.int32) + + regions = regionprops(labeled_mask) + slices = find_objects(labeled_mask) + origin_y, origin_x = request.spatial_origin_yx or (0, 0) + + rois = [] + for region in regions: + if region.area < request.min_area: + continue + min_y, min_x, max_y, max_x = region.bbox + + metadata = { + "label": int(region.label), + "area": float(region.area), + "perimeter": float(region.perimeter), + "centroid": ( + float(region.centroid[0] + origin_y), + float(region.centroid[1] + origin_x), + ), + "bbox": ( + int(min_y + origin_y), + int(min_x + origin_x), + int(max_y + origin_y), + int(max_x + origin_x), + ), + } + if request.source_spatial_shape_yx is not None: + metadata["source_spatial_shape_yx"] = tuple( + int(value) for value in request.source_spatial_shape_yx + ) + + shapes = [] + if request.extract_contours: + label_idx = region.label - 1 + if label_idx < len(slices) and slices[label_idx] is not None: + slice_y, slice_x = slices[label_idx] + cropped_mask = labeled_mask[slice_y, slice_x] + binary_mask = (cropped_mask == region.label).astype(np.uint8) + padded_mask = np.pad(binary_mask, pad_width=1, mode="constant", constant_values=0) + contours = measure.find_contours(padded_mask, level=0.5) + offset_y = slice_y.start + offset_x = slice_x.start + padding_offset = np.array([offset_y + origin_y, offset_x + origin_x]) - 1 + for contour in contours: + if len(contour) >= 3: + contour_full = contour + padding_offset + shapes.append(PolygonShape(coordinates=contour_full)) + else: + binary_mask = labeled_mask == region.label + shapes.append(MaskShape(mask=binary_mask, bbox=metadata["bbox"])) + + if shapes: + rois.append(ROI(shapes=shapes, metadata=metadata)) + + logger.info(f"Extracted {len(rois)} ROIs from labeled mask") + return rois + + +class NonSpatialLabeledMaskROIExtractor(LabeledMaskROIExtractor): + """Treat scalar and otherwise non-spatial label payloads as empty ROI sets.""" + + def accepts(self, labeled_mask: np.ndarray) -> bool: + return labeled_mask.ndim < 2 + + def extract(self, request: LabeledMaskROIExtractionRequest) -> List[ROI]: + return [] + + +class StackedLabeledMaskROIExtractor(LabeledMaskROIExtractor): + """Extract ROIs from all 2D planes in a labeled-mask stack.""" + + def accepts(self, labeled_mask: np.ndarray) -> bool: + return labeled_mask.ndim > 2 + + def extract(self, request: LabeledMaskROIExtractionRequest) -> List[ROI]: + stack = request.labeled_mask + plane_shape = stack.shape[-2:] + leading_shape = stack.shape[:-2] + rois: list[ROI] = [] + for plane_indices in np.ndindex(leading_shape): + plane_request = LabeledMaskROIExtractionRequest( + labeled_mask=stack[plane_indices], + min_area=request.min_area, + extract_contours=request.extract_contours, + spatial_origin_yx=request.spatial_origin_yx, + source_spatial_shape_yx=request.source_spatial_shape_yx or plane_shape, + ) + for roi in TwoDimensionalLabeledMaskROIExtractor().extract(plane_request): + rois.append(self._with_plane_metadata(roi, plane_indices, leading_shape)) + return rois + + @staticmethod + def _with_plane_metadata( + roi: ROI, + plane_indices: tuple[int, ...], + leading_shape: tuple[int, ...], + ) -> ROI: + return ROI( + shapes=roi.shapes, + metadata={ + **roi.metadata, + "plane_indices": tuple(int(index) for index in plane_indices), + "plane_shape": tuple(int(size) for size in leading_shape), + }, + ) + + class ROIJsonShapeDecoder(ABC, metaclass=AutoRegisterMeta): """Decode one serialized ROI shape variant.""" @@ -181,71 +342,14 @@ def extract_rois_from_labeled_mask( source_spatial_shape_yx: Optional[Tuple[int, int]] = None, ) -> List[ROI]: """Extract ROIs from a labeled segmentation mask.""" - from skimage import measure - from skimage.measure import regionprops - from scipy.ndimage import find_objects - - if labeled_mask.ndim != 2: - raise ValueError(f"Labeled mask must be 2D, got shape {labeled_mask.shape}") - - if not np.issubdtype(labeled_mask.dtype, np.integer): - labeled_mask = labeled_mask.astype(np.int32) - - regions = regionprops(labeled_mask) - slices = find_objects(labeled_mask) - origin_y, origin_x = spatial_origin_yx or (0, 0) - - rois = [] - for region in regions: - if region.area < min_area: - continue - min_y, min_x, max_y, max_x = region.bbox - - metadata = { - "label": int(region.label), - "area": float(region.area), - "perimeter": float(region.perimeter), - "centroid": ( - float(region.centroid[0] + origin_y), - float(region.centroid[1] + origin_x), - ), - "bbox": ( - int(min_y + origin_y), - int(min_x + origin_x), - int(max_y + origin_y), - int(max_x + origin_x), - ), - } - if source_spatial_shape_yx is not None: - metadata["source_spatial_shape_yx"] = tuple( - int(value) for value in source_spatial_shape_yx - ) - - shapes = [] - if extract_contours: - label_idx = region.label - 1 - if label_idx < len(slices) and slices[label_idx] is not None: - slice_y, slice_x = slices[label_idx] - cropped_mask = labeled_mask[slice_y, slice_x] - binary_mask = (cropped_mask == region.label).astype(np.uint8) - padded_mask = np.pad(binary_mask, pad_width=1, mode="constant", constant_values=0) - contours = measure.find_contours(padded_mask, level=0.5) - offset_y = slice_y.start - offset_x = slice_x.start - padding_offset = np.array([offset_y + origin_y, offset_x + origin_x]) - 1 - for contour in contours: - if len(contour) >= 3: - contour_full = contour + padding_offset - shapes.append(PolygonShape(coordinates=contour_full)) - else: - binary_mask = (labeled_mask == region.label) - shapes.append(MaskShape(mask=binary_mask, bbox=metadata["bbox"])) - - if shapes: - rois.append(ROI(shapes=shapes, metadata=metadata)) - - logger.info(f"Extracted {len(rois)} ROIs from labeled mask") - return rois + request = LabeledMaskROIExtractionRequest( + labeled_mask=np.asarray(labeled_mask), + min_area=min_area, + extract_contours=extract_contours, + spatial_origin_yx=spatial_origin_yx, + source_spatial_shape_yx=source_spatial_shape_yx, + ) + return LabeledMaskROIExtractor.for_request(request).extract(request) def _get_backend_from_filemanager(filemanager: Any, backend: Union[str, Backend]): From 924b950d5de6dc8019cd8f2e3ed78f191b553092 Mon Sep 17 00:00:00 2001 From: Tristan Simas Date: Tue, 26 May 2026 18:17:32 -0400 Subject: [PATCH 08/11] Add Bio-Formats storage backend --- src/polystore/backend_registry.py | 3 +- src/polystore/bioformats_java.py | 223 ++++++++++++++++++++++++ src/polystore/bioformats_storage.py | 258 ++++++++++++++++++++++++++++ src/polystore/constants.py | 1 + 4 files changed, 483 insertions(+), 2 deletions(-) create mode 100644 src/polystore/bioformats_java.py create mode 100644 src/polystore/bioformats_storage.py diff --git a/src/polystore/backend_registry.py b/src/polystore/backend_registry.py index ad8ac52..eb4cb21 100644 --- a/src/polystore/backend_registry.py +++ b/src/polystore/backend_registry.py @@ -74,7 +74,7 @@ def create_storage_registry() -> Dict[str, DataSink]: # Backends that require context-specific initialization (e.g., plate_root) # These are registered lazily when needed, not at startup - SKIP_BACKENDS = {'virtual_workspace', 'omero_local'} + SKIP_BACKENDS = {'virtual_workspace', 'omero_local', 'bioformats'} registry = {} for backend_type in STORAGE_BACKENDS.keys(): @@ -157,4 +157,3 @@ def cleanup_all_backends() -> None: _backend_instances.clear() logger.info("All backend instances cleaned up") - diff --git a/src/polystore/bioformats_java.py b/src/polystore/bioformats_java.py new file mode 100644 index 0000000..41c7824 --- /dev/null +++ b/src/polystore/bioformats_java.py @@ -0,0 +1,223 @@ +"""Shared Java Bio-Formats bridge for metadata discovery and plane loading.""" + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from threading import Lock +from typing import Any, Callable + +import numpy as np + + +class BioFormatsJavaUnavailableError(RuntimeError): + """Raised when the Java Bio-Formats runtime cannot be initialized.""" + + +@dataclass(frozen=True, slots=True) +class BioFormatsOpenedReader: + """Open Bio-Formats reader plus its OME metadata store.""" + + reader: Any + metadata: Any + + def close(self) -> None: + self.reader.close() + + +class BioFormatsJavaContext: + """Lazy JVM/ImageJ context for Bio-Formats Java access.""" + + _lock = Lock() + _instance: "BioFormatsJavaContext | None" = None + + def __init__(self, imagej_module: Any, scyjava_module: Any): + self.imagej = imagej_module + self.scyjava = scyjava_module + self.ij = None + self.ImageReader = None + self.MetadataTools = None + self.FormatTools = None + + @classmethod + def instance(cls) -> "BioFormatsJavaContext": + with cls._lock: + if cls._instance is None: + cls._instance = cls._create() + return cls._instance + + @classmethod + def _create(cls) -> "BioFormatsJavaContext": + try: + import imagej + import scyjava + except ImportError as exc: + raise BioFormatsJavaUnavailableError( + "Bio-Formats support requires the optional bioformats/fiji dependencies." + ) from exc + return cls(imagej, scyjava) + + def ensure_initialized(self) -> None: + if self.ij is not None: + return + try: + self.ij = self.imagej.init("sc.fiji:fiji", mode="headless") + self.ImageReader = self.scyjava.jimport("loci.formats.ImageReader") + self.MetadataTools = self.scyjava.jimport("loci.formats.MetadataTools") + self.FormatTools = self.scyjava.jimport("loci.formats.FormatTools") + except Exception as exc: + raise BioFormatsJavaUnavailableError( + "Could not initialize Fiji/Bio-Formats through pyimagej." + ) from exc + + def open_reader(self, source_path: str | Path) -> BioFormatsOpenedReader: + self.ensure_initialized() + metadata = self.MetadataTools.createOMEXMLMetadata() + reader = self.ImageReader() + try: + reader.setMetadataStore(metadata) + reader.setId(str(source_path)) + return BioFormatsOpenedReader(reader=reader, metadata=metadata) + except Exception: + reader.close() + raise + + +def java_int(value: Any) -> int | None: + """Convert nullable Java primitive wrappers to Python int.""" + return OptionalJavaScalar.from_java(value, JAVA_SCALAR_PROJECTOR.readers).convert(int) + + +def java_float(value: Any) -> float | None: + """Convert nullable Java numeric wrappers to Python float.""" + return OptionalJavaScalar.from_java(value, JAVA_SCALAR_PROJECTOR.readers).convert(float) + + +def java_str(value: Any) -> str | None: + """Convert nullable Java strings to Python strings.""" + if value is None: + return None + return str(value) + + +def _read_java_value(value: Any) -> Any: + return value.value() + + +def _read_java_get_value(value: Any) -> Any: + return value.getValue() + + +@dataclass(frozen=True, slots=True) +class JavaScalarProjector: + """Project nullable Java scalar wrappers to Python scalar values.""" + + readers: tuple[Callable[[Any], Any], ...] + + def unwrap(self, value: Any) -> Any: + for reader in self.readers: + try: + return reader(value) + except AttributeError: + continue + return value + + +@dataclass(frozen=True, slots=True) +class OptionalJavaScalar: + """Nullable Java scalar after wrapper unwrapping.""" + + value: Any | None + + @classmethod + def from_java( + cls, + value: Any, + readers: tuple[Callable[[Any], Any], ...], + ) -> "OptionalJavaScalar": + if value is None: + return cls(None) + return cls(JavaScalarProjector(readers).unwrap(value)) + + def convert(self, converter: Callable[[Any], Any]) -> Any | None: + if self.value is None: + return None + return converter(self.value) + + +JAVA_SCALAR_PROJECTOR = JavaScalarProjector( + readers=( + _read_java_value, + _read_java_get_value, + ) +) + + +def load_bioformats_plane( + *, + source_path: Path, + series_index: int, + plane_index: int, +) -> np.ndarray: + """Load a single 2D Bio-Formats plane through the Java ImageReader.""" + context = BioFormatsJavaContext.instance() + opened = context.open_reader(source_path) + reader = opened.reader + try: + reader.setSeries(series_index) + if reader.getRGBChannelCount() != 1: + raise ValueError( + "Bio-Formats RGB/interleaved planes are not yet representable as " + "OpenHCS scalar channel planes." + ) + raw = bytes(reader.openBytes(plane_index)) + dtype = PixelDtypeCatalog.from_format_tools(context.FormatTools).dtype( + pixel_type=int(reader.getPixelType()), + little_endian=bool(reader.isLittleEndian()), + ) + array = np.frombuffer(raw, dtype=dtype) + return array.reshape((int(reader.getSizeY()), int(reader.getSizeX()))) + finally: + opened.close() + + +@dataclass(frozen=True, slots=True) +class PixelDtypeSpec: + """NumPy dtype projection for one Bio-Formats pixel type.""" + + key: int + dtype_code: str + endian_sensitive: bool = True + + def dtype(self, *, little_endian: bool) -> np.dtype: + if not self.endian_sensitive: + return np.dtype(self.dtype_code) + endian = "<" if little_endian else ">" + return np.dtype(endian + self.dtype_code) + + +@dataclass(frozen=True, slots=True) +class PixelDtypeCatalog: + """Authoritative Bio-Formats pixel-type to NumPy dtype mapping.""" + + specs_by_key: dict[int, PixelDtypeSpec] + + @classmethod + def from_format_tools(cls, format_tools: Any) -> "PixelDtypeCatalog": + specs = ( + PixelDtypeSpec(int(format_tools.INT8), "i1", endian_sensitive=False), + PixelDtypeSpec(int(format_tools.UINT8), "u1", endian_sensitive=False), + PixelDtypeSpec(int(format_tools.INT16), "i2"), + PixelDtypeSpec(int(format_tools.UINT16), "u2"), + PixelDtypeSpec(int(format_tools.INT32), "i4"), + PixelDtypeSpec(int(format_tools.UINT32), "u4"), + PixelDtypeSpec(int(format_tools.FLOAT), "f4"), + PixelDtypeSpec(int(format_tools.DOUBLE), "f8"), + ) + return cls({spec.key: spec for spec in specs}) + + def dtype(self, *, pixel_type: int, little_endian: bool) -> np.dtype: + try: + return self.specs_by_key[pixel_type].dtype(little_endian=little_endian) + except KeyError as exc: + raise ValueError(f"Unsupported Bio-Formats pixel type: {pixel_type}") from exc diff --git a/src/polystore/bioformats_storage.py b/src/polystore/bioformats_storage.py new file mode 100644 index 0000000..ba17dcf --- /dev/null +++ b/src/polystore/bioformats_storage.py @@ -0,0 +1,258 @@ +"""Structured-reference backend for Bio-Formats-backed virtual workspaces.""" + +from __future__ import annotations + +import json +from dataclasses import dataclass +from fnmatch import fnmatch +from pathlib import Path +from typing import Any, Dict, List, Optional, Set, Union + +from .base import PicklableBackend, ReadOnlyBackend +from .constants import Backend +from .exceptions import StorageResolutionError +from .metadata_writer import get_metadata_path + + +@dataclass(frozen=True, slots=True) +class BioFormatsPlaneRef: + """Serializable reference to one Bio-Formats image plane.""" + + source_path: Path + series_index: int + plane_index: int + c: int + z: int + t: int + reader: str = "bioformats" + + @classmethod + def from_mapping( + cls, + payload: Dict[str, Any], + *, + plate_root: Path, + ) -> "BioFormatsPlaneRef": + source_path = Path(payload["source_path"]) + if not source_path.is_absolute(): + source_path = plate_root / source_path + return cls( + source_path=source_path, + series_index=int(payload.get("series_index", 0)), + plane_index=int(payload["plane_index"]), + c=int(payload["c"]), + z=int(payload["z"]), + t=int(payload["t"]), + reader=str(payload.get("reader", "bioformats")), + ) + + +class BioFormatsStorageBackend(ReadOnlyBackend, PicklableBackend): + """Load normalized virtual source keys from structured Bio-Formats refs.""" + + _backend_type = Backend.BIOFORMATS.value + + def __init__(self, plate_root: Path | None = None): + self.plate_root = None if plate_root is None else Path(plate_root) + self._mapping_cache: Optional[Dict[str, Dict[str, Any]]] = None + self._cache_mtime: Optional[float] = None + + def get_connection_params(self) -> Optional[Dict[str, Any]]: + if self.plate_root is None: + return None + return {"plate_root": str(self.plate_root)} + + def set_connection_params(self, params: Optional[Dict[str, Any]]) -> None: + if not params: + self.plate_root = None + self._mapping_cache = None + self._cache_mtime = None + return + self.plate_root = Path(params["plate_root"]) + self._mapping_cache = None + self._cache_mtime = None + + def load(self, file_path: Union[str, Path], **kwargs) -> Any: + ref = self._resolve_ref(file_path) + if ref.reader == "npy": + return _load_npy_plane(ref) + if ref.reader != "bioformats": + raise BioFormatsReaderUnavailableError( + f"Unsupported Bio-Formats reader {ref.reader!r}." + ) + from .bioformats_java import load_bioformats_plane + + return load_bioformats_plane( + source_path=ref.source_path, + series_index=ref.series_index, + plane_index=ref.plane_index, + ) + + def load_batch(self, file_paths: List[Union[str, Path]], **kwargs) -> List[Any]: + return [self.load(file_path, **kwargs) for file_path in file_paths] + + def list_files( + self, + directory: Union[str, Path], + pattern: Optional[str] = None, + extensions: Optional[Set[str]] = None, + recursive: bool = False, + **kwargs, + ) -> List[str]: + plate_root = self._require_plate_root() + relative_dir = self.relative_to_root(directory) + normalized_dir = _normalize_relative_path(str(relative_dir)) + lowercase_extensions = ( + None if extensions is None else {extension.lower() for extension in extensions} + ) + results = [] + for virtual_path in self._load_mapping().keys(): + if not _virtual_path_in_directory( + virtual_path, + normalized_dir=normalized_dir, + recursive=recursive, + ): + continue + path = Path(virtual_path) + if lowercase_extensions is not None and path.suffix.lower() not in lowercase_extensions: + continue + if pattern is not None and not fnmatch(path.name, pattern): + continue + results.append(str(plate_root / virtual_path)) + return results + + def exists(self, path: Union[str, Path]) -> bool: + try: + relative = self.normalized_relative_path(path) + except StorageResolutionError: + return False + if not relative: + return True + mapping = self._load_mapping() + return relative in mapping or any( + virtual_path.startswith(relative + "/") + for virtual_path in mapping + ) + + def is_file(self, path: Union[str, Path]) -> bool: + try: + relative = self.normalized_relative_path(path) + except StorageResolutionError: + return False + return relative in self._load_mapping() + + def is_dir(self, path: Union[str, Path]) -> bool: + try: + relative = self.normalized_relative_path(path) + except StorageResolutionError: + return False + return not relative or any( + virtual_path.startswith(relative + "/") + for virtual_path in self._load_mapping() + ) + + def list_dir(self, path: Union[str, Path]) -> List[str]: + relative = self.normalized_relative_path(path) + prefix = "" if not relative else relative + "/" + names = set() + for virtual_path in self._load_mapping(): + if not virtual_path.startswith(prefix): + continue + remainder = virtual_path[len(prefix):] + if remainder: + names.add(remainder.split("/", 1)[0]) + return sorted(names) + + def _resolve_ref(self, path: Union[str, Path]) -> BioFormatsPlaneRef: + plate_root = self._require_plate_root() + relative_path = self.normalized_relative_path(path) + mapping = self._load_mapping() + try: + payload = mapping[relative_path] + except KeyError as exc: + raise StorageResolutionError( + f"Path not in Bio-Formats workspace mapping: {relative_path}" + ) from exc + if not isinstance(payload, dict): + raise StorageResolutionError( + f"Bio-Formats workspace mapping for {relative_path!r} is not structured." + ) + return BioFormatsPlaneRef.from_mapping(payload, plate_root=plate_root) + + def _load_mapping(self) -> Dict[str, Dict[str, Any]]: + plate_root = self._require_plate_root() + metadata_path = get_metadata_path(plate_root) + if not metadata_path.exists(): + raise FileNotFoundError(f"Metadata not found: {metadata_path}") + current_mtime = metadata_path.stat().st_mtime + if self._mapping_cache is not None and self._cache_mtime == current_mtime: + return self._mapping_cache + metadata = json.loads(metadata_path.read_text(encoding="utf-8")) + combined_mapping: Dict[str, Dict[str, Any]] = {} + for subdirectory in metadata.get("subdirectories", {}).values(): + if Backend.BIOFORMATS.value not in subdirectory.get("available_backends", {}): + continue + workspace_mapping = subdirectory.get("workspace_mapping", {}) + for virtual_path, ref_payload in workspace_mapping.items(): + if isinstance(ref_payload, dict): + combined_mapping[_normalize_relative_path(str(virtual_path))] = ref_payload + if not combined_mapping: + raise ValueError(f"No Bio-Formats workspace_mapping in {metadata_path}") + self._mapping_cache = combined_mapping + self._cache_mtime = current_mtime + return combined_mapping + + def _require_plate_root(self) -> Path: + if self.plate_root is None: + raise StorageResolutionError("BioFormatsStorageBackend requires plate_root.") + return self.plate_root + + def relative_to_root(self, path: Union[str, Path]) -> Path: + plate_root = self._require_plate_root() + path_obj = Path(path) + if not path_obj.is_absolute(): + return path_obj + try: + return path_obj.relative_to(plate_root) + except ValueError as exc: + raise StorageResolutionError( + f"Path {path_obj} is outside Bio-Formats plate root {plate_root}." + ) from exc + + def normalized_relative_path(self, path: Union[str, Path]) -> str: + return _normalize_relative_path(str(self.relative_to_root(path))) + + +class BioFormatsReaderUnavailableError(RuntimeError): + """Raised when a production Bio-Formats reader has not been configured.""" + + +def _load_npy_plane(ref: BioFormatsPlaneRef) -> Any: + import numpy as np + + array = np.load(ref.source_path) + if array.ndim == 2: + return array + if array.ndim == 5: + return array[ref.t - 1, ref.z - 1, ref.c - 1] + if array.ndim == 3: + return array[ref.plane_index] + raise ValueError( + f"Unsupported npy Bio-Formats fixture shape {array.shape} for {ref.source_path}." + ) + + +def _normalize_relative_path(path: str) -> str: + normalized = path.replace("\\", "/") + return "" if normalized == "." else normalized + + +def _virtual_path_in_directory( + virtual_path: str, + *, + normalized_dir: str, + recursive: bool, +) -> bool: + if recursive: + return not normalized_dir or virtual_path.startswith(normalized_dir + "/") + return _normalize_relative_path(str(Path(virtual_path).parent)) == normalized_dir diff --git a/src/polystore/constants.py b/src/polystore/constants.py index 3a27cfb..0103236 100644 --- a/src/polystore/constants.py +++ b/src/polystore/constants.py @@ -19,6 +19,7 @@ class Backend(Enum): FIJI_STREAM = "fiji_stream" OMERO_LOCAL = "omero_local" VIRTUAL_WORKSPACE = "virtual_workspace" + BIOFORMATS = "bioformats" class TransportMode(Enum): From db409a9838f0d9a299ff3ab0f47f3c48fd272e41 Mon Sep 17 00:00:00 2001 From: Tristan Simas Date: Fri, 12 Jun 2026 12:59:27 -0400 Subject: [PATCH 09/11] Refine viewer transport and virtual workspace refs Centralize Napari/Fiji viewer stream kwargs and REQ/REP ack policy in a shared transport helper, add per-path component metadata support for streamed batches, simplify Napari batch dispatch payloads, and let virtual workspaces resolve structured source refs including TIFF plane refs.\n\nValidation: PYTHONPATH=src:../arraybridge/src:../metaclass-registry/src uv run --no-sync pytest tests/test_streaming_metadata.py -q --- src/polystore/fiji_stream.py | 122 ++++++++++---- src/polystore/napari_stream.py | 93 +++++++---- src/polystore/streaming/_streaming_backend.py | 45 +++++- .../napari/napari_batch_processor.py | 46 +++--- src/polystore/streaming/viewer_transport.py | 123 ++++++++++++++ src/polystore/virtual_workspace.py | 150 ++++++++++++++++-- tests/test_streaming_metadata.py | 31 ++++ 7 files changed, 511 insertions(+), 99 deletions(-) create mode 100644 src/polystore/streaming/viewer_transport.py diff --git a/src/polystore/fiji_stream.py b/src/polystore/fiji_stream.py index 08132bc..b79faf7 100644 --- a/src/polystore/fiji_stream.py +++ b/src/polystore/fiji_stream.py @@ -22,9 +22,57 @@ from .streaming_constants import StreamingDataType from .streaming import StreamingBackend from .roi_converters import FijiROIConverter +from .streaming.viewer_transport import ( + ViewerAckPolicy, + ViewerStreamKwargs, + ViewerTransportConfigAuthority, + ViewerTransportDefaults, +) from zmqruntime.transport import get_zmq_transport_url, coerce_transport_mode logger = logging.getLogger(__name__) +FIJI_TRANSPORT_DEFAULTS = ViewerTransportDefaults() +FIJI_ACK_POLICY = ViewerAckPolicy( + viewer_name="Fiji", + timeout_ms=FIJI_TRANSPORT_DEFAULTS.ack_timeout_ms, +) + + +class FijiDisplayPayload: + """Display payload projection for Fiji stream messages.""" + + @staticmethod + def auto_contrast_value(display_config) -> bool: + if not hasattr(display_config, "auto_contrast"): + return True + return display_config.auto_contrast + + @classmethod + def from_display_config(cls, display_config) -> dict[str, Any]: + return { + "lut": display_config.get_lut_name(), + "auto_contrast": cls.auto_contrast_value(display_config), + } + + +class FijiMessageMetadata: + """Typed access to optional Fiji message metadata.""" + + @staticmethod + def component_names_metadata(message: dict) -> dict: + if "component_names_metadata" in message: + return message["component_names_metadata"] + return {} + + +class FijiRoiPayload: + """ROI payload inspection for Fiji logging.""" + + @staticmethod + def count(item_data: dict) -> int: + if "rois" not in item_data: + raise ValueError("Fiji ROI payload missing required 'rois' field") + return len(item_data["rois"]) class FijiStreamingBackend(StreamingBackend): @@ -60,7 +108,9 @@ def _prepare_batch_item(self, data: Any, file_path: Union[str, Path], data_type) logger.info(f"🔍 FIJI BACKEND: Preparing ROI data for {file_path}") item_data = self._prepare_rois_data(data, file_path) data_type_value = "rois" - logger.info(f"🔍 FIJI BACKEND: ROI data prepared: {len(item_data.get('rois', []))} ROIs") + logger.info( + f"🔍 FIJI BACKEND: ROI data prepared: {FijiRoiPayload.count(item_data)} ROIs" + ) else: logger.info(f"🔍 FIJI BACKEND: Preparing image data for {file_path}") item_data = self._create_shared_memory(data, file_path) @@ -77,36 +127,30 @@ def save_batch(self, data_list: List[Any], file_paths: List[Union[str, Path]], * if not data_list: return - # Extract kwargs using generic polymorphic names - host = kwargs.get('host', 'localhost') - port = kwargs['port'] - transport_mode = kwargs['transport_mode'] - transport_config = kwargs.get('transport_config') - display_config = kwargs['display_config'] - microscope_handler = kwargs['microscope_handler'] - source = kwargs.get('source', 'unknown_source') # Pre-built source value - images_dir = kwargs.get('images_dir') # Source image subdirectory for ROI mapping - plate_path = kwargs.get('plate_path') - component_metadata = kwargs.get('component_metadata') - logger.info(f"🏷️ FIJI BACKEND: plate_path = {plate_path}") - logger.info(f"🏷️ FIJI BACKEND: microscope_handler = {microscope_handler}") - display_payload_extra = { - "lut": display_config.get_lut_name(), - "auto_contrast": display_config.auto_contrast if hasattr(display_config, "auto_contrast") else True, - } + stream_request = ViewerStreamKwargs.from_kwargs( + kwargs, + FIJI_TRANSPORT_DEFAULTS, + include_images_dir=True, + ) + logger.info(f"🏷️ FIJI BACKEND: plate_path = {stream_request.plate_path}") + logger.info(f"🏷️ FIJI BACKEND: microscope_handler = {stream_request.microscope_handler}") + display_payload_extra = FijiDisplayPayload.from_display_config( + stream_request.display_config + ) message_extra = { - "images_dir": images_dir, + "images_dir": stream_request.images_dir, } message, batch_images, image_ids = self._build_batch_message( data_list, file_paths, - microscope_handler, - source, - display_config, + stream_request.microscope_handler, + stream_request.source, + stream_request.display_config, self._prepare_batch_item, - plate_path=plate_path, - component_metadata=component_metadata, + plate_path=stream_request.plate_path, + component_metadata=stream_request.component_metadata, + component_metadata_by_path=stream_request.component_metadata_by_path, component_names_kwargs={"log_prefix": "🏷️ FIJI BACKEND", "verbose": True}, display_payload_extra=display_payload_extra, message_extra=message_extra, @@ -114,7 +158,7 @@ def save_batch(self, data_list: List[Any], file_paths: List[Union[str, Path]], * logger.info( "🏷️ FIJI BACKEND: Final component_names_metadata: %s", - message.get("component_names_metadata", {}), + FijiMessageMetadata.component_names_metadata(message), ) for item in batch_images: @@ -128,20 +172,23 @@ def save_batch(self, data_list: List[Any], file_paths: List[Union[str, Path]], * # Register sent images with queue tracker BEFORE sending # This prevents race condition with IPC mode where acks arrive before registration self._register_with_queue_tracker( - port, + stream_request.port, image_ids, - transport_mode=transport_mode, - transport_config=transport_config, + transport_mode=stream_request.transport_mode, + transport_config=stream_request.transport_config, ) # Create FRESH REQ socket for each send - REQ sockets cannot be reused # This prevents the "Operation cannot be accomplished in current state" error # when multiple streams happen concurrently - transport_config = transport_config or self._transport_config + transport_config = ViewerTransportConfigAuthority.resolve( + stream_request.transport_config, + self._transport_config, + ) url = get_zmq_transport_url( - port, - host=host, - mode=coerce_transport_mode(transport_mode), + stream_request.port, + host=stream_request.host, + mode=coerce_transport_mode(stream_request.transport_mode), config=transport_config, ) @@ -149,6 +196,7 @@ def save_batch(self, data_list: List[Any], file_paths: List[Union[str, Path]], * self._context = zmq.Context() socket = self._context.socket(zmq.REQ) + FIJI_ACK_POLICY.apply_socket_options(socket) socket.connect(url) time.sleep(0.1) # Brief delay for connection to establish @@ -156,13 +204,17 @@ def save_batch(self, data_list: List[Any], file_paths: List[Union[str, Path]], * # Send with REQ socket (BLOCKING - worker waits for Fiji to acknowledge) # Worker blocks until Fiji receives, copies data from shared memory, and sends ack # This guarantees no messages are lost and shared memory is only closed after Fiji is done - logger.info(f"📤 FIJI BACKEND: Sending batch of {len(batch_images)} images to Fiji on port {port} (REQ/REP - blocking until ack)") + logger.info(f"📤 FIJI BACKEND: Sending batch of {len(batch_images)} images to Fiji on port {stream_request.port} (REQ/REP - blocking until ack)") socket.send_json(message) # Blocking send # Wait for acknowledgment from Fiji (REP socket) # Fiji will only reply after it has copied all data from shared memory - ack_response = socket.recv_json() - logger.info(f"✅ FIJI BACKEND: Received ack from Fiji: {ack_response.get('status', 'unknown')}") + ack_response = FIJI_ACK_POLICY.receive( + socket, + lambda: self._cleanup_shared_memory_blocks(batch_images, unlink=True), + port=stream_request.port, + ) + logger.info(f"✅ FIJI BACKEND: Received ack from Fiji: {FIJI_ACK_POLICY.status(ack_response)}") finally: # Always close the socket - never reuse REQ sockets diff --git a/src/polystore/napari_stream.py b/src/polystore/napari_stream.py index d762cd6..b5797a4 100644 --- a/src/polystore/napari_stream.py +++ b/src/polystore/napari_stream.py @@ -22,9 +22,40 @@ from .constants import Backend, TransportMode from .streaming import StreamingBackend from .roi_converters import NapariROIConverter +from .streaming.viewer_transport import ( + ViewerAckPolicy, + ViewerStreamKwargs, + ViewerTransportConfigAuthority, + ViewerTransportDefaults, +) from zmqruntime.transport import get_zmq_transport_url, coerce_transport_mode logger = logging.getLogger(__name__) +NAPARI_TRANSPORT_DEFAULTS = ViewerTransportDefaults() +NAPARI_ACK_POLICY = ViewerAckPolicy( + viewer_name="Napari", + timeout_ms=NAPARI_TRANSPORT_DEFAULTS.ack_timeout_ms, +) + + +class NapariDisplayPayload: + """Display payload projection for Napari stream messages.""" + + @staticmethod + def variable_size_handling_value(display_config): + if not hasattr(display_config, "variable_size_handling"): + return None + variable_size_handling = display_config.variable_size_handling + if variable_size_handling is None: + return None + return variable_size_handling.value + + @classmethod + def from_display_config(cls, display_config) -> dict[str, Any]: + return { + "colormap": display_config.get_colormap_name(), + "variable_size_handling": cls.variable_size_handling_value(display_config), + } class NapariStreamingBackend(StreamingBackend): @@ -75,52 +106,47 @@ def save_batch(self, data_list: List[Any], file_paths: List[Union[str, Path]], * if not data_list: return - # Extract kwargs using generic polymorphic names - host = kwargs.get('host', 'localhost') - port = kwargs['port'] - transport_mode = kwargs['transport_mode'] - transport_config = kwargs.get('transport_config') - display_config = kwargs['display_config'] - microscope_handler = kwargs['microscope_handler'] - source = kwargs.get('source', 'unknown_source') # Pre-built source value - plate_path = kwargs.get('plate_path') - component_metadata = kwargs.get('component_metadata') - display_payload_extra = { - "colormap": display_config.get_colormap_name(), - "variable_size_handling": display_config.variable_size_handling.value - if hasattr(display_config, "variable_size_handling") and display_config.variable_size_handling - else None, - } + stream_request = ViewerStreamKwargs.from_kwargs( + kwargs, + NAPARI_TRANSPORT_DEFAULTS, + ) + display_payload_extra = NapariDisplayPayload.from_display_config( + stream_request.display_config + ) message, batch_images, image_ids = self._build_batch_message( data_list, file_paths, - microscope_handler, - source, - display_config, + stream_request.microscope_handler, + stream_request.source, + stream_request.display_config, self._prepare_batch_item, - plate_path=plate_path, - component_metadata=component_metadata, + plate_path=stream_request.plate_path, + component_metadata=stream_request.component_metadata, + component_metadata_by_path=stream_request.component_metadata_by_path, display_payload_extra=display_payload_extra, ) # Register sent images with queue tracker BEFORE sending # This prevents race condition with IPC mode where acks arrive before registration self._register_with_queue_tracker( - port, + stream_request.port, image_ids, - transport_mode=transport_mode, - transport_config=transport_config, + transport_mode=stream_request.transport_mode, + transport_config=stream_request.transport_config, ) # Create FRESH REQ socket for each send - REQ sockets cannot be reused # This prevents the "Operation cannot be accomplished in current state" error # when multiple streams happen concurrently - transport_config = transport_config or self._transport_config + transport_config = ViewerTransportConfigAuthority.resolve( + stream_request.transport_config, + self._transport_config, + ) url = get_zmq_transport_url( - port, - host=host, - mode=coerce_transport_mode(transport_mode), + stream_request.port, + host=stream_request.host, + mode=coerce_transport_mode(stream_request.transport_mode), config=transport_config, ) @@ -128,6 +154,7 @@ def save_batch(self, data_list: List[Any], file_paths: List[Union[str, Path]], * self._context = zmq.Context() socket = self._context.socket(zmq.REQ) + NAPARI_ACK_POLICY.apply_socket_options(socket) socket.connect(url) time.sleep(0.1) # Brief delay for connection to establish @@ -135,13 +162,17 @@ def save_batch(self, data_list: List[Any], file_paths: List[Union[str, Path]], * # Send with REQ socket (BLOCKING - worker waits for Napari to acknowledge) # Worker blocks until Napari receives, copies data from shared memory, and sends ack # This guarantees no messages are lost and shared memory is only closed after Napari is done - logger.info(f"📤 NAPARI BACKEND: Sending batch of {len(batch_images)} images to Napari on port {port} (REQ/REP - blocking until ack)") + logger.info(f"📤 NAPARI BACKEND: Sending batch of {len(batch_images)} images to Napari on port {stream_request.port} (REQ/REP - blocking until ack)") socket.send_json(message) # Blocking send # Wait for acknowledgment from Napari (REP socket) # Napari will only reply after it has copied all data from shared memory - ack_response = socket.recv_json() - logger.info(f"✅ NAPARI BACKEND: Received ack from Napari: {ack_response.get('status', 'unknown')}") + ack_response = NAPARI_ACK_POLICY.receive( + socket, + lambda: self._cleanup_shared_memory_blocks(batch_images, unlink=True), + port=stream_request.port, + ) + logger.info(f"✅ NAPARI BACKEND: Received ack from Napari: {NAPARI_ACK_POLICY.status(ack_response)}") finally: # Always close the socket - never reuse REQ sockets diff --git a/src/polystore/streaming/_streaming_backend.py b/src/polystore/streaming/_streaming_backend.py index 1f3a70a..51a1407 100644 --- a/src/polystore/streaming/_streaming_backend.py +++ b/src/polystore/streaming/_streaming_backend.py @@ -9,6 +9,7 @@ import os import time import uuid +from collections.abc import Sequence from dataclasses import dataclass from pathlib import Path from typing import Any, Callable, List, Mapping, Set, Union @@ -28,6 +29,11 @@ PrepareStreamingItem = Callable[[Any, Union[str, Path], Any], tuple[dict, str]] +ComponentMetadataByPath = ( + Mapping[str, Mapping[str, Any] | None] + | Sequence[Mapping[str, Any] | None] + | None +) @dataclass(frozen=True) @@ -59,6 +65,7 @@ class StreamingBatchRequest: source: str prepare_item: PrepareStreamingItem component_metadata: Mapping[str, Any] | None = None + component_metadata_by_path: ComponentMetadataByPath = None class StreamingPayloadMemoryAuthority: @@ -189,6 +196,30 @@ def _parse_component_metadata( ) return StreamingComponentMetadata(parsed_metadata, source).to_payload() + @staticmethod + def _component_metadata_for_item( + *, + file_path: Union[str, Path], + index: int, + component_metadata: Mapping[str, Any] | None, + component_metadata_by_path: ComponentMetadataByPath, + ) -> Mapping[str, Any] | None: + """Return explicit component metadata for one batch item when provided.""" + if component_metadata_by_path is None: + return component_metadata + + if isinstance(component_metadata_by_path, Mapping): + path = Path(file_path) + for key in (str(file_path), path.as_posix(), path.name): + if key in component_metadata_by_path: + return component_metadata_by_path[key] + return component_metadata + + if index < len(component_metadata_by_path): + return component_metadata_by_path[index] + + return component_metadata + def _detect_data_type(self, data: Any): """ Detect if data is ROI (shapes/points) or image (common for all streaming backends). @@ -334,16 +365,24 @@ def _prepare_batch_items( batch_images = [] image_ids = [] - for data, file_path in zip(request.data_list, request.file_paths): + for index, (data, file_path) in enumerate( + zip(request.data_list, request.file_paths) + ): image_id = str(uuid.uuid4()) image_ids.append(image_id) data_type = self._detect_data_type(data) + explicit_component_metadata = self._component_metadata_for_item( + file_path=file_path, + index=index, + component_metadata=request.component_metadata, + component_metadata_by_path=request.component_metadata_by_path, + ) component_metadata = self._parse_component_metadata( file_path, request.microscope_handler, request.source, - request.component_metadata, + explicit_component_metadata, ) item_data, data_type_value = request.prepare_item(data, file_path, data_type) @@ -369,6 +408,7 @@ def _build_batch_message( plate_path: Union[str, Path, None] = None, component_names_kwargs: dict | None = None, component_metadata: Mapping[str, Any] | None = None, + component_metadata_by_path: ComponentMetadataByPath = None, display_payload_extra: dict | None = None, message_extra: dict | None = None, ) -> tuple[dict, list[dict], list[str]]: @@ -383,6 +423,7 @@ def _build_batch_message( source=source, prepare_item=prepare_item, component_metadata=component_metadata, + component_metadata_by_path=component_metadata_by_path, ) ) diff --git a/src/polystore/streaming/receivers/napari/napari_batch_processor.py b/src/polystore/streaming/receivers/napari/napari_batch_processor.py index ad485af..4abbc90 100644 --- a/src/polystore/streaming/receivers/napari/napari_batch_processor.py +++ b/src/polystore/streaming/receivers/napari/napari_batch_processor.py @@ -1,9 +1,28 @@ import logging +from dataclasses import dataclass from typing import Any, Dict, List, Optional logger = logging.getLogger(__name__) +@dataclass(frozen=True) +class NapariBatchDisplayRequest: + """Nominal request for one debounced Napari display update.""" + + layer_key: str + items: List[Dict[str, Any]] + display_payload: object + component_names_metadata: Dict[str, Any] + + def dispatch_to(self, napari_server) -> None: + napari_server.display_layer_batch( + layer_key=self.layer_key, + items=self.items, + display_payload=self.display_payload, + component_names_metadata=self.component_names_metadata, + ) + + class NapariBatchProcessor: """ Batch processor for Napari viewer display operations. @@ -43,7 +62,7 @@ def add_items( self, layer_key: str, items: List[Dict[str, Any]], - display_config: Dict[str, Any], + display_payload: object, component_names_metadata: Dict[str, Any], ): """ @@ -52,17 +71,15 @@ def add_items( Args: layer_key: Unique identifier for the layer items: List of items to add (images or ROIs) - display_config: Display configuration dict + display_payload: Viewer-owned display configuration object component_names_metadata: Component name mappings for dimension labels """ - self._process_batch( - items, - { - "display_config": display_config, - "component_names_metadata": component_names_metadata, - "layer_key": layer_key, - }, - ) + NapariBatchDisplayRequest( + layer_key=layer_key, + items=items, + display_payload=display_payload, + component_names_metadata=component_names_metadata, + ).dispatch_to(self.napari_server) logger.debug( "NapariBatchProcessor: Added %d items to batch for layer '%s'", len(items), @@ -71,12 +88,3 @@ def add_items( def flush(self) -> None: """Compatibility no-op; OpenHCS owns the Qt-thread debounce timer.""" - - def _process_batch(self, items: List[Dict[str, Any]], context: Dict[str, Any]) -> None: - """Process callback used by shared debounced batch engine.""" - self.napari_server.display_layer_batch( - layer_key=context["layer_key"], - items=items, - display_config=context["display_config"], - component_names_metadata=context["component_names_metadata"], - ) diff --git a/src/polystore/streaming/viewer_transport.py b/src/polystore/streaming/viewer_transport.py new file mode 100644 index 0000000..b906190 --- /dev/null +++ b/src/polystore/streaming/viewer_transport.py @@ -0,0 +1,123 @@ +"""Nominal transport helpers for blocking viewer stream backends.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Mapping + +import zmq + + +@dataclass(frozen=True) +class ViewerTransportDefaults: + """Declared transport defaults shared by viewer streaming backends.""" + + host: str = "localhost" + source: str = "unknown_source" + ack_timeout_ms: int = 30_000 + + +@dataclass(frozen=True) +class ViewerStreamKwargs: + """Typed view of backend kwargs at the viewer streaming boundary.""" + + host: str + port: int + transport_mode: Any + transport_config: Any + display_config: Any + microscope_handler: Any + source: str + plate_path: Any + component_metadata: Any + component_metadata_by_path: Any + images_dir: Any = None + + @classmethod + def from_kwargs( + cls, + kwargs: Mapping[str, Any], + defaults: ViewerTransportDefaults, + *, + include_images_dir: bool = False, + ) -> "ViewerStreamKwargs": + return cls( + host=ViewerKwargAuthority.value_or_default(kwargs, "host", defaults.host), + port=ViewerKwargAuthority.required(kwargs, "port"), + transport_mode=ViewerKwargAuthority.required(kwargs, "transport_mode"), + transport_config=ViewerKwargAuthority.optional(kwargs, "transport_config"), + display_config=ViewerKwargAuthority.required(kwargs, "display_config"), + microscope_handler=ViewerKwargAuthority.required(kwargs, "microscope_handler"), + source=ViewerKwargAuthority.value_or_default(kwargs, "source", defaults.source), + plate_path=ViewerKwargAuthority.optional(kwargs, "plate_path"), + component_metadata=ViewerKwargAuthority.optional(kwargs, "component_metadata"), + component_metadata_by_path=ViewerKwargAuthority.optional( + kwargs, "component_metadata_by_path" + ), + images_dir=( + ViewerKwargAuthority.optional(kwargs, "images_dir") + if include_images_dir + else None + ), + ) + + +class ViewerKwargAuthority: + """Named access policy for viewer backend kwargs.""" + + @staticmethod + def required(kwargs: Mapping[str, Any], name: str) -> Any: + if name not in kwargs: + raise ValueError(f"Viewer streaming kwargs missing required field '{name}'") + return kwargs[name] + + @staticmethod + def optional(kwargs: Mapping[str, Any], name: str) -> Any: + if name in kwargs: + return kwargs[name] + return None + + @staticmethod + def value_or_default(kwargs: Mapping[str, Any], name: str, default: Any) -> Any: + if name in kwargs: + return kwargs[name] + return default + + +class ViewerTransportConfigAuthority: + """Resolve the concrete transport config without implicit truthiness fallback.""" + + @staticmethod + def resolve(transport_config: Any, default_transport_config: Any) -> Any: + if transport_config is None: + return default_transport_config + return transport_config + + +@dataclass(frozen=True) +class ViewerAckPolicy: + """REQ/REP ack contract for a streaming viewer.""" + + viewer_name: str + timeout_ms: int + + def apply_socket_options(self, socket: zmq.Socket) -> None: + socket.setsockopt(zmq.LINGER, 0) + socket.setsockopt(zmq.SNDTIMEO, self.timeout_ms) + socket.setsockopt(zmq.RCVTIMEO, self.timeout_ms) + + def receive(self, socket: zmq.Socket, cleanup, *, port: int) -> dict[str, Any]: + try: + return socket.recv_json() + except zmq.Again as exc: + cleanup() + raise TimeoutError( + f"Timed out waiting {self.timeout_ms}ms for {self.viewer_name} ack on port {port}" + ) from exc + + def status(self, ack_response: Mapping[str, Any]) -> str: + if "status" not in ack_response: + raise ValueError( + f"{self.viewer_name} ack response missing required 'status': {ack_response}" + ) + return ack_response["status"] diff --git a/src/polystore/virtual_workspace.py b/src/polystore/virtual_workspace.py index bec8be5..e47c657 100644 --- a/src/polystore/virtual_workspace.py +++ b/src/polystore/virtual_workspace.py @@ -3,17 +3,133 @@ import logging import json from pathlib import Path -from typing import Any, Dict, List, Optional, Set, Union +from abc import ABC, abstractmethod +from typing import Any, ClassVar, Dict, List, Mapping, Optional, Set, Union from fnmatch import fnmatch from .disk import DiskStorageBackend from .metadata_writer import get_metadata_path from .exceptions import StorageResolutionError from .base import ReadOnlyBackend +from .constants import Backend +from .registry import AutoRegisterMeta logger = logging.getLogger(__name__) +class VirtualWorkspaceSourceRefResolver(ABC, metaclass=AutoRegisterMeta): + """Nominal loader family for virtual-workspace source references.""" + + __registry_key__ = "resolver_key" + __skip_if_no_key__ = True + resolver_key: ClassVar[str | None] = None + priority: ClassVar[int] + + @classmethod + def for_ref(cls, source_ref: Any) -> "VirtualWorkspaceSourceRefResolver": + for resolver_type in sorted( + cls.__registry__.values(), + key=lambda registered_type: registered_type.priority, + ): + resolver = resolver_type() + if resolver.accepts(source_ref): + return resolver + raise StorageResolutionError( + f"Unsupported virtual workspace source reference: {source_ref!r}" + ) + + @abstractmethod + def accepts(self, source_ref: Any) -> bool: + """Return whether this resolver owns the reference shape.""" + + @abstractmethod + def source_path(self, plate_root: Path, source_ref: Any) -> Path: + """Return the concrete source path for existence and diagnostics.""" + + @abstractmethod + def load( + self, + disk_backend: DiskStorageBackend, + plate_root: Path, + source_ref: Any, + **kwargs: Any, + ) -> Any: + """Load the payload addressed by this source reference.""" + + +class PathSourceRefResolver(VirtualWorkspaceSourceRefResolver): + """Resolve legacy string path mappings.""" + + resolver_key = "path" + priority = 100 + + def accepts(self, source_ref: Any) -> bool: + return isinstance(source_ref, (str, Path)) + + def source_path(self, plate_root: Path, source_ref: Any) -> Path: + path = Path(source_ref) + return path if path.is_absolute() else plate_root / path + + def load( + self, + disk_backend: DiskStorageBackend, + plate_root: Path, + source_ref: Any, + **kwargs: Any, + ) -> Any: + return disk_backend.load(self.source_path(plate_root, source_ref), **kwargs) + + +class DiskSourceRefResolver(VirtualWorkspaceSourceRefResolver): + """Resolve structured disk refs, including single-plane TIFF pages.""" + + resolver_key = "disk" + priority = 10 + + def accepts(self, source_ref: Any) -> bool: + return ( + isinstance(source_ref, Mapping) + and source_ref.get("backend", Backend.DISK.value) == Backend.DISK.value + and isinstance(source_ref.get("source_path"), (str, Path)) + ) + + def source_path(self, plate_root: Path, source_ref: Any) -> Path: + path = Path(source_ref["source_path"]) + return path if path.is_absolute() else plate_root / path + + def load( + self, + disk_backend: DiskStorageBackend, + plate_root: Path, + source_ref: Any, + **kwargs: Any, + ) -> Any: + payload = disk_backend.load(self.source_path(plate_root, source_ref), **kwargs) + plane_index = source_ref.get("plane_index") + if plane_index is None: + return payload + return _payload_plane(payload, int(plane_index), source_ref) + + +def _payload_plane(payload: Any, plane_index: int, source_ref: Mapping[str, Any]) -> Any: + if not hasattr(payload, "ndim") or not hasattr(payload, "shape"): + raise StorageResolutionError( + f"Source ref {source_ref!r} requested plane {plane_index}, but the loaded " + f"payload has no array shape." + ) + if payload.ndim < 3: + raise StorageResolutionError( + f"Source ref {source_ref!r} requested plane {plane_index}, but loaded " + f"payload shape {payload.shape!r} is not a stack." + ) + if plane_index < 0 or plane_index >= payload.shape[0]: + raise StorageResolutionError( + f"Source ref {source_ref!r} requested plane {plane_index}, but loaded " + f"payload shape is {payload.shape!r}." + ) + return payload[plane_index] + + class VirtualWorkspaceBackend(ReadOnlyBackend): """ Read-only path translation layer for virtual workspace. @@ -53,7 +169,7 @@ def __init__(self, plate_root: Path): """ self.plate_root = Path(plate_root) self.disk_backend = DiskStorageBackend() - self._mapping_cache: Optional[Dict[str, str]] = None + self._mapping_cache: Optional[Dict[str, Any]] = None self._cache_mtime: Optional[float] = None # Load mapping eagerly - fail loud if metadata missing @@ -76,7 +192,7 @@ def _normalize_relative_path(path_str: str) -> str: normalized = path_str.replace('\\', '/') return '' if normalized == '.' else normalized - def _load_mapping(self) -> Dict[str, str]: + def _load_mapping(self) -> Dict[str, Any]: """ Load workspace_mapping from metadata with mtime-based caching. @@ -122,7 +238,7 @@ def _load_mapping(self) -> Dict[str, str]: logger.info(f"Loaded {len(combined_mapping)} mappings for {self.plate_root}") return combined_mapping - def _resolve_path(self, path: Union[str, Path]) -> str: + def _resolve_ref(self, path: Union[str, Path]) -> Any: """ Resolve virtual path to real plate path using plate-relative mapping. @@ -163,20 +279,30 @@ def _resolve_path(self, path: Union[str, Path]) -> str: f"This path must be accessed through the virtual workspace mapping." ) - real_relative = self._mapping_cache[relative_str] - real_absolute = self.plate_root / real_relative - logger.debug(f"Resolved virtual → real: {relative_str} → {real_relative}") - return str(real_absolute) + source_ref = self._mapping_cache[relative_str] + logger.debug("Resolved virtual source ref: %s -> %r", relative_str, source_ref) + return source_ref + + def _resolve_path(self, path: Union[str, Path]) -> str: + """Resolve a virtual path to the concrete source path for diagnostics.""" + source_ref = self._resolve_ref(path) + resolver = VirtualWorkspaceSourceRefResolver.for_ref(source_ref) + return str(resolver.source_path(self.plate_root, source_ref)) def load(self, file_path: Union[str, Path], **kwargs) -> Any: """Load file from virtual workspace.""" - real_path = self._resolve_path(file_path) - return self.disk_backend.load(real_path, **kwargs) + source_ref = self._resolve_ref(file_path) + resolver = VirtualWorkspaceSourceRefResolver.for_ref(source_ref) + return resolver.load( + self.disk_backend, + self.plate_root, + source_ref, + **kwargs, + ) def load_batch(self, file_paths: List[Union[str, Path]], **kwargs) -> List[Any]: """Load multiple files from virtual workspace.""" - real_paths = [self._resolve_path(fp) for fp in file_paths] - return self.disk_backend.load_batch(real_paths, **kwargs) + return [self.load(file_path, **kwargs) for file_path in file_paths] def list_files(self, directory: Union[str, Path], pattern: Optional[str] = None, extensions: Optional[Set[str]] = None, recursive: bool = False, diff --git a/tests/test_streaming_metadata.py b/tests/test_streaming_metadata.py index 95d3b00..767a772 100644 --- a/tests/test_streaming_metadata.py +++ b/tests/test_streaming_metadata.py @@ -46,6 +46,37 @@ def test_streaming_batch_items_reject_unparsed_artifact_filename() -> None: ) +def test_streaming_batch_items_accept_per_path_component_metadata() -> None: + backend = MetadataProbeStreamingBackend() + microscope_handler = SimpleNamespace( + parser=SimpleNamespace(parse_filename=lambda _filename: None) + ) + + batch_images, _image_ids = backend._prepare_batch_items( + StreamingBatchRequest( + data_list=[object()], + file_paths=["A01_s001_w1_z001_t001_Nuclei_step3_rois.roi.zip"], + microscope_handler=microscope_handler, + source="IdentifyPrimaryObjects", + prepare_item=lambda _data, _path, _data_type: ({"payload": "ok"}, "image"), + component_metadata_by_path={ + "A01_s001_w1_z001_t001_Nuclei_step3_rois.roi.zip": { + "well": "A01", + "site": 1, + "channel": 1, + }, + }, + ) + ) + + assert batch_images[0]["metadata"] == { + "well": "A01", + "site": 1, + "channel": 1, + "source": "IdentifyPrimaryObjects", + } + + def test_streaming_component_metadata_preserves_parsed_filename_fields() -> None: backend = MetadataProbeStreamingBackend() microscope_handler = SimpleNamespace( From c7d400b57639d5cc50666ceda8da9d2456fe34eb Mon Sep 17 00:00:00 2001 From: Tristan Simas Date: Tue, 16 Jun 2026 00:45:32 -0400 Subject: [PATCH 10/11] Add producer identity for streaming routes Replace ad hoc source metadata with StreamProducerIdentity so streamed images, labels, and artifacts carry a single producer contract. Route viewer windows and Napari layers from the producer identity while keeping display labels separate from hidden route keys. Update streaming receiver tests to cover identity payloads, stable grouping, and component-mode projection. --- src/polystore/fiji_stream.py | 9 +- src/polystore/napari_stream.py | 2 +- src/polystore/streaming/_streaming_backend.py | 26 +-- src/polystore/streaming/base.py | 1 - src/polystore/streaming/identity.py | 203 ++++++++++++++++++ src/polystore/streaming/receivers/__init__.py | 4 +- .../receivers/core/window_projection.py | 43 ++-- .../receivers/fiji/fiji_batch_processor.py | 2 +- .../streaming/receivers/napari/__init__.py | 4 +- .../streaming/receivers/napari/layer_key.py | 17 +- src/polystore/streaming/viewer_transport.py | 9 +- tests/test_streaming_metadata.py | 22 +- tests/test_streaming_receiver_core.py | 130 +++++++++-- 13 files changed, 379 insertions(+), 93 deletions(-) create mode 100644 src/polystore/streaming/identity.py diff --git a/src/polystore/fiji_stream.py b/src/polystore/fiji_stream.py index b79faf7..0a2f36f 100644 --- a/src/polystore/fiji_stream.py +++ b/src/polystore/fiji_stream.py @@ -145,7 +145,7 @@ def save_batch(self, data_list: List[Any], file_paths: List[Union[str, Path]], * data_list, file_paths, stream_request.microscope_handler, - stream_request.source, + stream_request.producer_identity, stream_request.display_config, self._prepare_batch_item, plate_path=stream_request.plate_path, @@ -167,7 +167,12 @@ def save_batch(self, data_list: List[Any], file_paths: List[Union[str, Path]], * # Log batch composition data_types = [item['data_type'] for item in batch_images] type_counts = {dt: data_types.count(dt) for dt in set(data_types)} - logger.info(f"📤 FIJI BACKEND: Sending batch message with {len(batch_images)} items to port {port}: {type_counts}") + logger.info( + "📤 FIJI BACKEND: Sending batch message with %d items to port %s: %s", + len(batch_images), + stream_request.port, + type_counts, + ) # Register sent images with queue tracker BEFORE sending # This prevents race condition with IPC mode where acks arrive before registration diff --git a/src/polystore/napari_stream.py b/src/polystore/napari_stream.py index b5797a4..f80348e 100644 --- a/src/polystore/napari_stream.py +++ b/src/polystore/napari_stream.py @@ -118,7 +118,7 @@ def save_batch(self, data_list: List[Any], file_paths: List[Union[str, Path]], * data_list, file_paths, stream_request.microscope_handler, - stream_request.source, + stream_request.producer_identity, stream_request.display_config, self._prepare_batch_item, plate_path=stream_request.plate_path, diff --git a/src/polystore/streaming/_streaming_backend.py b/src/polystore/streaming/_streaming_backend.py index 51a1407..8d82d75 100644 --- a/src/polystore/streaming/_streaming_backend.py +++ b/src/polystore/streaming/_streaming_backend.py @@ -22,6 +22,7 @@ from ..streaming_constants import StreamingDataType from ..roi import ROI, PointShape from ..zmq_config import POLYSTORE_ZMQ_CONFIG +from .identity import StreamProducerIdentity from zmqruntime.ack_listener import GlobalAckListener from zmqruntime.transport import coerce_transport_mode @@ -41,7 +42,6 @@ class StreamingComponentMetadata: """Message metadata for one streamed item.""" parsed_filename_metadata: Mapping[str, Any] - source: str def to_payload(self) -> dict[str, Any]: if isinstance(self.parsed_filename_metadata, Mapping): @@ -51,7 +51,6 @@ def to_payload(self) -> dict[str, Any]: "Streaming component metadata must be a mapping, " f"got {type(self.parsed_filename_metadata).__name__}." ) - metadata["source"] = self.source return metadata @@ -62,7 +61,7 @@ class StreamingBatchRequest: data_list: List[Any] file_paths: List[Union[str, Path]] microscope_handler: Any - source: str + producer_identity: StreamProducerIdentity prepare_item: PrepareStreamingItem component_metadata: Mapping[str, Any] | None = None component_metadata_by_path: ComponentMetadataByPath = None @@ -169,20 +168,9 @@ def _parse_component_metadata( self, file_path: Union[str, Path], microscope_handler, - source: str, component_metadata: Mapping[str, Any] | None = None, ) -> dict: - """ - Parse component metadata from filename (common for all streaming backends). - - Args: - file_path: Path to parse - microscope_handler: Handler with parser - source: Pre-built source value (step_name during execution, subdir when loading from disk) - - Returns: - Component metadata dict with source added - """ + """Parse real source-plane component metadata for one stream item.""" filename = os.path.basename(str(file_path)) parsed_metadata = ( component_metadata @@ -194,7 +182,7 @@ def _parse_component_metadata( "Streaming component metadata requires explicit component_metadata " f"or a parser-readable filename; got {filename!r}." ) - return StreamingComponentMetadata(parsed_metadata, source).to_payload() + return StreamingComponentMetadata(parsed_metadata).to_payload() @staticmethod def _component_metadata_for_item( @@ -381,7 +369,6 @@ def _prepare_batch_items( component_metadata = self._parse_component_metadata( file_path, request.microscope_handler, - request.source, explicit_component_metadata, ) item_data, data_type_value = request.prepare_item(data, file_path, data_type) @@ -391,6 +378,7 @@ def _prepare_batch_items( **item_data, "data_type": data_type_value, "metadata": component_metadata, + "producer_identity": request.producer_identity.to_payload(), "image_id": image_id, } ) @@ -402,7 +390,7 @@ def _build_batch_message( data_list: List[Any], file_paths: List[Union[str, Path]], microscope_handler, - source: str, + producer_identity: StreamProducerIdentity, display_config, prepare_item: PrepareStreamingItem, plate_path: Union[str, Path, None] = None, @@ -420,7 +408,7 @@ def _build_batch_message( data_list=data_list, file_paths=file_paths, microscope_handler=microscope_handler, - source=source, + producer_identity=producer_identity, prepare_item=prepare_item, component_metadata=component_metadata, component_metadata_by_path=component_metadata_by_path, diff --git a/src/polystore/streaming/base.py b/src/polystore/streaming/base.py index 5cbfec0..c0bfb70 100644 --- a/src/polystore/streaming/base.py +++ b/src/polystore/streaming/base.py @@ -27,7 +27,6 @@ class TypedData(Generic[T]): """ items: List[T] metadata: Dict[str, Any] - source: str class ComponentAccessor(ABC): diff --git a/src/polystore/streaming/identity.py b/src/polystore/streaming/identity.py new file mode 100644 index 0000000..3d7440f --- /dev/null +++ b/src/polystore/streaming/identity.py @@ -0,0 +1,203 @@ +"""Nominal stream identity records shared by viewer streaming backends.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, ClassVar, Mapping + + +@dataclass(frozen=True, slots=True) +class StreamProducerIdentity: + """Producer/output identity for one streamed viewer item.""" + + origin: str + output_kind: str + output_key: str + step_name: str | None = None + pipeline_position: int | None = None + step_scope_id: str | None = None + invocation_key: str | None = None + artifact_kind: str | None = None + PAYLOAD_FIELDS: ClassVar[tuple[str, ...]] = ( + "origin", + "output_kind", + "output_key", + "step_name", + "pipeline_position", + "step_scope_id", + "invocation_key", + "artifact_kind", + ) + + @classmethod + def pipeline_output( + cls, + *, + output_kind: str, + output_key: str, + step_name: str, + pipeline_position: int | None, + step_scope_id: str | None = None, + artifact_kind: str | None = None, + ) -> "StreamProducerIdentity": + """Build identity for one pipeline-produced stream output.""" + return cls( + origin="pipeline", + output_kind=output_kind, + output_key=output_key, + step_name=step_name, + pipeline_position=pipeline_position, + step_scope_id=step_scope_id, + artifact_kind=artifact_kind, + ) + + @classmethod + def manual(cls, output_key: str) -> "StreamProducerIdentity": + """Build identity for one manual viewer action.""" + return cls._fixed_origin_output( + origin="manual", + output_kind="manual", + output_key=output_key, + ) + + @classmethod + def direct(cls, output_key: str) -> "StreamProducerIdentity": + """Build identity for direct in-process display calls.""" + return cls._fixed_origin_output( + origin="direct", + output_kind="direct", + output_key=output_key, + ) + + @classmethod + def _fixed_origin_output( + cls, + *, + origin: str, + output_kind: str, + output_key: str, + ) -> "StreamProducerIdentity": + """Build identity variants whose origin and output kind match.""" + return cls( + origin=origin, + output_kind=output_kind, + output_key=output_key, + ) + + @classmethod + def from_payload(cls, payload: "StreamProducerIdentity | Mapping[str, Any]") -> "StreamProducerIdentity": + if isinstance(payload, cls): + return payload + if not isinstance(payload, Mapping): + raise TypeError( + "Stream producer identity must be a mapping or StreamProducerIdentity, " + f"got {type(payload).__name__}." + ) + missing = [ + field_name + for field_name in ("origin", "output_kind", "output_key") + if payload.get(field_name) in (None, "") + ] + if missing: + raise ValueError(f"Stream producer identity missing required fields: {missing}") + pipeline_position = payload.get("pipeline_position") + return cls( + origin=str(payload["origin"]), + output_kind=str(payload["output_kind"]), + output_key=str(payload["output_key"]), + step_name=_optional_str(payload.get("step_name")), + pipeline_position=( + None if pipeline_position is None else int(pipeline_position) + ), + step_scope_id=_optional_str(payload.get("step_scope_id")), + invocation_key=_optional_str(payload.get("invocation_key")), + artifact_kind=_optional_str(payload.get("artifact_kind")), + ) + + def to_payload(self) -> dict[str, Any]: + return { + field_name: getattr(self, field_name) + for field_name in self.PAYLOAD_FIELDS + } + + def route_parts(self) -> tuple[str, ...]: + parts = [ + f"origin_{self.origin}", + f"kind_{self.output_kind}", + f"out_{self.output_key}", + ] + if self.pipeline_position is not None: + parts.append(f"step_{self.pipeline_position}") + if self.step_scope_id: + parts.append(f"scope_{self.step_scope_id}") + if self.step_name: + parts.append(f"name_{self.step_name}") + if self.invocation_key: + parts.append(f"invocation_{self.invocation_key}") + if self.artifact_kind: + parts.append(f"artifact_{self.artifact_kind}") + return tuple(parts) + + +def _optional_str(value: Any) -> str | None: + if value is None: + return None + text = str(value) + return text or None + + +class StreamProducerDisplayNameAuthority: + """Own user-facing labels derived from stream producer identity.""" + + PIPELINE_DISPLAY_INDEX_BASE: ClassVar[int] = 1 + OUTPUT_KEY_OMITTING_KINDS: ClassVar[frozenset[str]] = frozenset( + {"main", "manual", "direct"} + ) + + @staticmethod + def producer_base(producer: StreamProducerIdentity) -> str: + if producer.step_name: + return producer.step_name + return producer.output_key + + @classmethod + def producer_label(cls, producer: StreamProducerIdentity) -> str: + base = cls.producer_base(producer) + if producer.pipeline_position is None: + return base + return f"{producer.pipeline_position + cls.PIPELINE_DISPLAY_INDEX_BASE}. {base}" + + @classmethod + def output_label(cls, producer: StreamProducerIdentity) -> str: + parts = [cls.producer_label(producer)] + if cls.includes_output_key(producer): + parts.append(producer.output_key) + return " ".join(part for part in parts if part) + + @classmethod + def disambiguation_label(cls, producer: StreamProducerIdentity) -> str: + if producer.pipeline_position is not None: + return f"step {producer.pipeline_position + cls.PIPELINE_DISPLAY_INDEX_BASE}" + return producer.output_key or producer.origin + + @classmethod + def includes_output_key(cls, producer: StreamProducerIdentity) -> bool: + if not producer.output_key: + return False + if producer.output_kind in cls.OUTPUT_KEY_OMITTING_KINDS: + return False + return producer.output_key != cls.producer_base(producer) + + +class StreamRouteKeyAuthority: + """Own stable key-token projection for viewer route keys.""" + + @staticmethod + def token(value: object) -> str: + return str(value).replace("/", "_").replace("\\", "_").replace(" ", "_") + + @classmethod + def join(cls, parts: tuple[object, ...] | list[object]) -> str: + if not parts: + raise ValueError("Cannot build a stream route key with no parts.") + return "_".join(cls.token(part) for part in parts) diff --git a/src/polystore/streaming/receivers/__init__.py b/src/polystore/streaming/receivers/__init__.py index b1876be..5b57d56 100644 --- a/src/polystore/streaming/receivers/__init__.py +++ b/src/polystore/streaming/receivers/__init__.py @@ -15,7 +15,7 @@ from polystore.streaming.receivers.napari import ( NapariBatchProcessor, normalize_component_layout, - build_layer_key, + build_route_key, ) __all__ = [ @@ -27,5 +27,5 @@ "FijiBatchProcessor", "NapariBatchProcessor", "normalize_component_layout", - "build_layer_key", + "build_route_key", ] diff --git a/src/polystore/streaming/receivers/core/window_projection.py b/src/polystore/streaming/receivers/core/window_projection.py index 4987960..6a99995 100644 --- a/src/polystore/streaming/receivers/core/window_projection.py +++ b/src/polystore/streaming/receivers/core/window_projection.py @@ -3,11 +3,13 @@ from __future__ import annotations from dataclasses import dataclass -from pathlib import Path -from typing import Any, Callable +from typing import Any - -WindowValueNormalizer = Callable[[str, Any, dict[str, Any], str | None], Any] +from polystore.streaming.identity import ( + StreamProducerDisplayNameAuthority, + StreamProducerIdentity, + StreamRouteKeyAuthority, +) @dataclass(frozen=True) @@ -22,33 +24,12 @@ class GroupedWindowItems: fixed_window_labels: dict[str, list[tuple[str, Any]]] -def _default_normalizer( - component_name: str, - value: Any, - item: dict[str, Any], - images_dir: str | None, -) -> Any: - """Normalize window component values for stable keying across payload types.""" - data_type = item.get("data_type") - if component_name == "source" and images_dir and data_type == "rois": - value_str = str(value) - if "_results" in value_str or "/" in value_str: - return Path(images_dir).name - return value - - def group_items_by_component_modes( items: list[dict[str, Any]], component_modes: dict[str, str], component_order: list[str], - *, - images_dir: str | None = None, - normalizer: WindowValueNormalizer | None = None, ) -> GroupedWindowItems: """Project items into window groups using declared component modes.""" - if normalizer is None: - normalizer = _default_normalizer - result: dict[str, list[str]] = { "window": [], "channel": [], @@ -69,17 +50,20 @@ def group_items_by_component_modes( for item in items: meta = item.get("metadata", {}) - key_parts: list[str] = [] - fixed_labels: list[tuple[str, Any]] = [] + producer = StreamProducerIdentity.from_payload(item.get("producer_identity")) + key_parts: list[str] = list(producer.route_parts()) + fixed_labels: list[tuple[str, Any]] = [ + ("producer", StreamProducerDisplayNameAuthority.output_label(producer)) + ] for comp in window_components: if comp not in meta: continue - value = normalizer(comp, meta[comp], item, images_dir) + value = meta[comp] key_parts.append(f"{comp}_{value}") fixed_labels.append((comp, value)) - window_key = "_".join(key_parts) if key_parts else "default_window" + window_key = StreamRouteKeyAuthority.join(key_parts) windows.setdefault(window_key, []).append(item) if window_key not in fixed_window_labels: fixed_window_labels[window_key] = fixed_labels @@ -92,4 +76,3 @@ def group_items_by_component_modes( windows=windows, fixed_window_labels=fixed_window_labels, ) - diff --git a/src/polystore/streaming/receivers/fiji/fiji_batch_processor.py b/src/polystore/streaming/receivers/fiji/fiji_batch_processor.py index 60e0a3b..9b37366 100644 --- a/src/polystore/streaming/receivers/fiji/fiji_batch_processor.py +++ b/src/polystore/streaming/receivers/fiji/fiji_batch_processor.py @@ -65,7 +65,7 @@ def add_items( window_key: Unique identifier for the Fiji window items: List of items to add (images) display_config: Display configuration dict - images_dir: Source image subdirectory + images_dir: Artifact image directory context. component_names_metadata: Component name mappings for dimension labels """ context = { diff --git a/src/polystore/streaming/receivers/napari/__init__.py b/src/polystore/streaming/receivers/napari/__init__.py index 9ece1bf..6472c4c 100644 --- a/src/polystore/streaming/receivers/napari/__init__.py +++ b/src/polystore/streaming/receivers/napari/__init__.py @@ -3,7 +3,7 @@ from polystore.streaming.receivers.napari.napari_batch_processor import NapariBatchProcessor from polystore.streaming.receivers.napari.layer_key import ( normalize_component_layout, - build_layer_key, + build_route_key, ) -__all__ = ["NapariBatchProcessor", "normalize_component_layout", "build_layer_key"] +__all__ = ["NapariBatchProcessor", "normalize_component_layout", "build_route_key"] diff --git a/src/polystore/streaming/receivers/napari/layer_key.py b/src/polystore/streaming/receivers/napari/layer_key.py index 51b7d67..e29ab73 100644 --- a/src/polystore/streaming/receivers/napari/layer_key.py +++ b/src/polystore/streaming/receivers/napari/layer_key.py @@ -1,9 +1,10 @@ -"""Canonical napari layer-key construction from component metadata.""" +"""Canonical napari route-key construction.""" from __future__ import annotations from typing import Any +from polystore.streaming.identity import StreamProducerIdentity, StreamRouteKeyAuthority from polystore.streaming_constants import StreamingDataType @@ -17,19 +18,21 @@ def normalize_component_layout(display_config: Any) -> tuple[dict[str, str], lis return display_config.component_modes(), list(display_config.COMPONENT_ORDER) -def build_layer_key( +def build_route_key( + producer_identity: StreamProducerIdentity | dict[str, Any], component_info: dict[str, Any], component_modes: dict[str, str], component_order: list[str], data_type: StreamingDataType, ) -> str: - """Build canonical layer key from slice-mode components and payload type.""" - layer_key_parts: list[str] = [] + """Build hidden route key from producer identity, slice components, and type.""" + producer = StreamProducerIdentity.from_payload(producer_identity) + route_parts: list[str] = list(producer.route_parts()) for component in component_order: mode = component_modes[component] if mode == "slice" and component in component_info: - layer_key_parts.append(f"{component}_{component_info[component]}") + route_parts.append(f"{component}_{component_info[component]}") - layer_key = "_".join(layer_key_parts) if layer_key_parts else "default_layer" + route_key = StreamRouteKeyAuthority.join(route_parts) - return f"{layer_key}{data_type.napari_layer_suffix}" + return f"{route_key}{data_type.napari_layer_suffix}" diff --git a/src/polystore/streaming/viewer_transport.py b/src/polystore/streaming/viewer_transport.py index b906190..6947076 100644 --- a/src/polystore/streaming/viewer_transport.py +++ b/src/polystore/streaming/viewer_transport.py @@ -7,13 +7,14 @@ import zmq +from polystore.streaming.identity import StreamProducerIdentity + @dataclass(frozen=True) class ViewerTransportDefaults: """Declared transport defaults shared by viewer streaming backends.""" host: str = "localhost" - source: str = "unknown_source" ack_timeout_ms: int = 30_000 @@ -27,7 +28,7 @@ class ViewerStreamKwargs: transport_config: Any display_config: Any microscope_handler: Any - source: str + producer_identity: StreamProducerIdentity plate_path: Any component_metadata: Any component_metadata_by_path: Any @@ -48,7 +49,9 @@ def from_kwargs( transport_config=ViewerKwargAuthority.optional(kwargs, "transport_config"), display_config=ViewerKwargAuthority.required(kwargs, "display_config"), microscope_handler=ViewerKwargAuthority.required(kwargs, "microscope_handler"), - source=ViewerKwargAuthority.value_or_default(kwargs, "source", defaults.source), + producer_identity=StreamProducerIdentity.from_payload( + ViewerKwargAuthority.required(kwargs, "producer_identity") + ), plate_path=ViewerKwargAuthority.optional(kwargs, "plate_path"), component_metadata=ViewerKwargAuthority.optional(kwargs, "component_metadata"), component_metadata_by_path=ViewerKwargAuthority.optional( diff --git a/tests/test_streaming_metadata.py b/tests/test_streaming_metadata.py index 767a772..24fbaaa 100644 --- a/tests/test_streaming_metadata.py +++ b/tests/test_streaming_metadata.py @@ -4,6 +4,7 @@ from polystore.streaming._streaming_backend import StreamingBackend from polystore.streaming._streaming_backend import StreamingBatchRequest +from polystore.streaming.identity import StreamProducerIdentity class MetadataProbeStreamingBackend(StreamingBackend): @@ -14,6 +15,14 @@ def save_batch(self, data_list, file_paths, **kwargs): raise NotImplementedError +PRODUCER_IDENTITY = StreamProducerIdentity( + origin="pipeline", + output_kind="main", + output_key="main", + step_name="IdentifyPrimaryObjects", +) + + def test_streaming_component_metadata_rejects_unparsed_artifact_filename() -> None: backend = MetadataProbeStreamingBackend() microscope_handler = SimpleNamespace( @@ -24,7 +33,6 @@ def test_streaming_component_metadata_rejects_unparsed_artifact_filename() -> No backend._parse_component_metadata( "A01_s001_w1_z001_t001_Nuclei_step3_rois.roi.zip", microscope_handler, - source="IdentifyPrimaryObjects", ) @@ -40,7 +48,7 @@ def test_streaming_batch_items_reject_unparsed_artifact_filename() -> None: data_list=[object()], file_paths=["A01_s001_w1_z001_t001_Nuclei_step3_rois.roi.zip"], microscope_handler=microscope_handler, - source="IdentifyPrimaryObjects", + producer_identity=PRODUCER_IDENTITY, prepare_item=lambda _data, _path, _data_type: ({"payload": "ok"}, "image"), ) ) @@ -57,7 +65,7 @@ def test_streaming_batch_items_accept_per_path_component_metadata() -> None: data_list=[object()], file_paths=["A01_s001_w1_z001_t001_Nuclei_step3_rois.roi.zip"], microscope_handler=microscope_handler, - source="IdentifyPrimaryObjects", + producer_identity=PRODUCER_IDENTITY, prepare_item=lambda _data, _path, _data_type: ({"payload": "ok"}, "image"), component_metadata_by_path={ "A01_s001_w1_z001_t001_Nuclei_step3_rois.roi.zip": { @@ -73,8 +81,8 @@ def test_streaming_batch_items_accept_per_path_component_metadata() -> None: "well": "A01", "site": 1, "channel": 1, - "source": "IdentifyPrimaryObjects", } + assert batch_images[0]["producer_identity"] == PRODUCER_IDENTITY.to_payload() def test_streaming_component_metadata_preserves_parsed_filename_fields() -> None: @@ -88,10 +96,9 @@ def test_streaming_component_metadata_preserves_parsed_filename_fields() -> None metadata = backend._parse_component_metadata( "A01_s001_w1_z001_t001.TIF", microscope_handler, - source="Crop", ) - assert metadata == {"well": "A01", "channel": 1, "source": "Crop"} + assert metadata == {"well": "A01", "channel": 1} def test_streaming_component_metadata_prefers_explicit_metadata() -> None: @@ -103,7 +110,6 @@ def test_streaming_component_metadata_prefers_explicit_metadata() -> None: metadata = backend._parse_component_metadata( "A01_s001_w1_z001_t001_Nuclei_step3_rois.roi.zip", microscope_handler, - source="IdentifyPrimaryObjects", component_metadata={"well": "A01", "site": 1, "channel": 1}, ) @@ -111,7 +117,6 @@ def test_streaming_component_metadata_prefers_explicit_metadata() -> None: "well": "A01", "site": 1, "channel": 1, - "source": "IdentifyPrimaryObjects", } @@ -125,5 +130,4 @@ def test_streaming_component_metadata_rejects_invalid_parser_result() -> None: backend._parse_component_metadata( "A01_s001_w1_z001_t001.TIF", microscope_handler, - source="Crop", ) diff --git a/tests/test_streaming_receiver_core.py b/tests/test_streaming_receiver_core.py index 6f7cd63..be0184b 100644 --- a/tests/test_streaming_receiver_core.py +++ b/tests/test_streaming_receiver_core.py @@ -4,64 +4,162 @@ import time from polystore.streaming_constants import StreamingDataType +from polystore.streaming.identity import ( + StreamProducerDisplayNameAuthority, + StreamProducerIdentity, +) from polystore.streaming.receivers.core import ( DebouncedBatchEngine, group_items_by_component_modes, ) from polystore.streaming.receivers.napari import ( normalize_component_layout, - build_layer_key, + build_route_key, ) - -def test_group_items_by_component_modes_source_normalization_for_rois() -> None: +class PipelineProducerFixture: + """Nominal producer fixtures for receiver-core tests.""" + + MAIN_KIND = "main" + MAIN_KEY = "main" + ARTIFACT_KIND = "artifact" + + @classmethod + def main_output( + cls, + *, + step_name: str, + pipeline_position: int, + ) -> StreamProducerIdentity: + return StreamProducerIdentity.pipeline_output( + output_kind=cls.MAIN_KIND, + output_key=cls.MAIN_KEY, + step_name=step_name, + pipeline_position=pipeline_position, + ) + + @classmethod + def artifact_output( + cls, + *, + output_key: str, + step_name: str, + pipeline_position: int, + artifact_kind: str | None = None, + ) -> StreamProducerIdentity: + return StreamProducerIdentity.pipeline_output( + output_kind=cls.ARTIFACT_KIND, + output_key=output_key, + step_name=step_name, + pipeline_position=pipeline_position, + artifact_kind=artifact_kind, + ) + + +def test_group_items_by_component_modes_keys_windows_by_producer_identity() -> None: + image_identity = PipelineProducerFixture.main_output( + step_name="RawLoad", + pipeline_position=1, + ) + roi_identity = PipelineProducerFixture.artifact_output( + output_key="Nuclei", + step_name="Segment", + pipeline_position=2, + artifact_kind="object_labels", + ) items = [ { "data_type": "rois", - "metadata": {"source": "/tmp/foo_results", "well": "A01", "channel": 1}, + "metadata": {"well": "A01", "channel": 1}, + "producer_identity": roi_identity.to_payload(), }, { "data_type": "image", - "metadata": {"source": "step_1", "well": "A01", "channel": 1}, + "metadata": {"well": "A01", "channel": 1}, + "producer_identity": image_identity.to_payload(), }, ] - component_modes = {"source": "window", "well": "frame", "channel": "channel"} - component_order = ["source", "well", "channel"] + component_modes = {"well": "frame", "channel": "channel"} + component_order = ["well", "channel"] grouped = group_items_by_component_modes( items, component_modes=component_modes, component_order=component_order, - images_dir="/my/plate/images", ) - assert grouped.window_components == ["source"] + assert grouped.window_components == [] assert grouped.channel_components == ["channel"] assert grouped.frame_components == ["well"] - assert "source_images" in grouped.windows - assert "source_step_1" in grouped.windows + assert grouped.slice_components == [] + assert grouped.fixed_window_labels[ + "origin_pipeline_kind_artifact_out_Nuclei_step_2_name_Segment_artifact_object_labels" + ] == [("producer", "3. Segment Nuclei")] + assert set(grouped.windows) == { + "origin_pipeline_kind_artifact_out_Nuclei_step_2_name_Segment_artifact_object_labels", + "origin_pipeline_kind_main_out_main_step_1_name_RawLoad", + } + + +def test_stream_producer_display_name_authority_matches_pipeline_editor_indexing() -> None: + main_output = PipelineProducerFixture.main_output( + step_name="ConvertObjectsToImage", + pipeline_position=8, + ) + artifact_output = PipelineProducerFixture.artifact_output( + output_key="NucleiObjects3D", + step_name="ConvertObjectsToImage", + pipeline_position=8, + artifact_kind="object_labels", + ) + manual_output = StreamProducerIdentity.manual("selected_rois") + + assert ( + StreamProducerDisplayNameAuthority.producer_label(main_output) + == "9. ConvertObjectsToImage" + ) + assert ( + StreamProducerDisplayNameAuthority.output_label(main_output) + == "9. ConvertObjectsToImage" + ) + assert ( + StreamProducerDisplayNameAuthority.output_label(artifact_output) + == "9. ConvertObjectsToImage NucleiObjects3D" + ) + assert StreamProducerDisplayNameAuthority.output_label(manual_output) == "selected_rois" + assert ( + StreamProducerDisplayNameAuthority.disambiguation_label(main_output) + == "step 9" + ) -def test_napari_layer_key_builder_uses_slice_components_and_payload_type() -> None: +def test_napari_route_key_builder_uses_producer_slice_components_and_payload_type() -> None: + producer = PipelineProducerFixture.artifact_output( + output_key="Nuclei", + step_name="Segment", + pipeline_position=2, + ) component_modes = {"well": "slice", "channel": "stack", "site": "slice"} component_order = ["well", "channel", "site"] component_info = {"well": "A01", "channel": 2, "site": 3} - key_image = build_layer_key( + key_image = build_route_key( + producer_identity=producer, component_info=component_info, component_modes=component_modes, component_order=component_order, data_type=StreamingDataType.IMAGE, ) - key_shapes = build_layer_key( + key_shapes = build_route_key( + producer_identity=producer, component_info=component_info, component_modes=component_modes, component_order=component_order, data_type=StreamingDataType.SHAPES, ) - assert key_image == "well_A01_site_3" - assert key_shapes == "well_A01_site_3_shapes" + assert key_image == "origin_pipeline_kind_artifact_out_Nuclei_step_2_name_Segment_well_A01_site_3" + assert key_shapes == "origin_pipeline_kind_artifact_out_Nuclei_step_2_name_Segment_well_A01_site_3_shapes" def test_normalize_component_layout_dict_config() -> None: From 3304f61749c14e7402584c136dbdc6fcdd4f7a86 Mon Sep 17 00:00:00 2001 From: Tristan Simas Date: Fri, 19 Jun 2026 02:44:02 -0400 Subject: [PATCH 11/11] Refine viewer streaming projection contracts --- src/polystore/fiji_stream.py | 286 +++--- src/polystore/napari_stream.py | 180 ++-- src/polystore/streaming/__init__.py | 31 +- src/polystore/streaming/_streaming_backend.py | 867 ++++++++++++------ src/polystore/streaming/identity.py | 166 ++-- src/polystore/streaming/receivers/__init__.py | 4 + .../streaming/receivers/core/__init__.py | 5 +- .../receivers/core/window_projection.py | 195 +++- .../receivers/fiji/fiji_batch_processor.py | 9 +- .../streaming/receivers/napari/layer_key.py | 85 +- .../napari/napari_batch_processor.py | 28 +- src/polystore/streaming/viewer_transport.py | 392 ++++++-- tests/test_streaming_metadata.py | 280 ++++-- tests/test_streaming_receiver_core.py | 81 +- tests/test_viewer_transport.py | 151 +++ 15 files changed, 1904 insertions(+), 856 deletions(-) create mode 100644 tests/test_viewer_transport.py diff --git a/src/polystore/fiji_stream.py b/src/polystore/fiji_stream.py index 0a2f36f..2cbeb1c 100644 --- a/src/polystore/fiji_stream.py +++ b/src/polystore/fiji_stream.py @@ -12,30 +12,37 @@ """ import logging -import time -from pathlib import Path -from typing import Any, List, Union +from enum import Enum -import zmq - -from .constants import Backend, TransportMode +from .constants import Backend from .streaming_constants import StreamingDataType -from .streaming import StreamingBackend +from .streaming import ( + FilePath, + RoiStreamPayload, + StreamingBuiltBatch, + StreamingBackend, + StreamingComponentNamesRequest, + StreamingItemPreparationRequest, + ViewerDisplayPayloadExtra, +) +from .streaming.viewer_transport import ViewerStreamRequest from .roi_converters import FijiROIConverter -from .streaming.viewer_transport import ( - ViewerAckPolicy, - ViewerStreamKwargs, - ViewerTransportConfigAuthority, - ViewerTransportDefaults, +from zmqruntime.viewer_protocol import ( + ViewerBatchContextWireField, + ViewerBatchItemWireField, + ViewerBatchWireField, + ViewerWireMapping, + ViewerWireValue, ) -from zmqruntime.transport import get_zmq_transport_url, coerce_transport_mode logger = logging.getLogger(__name__) -FIJI_TRANSPORT_DEFAULTS = ViewerTransportDefaults() -FIJI_ACK_POLICY = ViewerAckPolicy( - viewer_name="Fiji", - timeout_ms=FIJI_TRANSPORT_DEFAULTS.ack_timeout_ms, -) + + +class FijiDisplayWireField(str, Enum): + """Fiji-specific display fields inside the shared viewer display payload.""" + + LUT = "lut" + AUTO_CONTRAST = "auto_contrast" class FijiDisplayPayload: @@ -43,15 +50,15 @@ class FijiDisplayPayload: @staticmethod def auto_contrast_value(display_config) -> bool: - if not hasattr(display_config, "auto_contrast"): - return True return display_config.auto_contrast @classmethod - def from_display_config(cls, display_config) -> dict[str, Any]: + def from_display_config(cls, display_config) -> dict[str, ViewerWireValue]: return { - "lut": display_config.get_lut_name(), - "auto_contrast": cls.auto_contrast_value(display_config), + FijiDisplayWireField.LUT.value: display_config.get_lut_name(), + FijiDisplayWireField.AUTO_CONTRAST.value: cls.auto_contrast_value( + display_config + ), } @@ -59,20 +66,18 @@ class FijiMessageMetadata: """Typed access to optional Fiji message metadata.""" @staticmethod - def component_names_metadata(message: dict) -> dict: - if "component_names_metadata" in message: - return message["component_names_metadata"] - return {} + def component_names_metadata(message: ViewerWireMapping) -> ViewerWireValue: + return message[ViewerBatchWireField.COMPONENT_NAMES_METADATA.value] class FijiRoiPayload: """ROI payload inspection for Fiji logging.""" @staticmethod - def count(item_data: dict) -> int: - if "rois" not in item_data: + def count(item_data: ViewerWireMapping) -> int: + if ViewerBatchItemWireField.ROIS.value not in item_data: raise ValueError("Fiji ROI payload missing required 'rois' field") - return len(item_data["rois"]) + return len(item_data[ViewerBatchItemWireField.ROIS.value]) class FijiStreamingBackend(StreamingBackend): @@ -82,7 +87,70 @@ class FijiStreamingBackend(StreamingBackend): VIEWER_TYPE = 'fiji' SHM_PREFIX = 'fiji_' - def _prepare_rois_data(self, data: Any, file_path: Union[str, Path]) -> dict: + def _display_payload_extra( + self, + stream_request: ViewerStreamRequest, + ) -> ViewerDisplayPayloadExtra: + return ViewerDisplayPayloadExtra.from_mapping( + FijiDisplayPayload.from_display_config(stream_request.display_config) + ) + + def _message_extra( + self, + stream_request: ViewerStreamRequest, + ) -> dict[str, ViewerWireValue]: + message_extra = stream_request.message_extra_payload() + message_extra[ViewerBatchContextWireField.IMAGES_DIR.value] = ( + stream_request.images_dir + ) + return message_extra + + def _component_names_request( + self, + stream_request: ViewerStreamRequest, + ) -> StreamingComponentNamesRequest: + return StreamingComponentNamesRequest.from_stream_request( + stream_request, + log_prefix="🏷️ FIJI BACKEND", + verbose=True, + ) + + def _after_batch_message_built( + self, + stream_request: ViewerStreamRequest, + built_batch: StreamingBuiltBatch, + ) -> None: + logger.info( + "🏷️ FIJI BACKEND: Final component_names_metadata: %s", + FijiMessageMetadata.component_names_metadata(built_batch.message), + ) + + for item in built_batch.batch_images: + logger.info( + "🔍 FIJI BACKEND: Added %s item to batch", + item[ViewerBatchItemWireField.DATA_TYPE.value], + ) + + data_types = [ + item[ViewerBatchItemWireField.DATA_TYPE.value] + for item in built_batch.batch_images + ] + type_counts = { + data_type: data_types.count(data_type) + for data_type in set(data_types) + } + logger.info( + "📤 FIJI BACKEND: Sending batch message with %d items to port %s: %s", + len(built_batch.batch_images), + stream_request.port, + type_counts, + ) + + def _prepare_rois_data( + self, + data: RoiStreamPayload, + file_path: FilePath, + ) -> dict[str, ViewerWireValue]: """ Prepare ROIs data for transmission. @@ -98,136 +166,44 @@ def _prepare_rois_data(self, data: Any, file_path: Union[str, Path]) -> dict: rois_encoded = FijiROIConverter.encode_rois_for_transmission(roi_bytes_list) return { - 'path': str(file_path), - 'rois': rois_encoded, + ViewerBatchItemWireField.PATH.value: str(file_path), + ViewerBatchItemWireField.ROIS.value: rois_encoded, } - def _prepare_batch_item(self, data: Any, file_path: Union[str, Path], data_type): - logger.info(f"🔍 FIJI BACKEND: Detected data type: {data_type} for path: {file_path}") - if data_type == StreamingDataType.SHAPES: - logger.info(f"🔍 FIJI BACKEND: Preparing ROI data for {file_path}") - item_data = self._prepare_rois_data(data, file_path) - data_type_value = "rois" + def _prepare_batch_item( + self, + request: StreamingItemPreparationRequest, + ) -> tuple[ViewerWireMapping, str]: + logger.info( + "🔍 FIJI BACKEND: Detected data type: %s for path: %s", + request.data_type, + request.item_path.value, + ) + if request.data_type == StreamingDataType.SHAPES: + logger.info( + "🔍 FIJI BACKEND: Preparing ROI data for %s", + request.item_path.value, + ) + item_data = self._prepare_rois_data( + request.data, + request.item_path.value, + ) + data_type_value = StreamingDataType.ROIS.value logger.info( - f"🔍 FIJI BACKEND: ROI data prepared: {FijiRoiPayload.count(item_data)} ROIs" + "🔍 FIJI BACKEND: ROI data prepared: %d ROIs", + FijiRoiPayload.count(item_data), ) else: - logger.info(f"🔍 FIJI BACKEND: Preparing image data for {file_path}") - item_data = self._create_shared_memory(data, file_path) - data_type_value = "image" - return item_data, data_type_value - - def save_batch(self, data_list: List[Any], file_paths: List[Union[str, Path]], **kwargs) -> None: - """Stream batch of images or ROIs to Fiji via ZMQ.""" - - logger.info(f"📦 FIJI BACKEND: save_batch called with {len(data_list)} items") - - # Filter to only supported file types - data_list, file_paths, skipped = self._filter_streamable_files(data_list, file_paths) - if not data_list: - return - - stream_request = ViewerStreamKwargs.from_kwargs( - kwargs, - FIJI_TRANSPORT_DEFAULTS, - include_images_dir=True, - ) - logger.info(f"🏷️ FIJI BACKEND: plate_path = {stream_request.plate_path}") - logger.info(f"🏷️ FIJI BACKEND: microscope_handler = {stream_request.microscope_handler}") - display_payload_extra = FijiDisplayPayload.from_display_config( - stream_request.display_config - ) - message_extra = { - "images_dir": stream_request.images_dir, - } - - message, batch_images, image_ids = self._build_batch_message( - data_list, - file_paths, - stream_request.microscope_handler, - stream_request.producer_identity, - stream_request.display_config, - self._prepare_batch_item, - plate_path=stream_request.plate_path, - component_metadata=stream_request.component_metadata, - component_metadata_by_path=stream_request.component_metadata_by_path, - component_names_kwargs={"log_prefix": "🏷️ FIJI BACKEND", "verbose": True}, - display_payload_extra=display_payload_extra, - message_extra=message_extra, - ) - - logger.info( - "🏷️ FIJI BACKEND: Final component_names_metadata: %s", - FijiMessageMetadata.component_names_metadata(message), - ) - - for item in batch_images: - logger.info(f"🔍 FIJI BACKEND: Added {item['data_type']} item to batch") - - # Log batch composition - data_types = [item['data_type'] for item in batch_images] - type_counts = {dt: data_types.count(dt) for dt in set(data_types)} - logger.info( - "📤 FIJI BACKEND: Sending batch message with %d items to port %s: %s", - len(batch_images), - stream_request.port, - type_counts, - ) - - # Register sent images with queue tracker BEFORE sending - # This prevents race condition with IPC mode where acks arrive before registration - self._register_with_queue_tracker( - stream_request.port, - image_ids, - transport_mode=stream_request.transport_mode, - transport_config=stream_request.transport_config, - ) - - # Create FRESH REQ socket for each send - REQ sockets cannot be reused - # This prevents the "Operation cannot be accomplished in current state" error - # when multiple streams happen concurrently - transport_config = ViewerTransportConfigAuthority.resolve( - stream_request.transport_config, - self._transport_config, - ) - url = get_zmq_transport_url( - stream_request.port, - host=stream_request.host, - mode=coerce_transport_mode(stream_request.transport_mode), - config=transport_config, - ) - - if self._context is None: - self._context = zmq.Context() - - socket = self._context.socket(zmq.REQ) - FIJI_ACK_POLICY.apply_socket_options(socket) - socket.connect(url) - time.sleep(0.1) # Brief delay for connection to establish - - try: - # Send with REQ socket (BLOCKING - worker waits for Fiji to acknowledge) - # Worker blocks until Fiji receives, copies data from shared memory, and sends ack - # This guarantees no messages are lost and shared memory is only closed after Fiji is done - logger.info(f"📤 FIJI BACKEND: Sending batch of {len(batch_images)} images to Fiji on port {stream_request.port} (REQ/REP - blocking until ack)") - socket.send_json(message) # Blocking send - - # Wait for acknowledgment from Fiji (REP socket) - # Fiji will only reply after it has copied all data from shared memory - ack_response = FIJI_ACK_POLICY.receive( - socket, - lambda: self._cleanup_shared_memory_blocks(batch_images, unlink=True), - port=stream_request.port, + logger.info( + "🔍 FIJI BACKEND: Preparing image data for %s", + request.item_path.value, ) - logger.info(f"✅ FIJI BACKEND: Received ack from Fiji: {FIJI_ACK_POLICY.status(ack_response)}") - - finally: - # Always close the socket - never reuse REQ sockets - socket.close() - - # Clean up publisher's handles after successful send - # Receiver will unlink the shared memory after copying the data - self._cleanup_shared_memory_blocks(batch_images, unlink=False) + item_data = self.create_shared_memory_payload( + request.data, + request.item_path.value, + ) + data_type_value = StreamingDataType.IMAGE.value + return item_data, data_type_value # cleanup() now inherited from ABC diff --git a/src/polystore/napari_stream.py b/src/polystore/napari_stream.py index f80348e..a0940bc 100644 --- a/src/polystore/napari_stream.py +++ b/src/polystore/napari_stream.py @@ -13,29 +13,32 @@ """ import logging -import time -from pathlib import Path -from typing import Any, List, Union - -import zmq - -from .constants import Backend, TransportMode -from .streaming import StreamingBackend +from enum import Enum + +from .constants import Backend +from .streaming import ( + FilePath, + RoiStreamPayload, + StreamingBackend, + StreamingItemPreparationRequest, + ViewerDisplayPayloadExtra, +) +from .streaming.viewer_transport import ViewerStreamRequest from .roi_converters import NapariROIConverter -from .streaming.viewer_transport import ( - ViewerAckPolicy, - ViewerStreamKwargs, - ViewerTransportConfigAuthority, - ViewerTransportDefaults, +from zmqruntime.viewer_protocol import ( + ViewerBatchItemWireField, + ViewerWireMapping, + ViewerWireValue, ) -from zmqruntime.transport import get_zmq_transport_url, coerce_transport_mode logger = logging.getLogger(__name__) -NAPARI_TRANSPORT_DEFAULTS = ViewerTransportDefaults() -NAPARI_ACK_POLICY = ViewerAckPolicy( - viewer_name="Napari", - timeout_ms=NAPARI_TRANSPORT_DEFAULTS.ack_timeout_ms, -) + + +class NapariDisplayWireField(str, Enum): + """Napari-specific display fields inside the shared viewer display payload.""" + + COLORMAP = "colormap" + VARIABLE_SIZE_HANDLING = "variable_size_handling" class NapariDisplayPayload: @@ -43,18 +46,18 @@ class NapariDisplayPayload: @staticmethod def variable_size_handling_value(display_config): - if not hasattr(display_config, "variable_size_handling"): - return None variable_size_handling = display_config.variable_size_handling if variable_size_handling is None: return None return variable_size_handling.value @classmethod - def from_display_config(cls, display_config) -> dict[str, Any]: + def from_display_config(cls, display_config) -> dict[str, ViewerWireValue]: return { - "colormap": display_config.get_colormap_name(), - "variable_size_handling": cls.variable_size_handling_value(display_config), + NapariDisplayWireField.COLORMAP.value: display_config.get_colormap_name(), + NapariDisplayWireField.VARIABLE_SIZE_HANDLING.value: ( + cls.variable_size_handling_value(display_config) + ), } @@ -65,7 +68,19 @@ class NapariStreamingBackend(StreamingBackend): VIEWER_TYPE = 'napari' SHM_PREFIX = 'napari_' - def _prepare_shapes_data(self, data: Any, file_path: Union[str, Path]) -> dict: + def _display_payload_extra( + self, + stream_request: ViewerStreamRequest, + ) -> ViewerDisplayPayloadExtra: + return ViewerDisplayPayloadExtra.from_mapping( + NapariDisplayPayload.from_display_config(stream_request.display_config) + ) + + def _prepare_shapes_data( + self, + data: RoiStreamPayload, + file_path: FilePath, + ) -> dict[str, ViewerWireValue]: """ Prepare shapes data for transmission. @@ -79,108 +94,27 @@ def _prepare_shapes_data(self, data: Any, file_path: Union[str, Path]) -> dict: shapes_data = NapariROIConverter.rois_to_shapes(data) return { - 'path': str(file_path), - 'shapes': shapes_data, + ViewerBatchItemWireField.PATH.value: str(file_path), + ViewerBatchItemWireField.SHAPES.value: shapes_data, } - def _prepare_batch_item(self, data: Any, file_path: Union[str, Path], data_type): - if data_type.uses_napari_vector_payload: - item_data = self._prepare_shapes_data(data, file_path) - data_type_value = data_type.value + def _prepare_batch_item( + self, + request: StreamingItemPreparationRequest, + ) -> tuple[ViewerWireMapping, str]: + if request.data_type.uses_napari_vector_payload: + item_data = self._prepare_shapes_data( + request.data, + request.item_path.value, + ) + data_type_value = request.data_type.value else: - item_data = self._create_shared_memory(data, file_path) - data_type_value = data_type.value - return item_data, data_type_value - - def save_batch(self, data_list: List[Any], file_paths: List[Union[str, Path]], **kwargs) -> None: - """ - Stream multiple images or ROIs to napari as a batch. - - Args: - data_list: List of image data or ROI lists - file_paths: List of path identifiers - **kwargs: Additional metadata - """ - # Filter to only supported file types - data_list, file_paths, skipped = self._filter_streamable_files(data_list, file_paths) - if not data_list: - return - - stream_request = ViewerStreamKwargs.from_kwargs( - kwargs, - NAPARI_TRANSPORT_DEFAULTS, - ) - display_payload_extra = NapariDisplayPayload.from_display_config( - stream_request.display_config - ) - - message, batch_images, image_ids = self._build_batch_message( - data_list, - file_paths, - stream_request.microscope_handler, - stream_request.producer_identity, - stream_request.display_config, - self._prepare_batch_item, - plate_path=stream_request.plate_path, - component_metadata=stream_request.component_metadata, - component_metadata_by_path=stream_request.component_metadata_by_path, - display_payload_extra=display_payload_extra, - ) - - # Register sent images with queue tracker BEFORE sending - # This prevents race condition with IPC mode where acks arrive before registration - self._register_with_queue_tracker( - stream_request.port, - image_ids, - transport_mode=stream_request.transport_mode, - transport_config=stream_request.transport_config, - ) - - # Create FRESH REQ socket for each send - REQ sockets cannot be reused - # This prevents the "Operation cannot be accomplished in current state" error - # when multiple streams happen concurrently - transport_config = ViewerTransportConfigAuthority.resolve( - stream_request.transport_config, - self._transport_config, - ) - url = get_zmq_transport_url( - stream_request.port, - host=stream_request.host, - mode=coerce_transport_mode(stream_request.transport_mode), - config=transport_config, - ) - - if self._context is None: - self._context = zmq.Context() - - socket = self._context.socket(zmq.REQ) - NAPARI_ACK_POLICY.apply_socket_options(socket) - socket.connect(url) - time.sleep(0.1) # Brief delay for connection to establish - - try: - # Send with REQ socket (BLOCKING - worker waits for Napari to acknowledge) - # Worker blocks until Napari receives, copies data from shared memory, and sends ack - # This guarantees no messages are lost and shared memory is only closed after Napari is done - logger.info(f"📤 NAPARI BACKEND: Sending batch of {len(batch_images)} images to Napari on port {stream_request.port} (REQ/REP - blocking until ack)") - socket.send_json(message) # Blocking send - - # Wait for acknowledgment from Napari (REP socket) - # Napari will only reply after it has copied all data from shared memory - ack_response = NAPARI_ACK_POLICY.receive( - socket, - lambda: self._cleanup_shared_memory_blocks(batch_images, unlink=True), - port=stream_request.port, + item_data = self.create_shared_memory_payload( + request.data, + request.item_path.value, ) - logger.info(f"✅ NAPARI BACKEND: Received ack from Napari: {NAPARI_ACK_POLICY.status(ack_response)}") - - finally: - # Always close the socket - never reuse REQ sockets - socket.close() - - # Clean up publisher's handles after successful send - # Receiver will unlink the shared memory after copying the data - self._cleanup_shared_memory_blocks(batch_images, unlink=False) + data_type_value = request.data_type.value + return item_data, data_type_value # cleanup() now inherited from ABC diff --git a/src/polystore/streaming/__init__.py b/src/polystore/streaming/__init__.py index 8a8536e..2217b4b 100644 --- a/src/polystore/streaming/__init__.py +++ b/src/polystore/streaming/__init__.py @@ -10,7 +10,32 @@ # This allows both: # from polystore.streaming import StreamingBackend # from polystore.streaming.receivers import FijiBatchProcessor -from polystore.streaming._streaming_backend import StreamingBackend - -__all__ = ["StreamingBackend"] +from polystore.streaming._streaming_backend import ( + FilePath, + RoiStreamPayload, + StreamablePayload, + StreamingBatchItemPreparationAuthority, + StreamingBatchMessageBuilder, + StreamingBatchMessageRequest, + StreamingBuiltBatch, + StreamingPreparedBatchItems, + StreamingBackend, + StreamingComponentNamesRequest, + StreamingItemPreparationRequest, + ViewerDisplayPayloadExtra, +) +__all__ = [ + "FilePath", + "RoiStreamPayload", + "StreamablePayload", + "StreamingBatchItemPreparationAuthority", + "StreamingBatchMessageBuilder", + "StreamingBatchMessageRequest", + "StreamingBuiltBatch", + "StreamingPreparedBatchItems", + "StreamingBackend", + "StreamingComponentNamesRequest", + "StreamingItemPreparationRequest", + "ViewerDisplayPayloadExtra", +] diff --git a/src/polystore/streaming/_streaming_backend.py b/src/polystore/streaming/_streaming_backend.py index 8d82d75..1f4976f 100644 --- a/src/polystore/streaming/_streaming_backend.py +++ b/src/polystore/streaming/_streaming_backend.py @@ -5,73 +5,262 @@ data to external systems without persistent storage capabilities. """ +from __future__ import annotations + import logging -import os import time import uuid -from collections.abc import Sequence -from dataclasses import dataclass +from collections.abc import Mapping, Sequence +from dataclasses import dataclass, field +from enum import Enum +from multiprocessing import resource_tracker, shared_memory from pathlib import Path -from typing import Any, Callable, List, Mapping, Set, Union +from types import MappingProxyType +from typing import TypeAlias import numpy as np +import zmq from arraybridge import convert_memory, detect_memory_type from arraybridge.types import MemoryType as ArrayBridgeMemoryType from ..base import DataSink -from ..constants import TransportMode from ..streaming_constants import StreamingDataType from ..roi import ROI, PointShape from ..zmq_config import POLYSTORE_ZMQ_CONFIG -from .identity import StreamProducerIdentity +from .viewer_transport import ( + ViewerDisplayConfigABC, + ViewerMicroscopeHandlerABC, + ViewerStreamBackendKwargs, + ViewerStreamRequest, + ViewerTransportDefaults, +) from zmqruntime.ack_listener import GlobalAckListener -from zmqruntime.transport import coerce_transport_mode +from zmqruntime.config import ZMQConfig +from zmqruntime.viewer_protocol import ( + ViewerBatchDisplayPayload, + ViewerBatchItemPayload, + ViewerBatchItemWireField, + ViewerBatchMessagePayload, + ViewerComponentMetadataPayload, + ViewerDisplayConfigWireField, + ViewerTransportEndpoint, + ViewerWireMapping, + ViewerWireValue, +) logger = logging.getLogger(__name__) -PrepareStreamingItem = Callable[[Any, Union[str, Path], Any], tuple[dict, str]] -ComponentMetadataByPath = ( - Mapping[str, Mapping[str, Any] | None] - | Sequence[Mapping[str, Any] | None] - | None -) +FilePath: TypeAlias = str | Path +RoiStreamPayload: TypeAlias = Sequence[ROI] +StreamablePayload: TypeAlias = np.ndarray | Sequence[ViewerWireValue] | RoiStreamPayload +ComponentValue = str | int | float | bool | tuple | None +ViewerDisplayPayloadExtraValues: TypeAlias = Mapping[ + str | ViewerDisplayConfigWireField, + ViewerWireValue, +] +STREAMING_TRANSPORT_DEFAULTS = ViewerTransportDefaults() @dataclass(frozen=True) -class StreamingComponentMetadata: - """Message metadata for one streamed item.""" +class ViewerDisplayPayloadExtra: + """Nominal viewer-specific display payload extension.""" + + values: ViewerDisplayPayloadExtraValues = field( + default_factory=lambda: MappingProxyType({}) + ) + + @classmethod + def from_mapping( + cls, + values: ViewerDisplayPayloadExtraValues, + ) -> "ViewerDisplayPayloadExtra": + return cls(values) + + def to_wire_mapping(self) -> dict[str, ViewerWireValue]: + return dict(self.values) + - parsed_filename_metadata: Mapping[str, Any] +EMPTY_DISPLAY_PAYLOAD_EXTRA = ViewerDisplayPayloadExtra() - def to_payload(self) -> dict[str, Any]: - if isinstance(self.parsed_filename_metadata, Mapping): - metadata = dict(self.parsed_filename_metadata) - else: + +class StreamingComponentValueDomainAuthority: + """Build batch-level component value domains from stream item metadata.""" + + @staticmethod + def wire_payload( + stream_request: ViewerStreamRequest, + batch_images: Sequence[ViewerWireMapping], + ) -> dict[str, ViewerWireValue]: + component_order = tuple( + str(component) + for component in stream_request.display_config.COMPONENT_ORDER + ) + values_by_component: dict[str, list[ComponentValue]] = { + component: [] for component in component_order + } + for image_payload in batch_images: + metadata = StreamingComponentValueDomainAuthority._metadata(image_payload) + for component in component_order: + if component not in metadata: + continue + value = StreamingComponentValueDomainAuthority._component_value( + metadata[component] + ) + if value not in values_by_component[component]: + values_by_component[component].append(value) + return { + component: values + for component, values in values_by_component.items() + if values + } + + @staticmethod + def _metadata(image_payload: ViewerWireMapping) -> ViewerWireMapping: + metadata = image_payload[ViewerBatchItemWireField.METADATA.value] + if not isinstance(metadata, Mapping): raise TypeError( - "Streaming component metadata must be a mapping, " - f"got {type(self.parsed_filename_metadata).__name__}." + "Streaming batch item metadata must be a mapping, " + f"got {type(metadata).__name__}." ) - return metadata + return dict(metadata) + + @staticmethod + def _component_value(value: ViewerWireValue) -> ComponentValue: + if isinstance(value, (str, int, float, bool)) or value is None: + return value + if isinstance(value, tuple): + return value + if isinstance(value, Sequence) and not isinstance(value, (str, bytes)): + return tuple(value) + raise TypeError( + "Streaming component values must be JSON scalar or tuple-like, " + f"got {type(value).__name__}." + ) + +@dataclass(frozen=True) +class StreamingComponentNamesRequest: + """Component-label metadata requested for one viewer batch.""" + + component_names: Sequence[str] + log_prefix: str | None = None + verbose: bool = False + + @classmethod + def from_stream_request( + cls, + stream_request: ViewerStreamRequest, + log_prefix: str | None = None, + verbose: bool = False, + ) -> "StreamingComponentNamesRequest": + return cls( + component_names=tuple( + str(component) + for component in stream_request.display_config.COMPONENT_ORDER + ), + log_prefix=log_prefix, + verbose=verbose, + ) + + +@dataclass(frozen=True) +class StreamingBatchMessageRequest: + """Inputs for building one viewer batch message.""" + + data_list: list[StreamablePayload] + file_paths: list[FilePath] + stream_request: ViewerStreamRequest + component_names_request: StreamingComponentNamesRequest | None = None + display_payload_extra: ViewerDisplayPayloadExtra = field( + default_factory=ViewerDisplayPayloadExtra + ) + + def resolved_component_names_request(self) -> StreamingComponentNamesRequest: + if self.component_names_request is not None: + return self.component_names_request + return StreamingComponentNamesRequest.from_stream_request( + self.stream_request + ) + + +@dataclass(frozen=True) +class StreamingPreparedBatchItems: + """Prepared per-item viewer payloads before batch-level metadata is attached.""" + + batch_images: list[dict[str, ViewerWireValue]] + image_ids: list[str] + + +@dataclass(frozen=True) +class StreamingBuiltBatch(StreamingPreparedBatchItems): + """Prepared viewer message and per-item transmission bookkeeping.""" + + message: dict[str, ViewerWireValue] + + +@dataclass(frozen=True) +class StreamingItemPath: + """Nominal path identity for one item in a viewer stream batch.""" + + value: FilePath + + @property + def wire_value(self) -> str: + return str(self.value) + + +@dataclass(frozen=True) +class StreamingPayloadFileRequest: + """Shared payload/file identity for viewer item preparation requests.""" + + data: StreamablePayload + item_path: StreamingItemPath + + +@dataclass(frozen=True) +class StreamingItemPreparationRequest(StreamingPayloadFileRequest): + """Inputs needed to prepare one payload for a viewer batch item.""" + + data_type: StreamingDataType + + +@dataclass(frozen=True) +class StreamingSharedMemoryRequest(StreamingPayloadFileRequest): + """Inputs needed to allocate one image payload into shared memory.""" + + shm_prefix: str + + +@dataclass(frozen=True) +class StreamingSharedMemoryPayload: + """Wire payload describing a shared-memory image allocation.""" + + item_path: StreamingItemPath + shape: tuple[int, ...] + dtype: str + shm_name: str + + def to_wire_mapping(self) -> dict[str, ViewerWireValue]: + return { + ViewerBatchItemWireField.PATH.value: self.item_path.wire_value, + ViewerBatchItemWireField.SHAPE.value: self.shape, + ViewerBatchItemWireField.DTYPE.value: self.dtype, + ViewerBatchItemWireField.SHM_NAME.value: self.shm_name, + } @dataclass(frozen=True) -class StreamingBatchRequest: - """Shared provenance for one streaming batch.""" +class StreamingSharedMemoryBlock: + """Allocated shared memory and the wire payload that names it.""" - data_list: List[Any] - file_paths: List[Union[str, Path]] - microscope_handler: Any - producer_identity: StreamProducerIdentity - prepare_item: PrepareStreamingItem - component_metadata: Mapping[str, Any] | None = None - component_metadata_by_path: ComponentMetadataByPath = None + shared_memory: shared_memory.SharedMemory + payload: StreamingSharedMemoryPayload class StreamingPayloadMemoryAuthority: """Memory conversion authority for streamable image payloads.""" @staticmethod - def to_numpy(data: Any) -> np.ndarray: + def to_numpy(data: StreamablePayload) -> np.ndarray: if isinstance(data, np.ndarray): return data if isinstance(data, (list, tuple)): @@ -84,6 +273,240 @@ def to_numpy(data: Any) -> np.ndarray: ) +class StreamingSharedMemoryAuthority: + """Allocate image payloads for viewer transfer through shared memory.""" + + @classmethod + def create( + cls, + request: StreamingSharedMemoryRequest, + ) -> StreamingSharedMemoryBlock: + np_data = StreamingPayloadMemoryAuthority.to_numpy(request.data) + shm_name = cls._name(request.shm_prefix) + shm = shared_memory.SharedMemory( + create=True, + size=np_data.nbytes, + name=shm_name, + ) + resource_tracker.unregister(shm._name, "shared_memory") + + shm_array = np.ndarray(np_data.shape, dtype=np_data.dtype, buffer=shm.buf) + shm_array[:] = np_data[:] + + return StreamingSharedMemoryBlock( + shared_memory=shm, + payload=StreamingSharedMemoryPayload( + item_path=request.item_path, + shape=tuple(int(dimension) for dimension in np_data.shape), + dtype=str(np_data.dtype), + shm_name=shm_name, + ), + ) + + @staticmethod + def _name(shm_prefix: str) -> str: + return f"{shm_prefix}{uuid.uuid4().hex[:12]}" + + +class StreamingDataTypeAuthority: + """Detect the viewer payload kind for one streamed object.""" + + @staticmethod + def detect(data: StreamablePayload) -> StreamingDataType: + is_roi = isinstance(data, list) and len(data) > 0 and isinstance(data[0], ROI) + + if not is_roi: + return StreamingDataType.IMAGE + + all_points = all( + roi.shapes and all(isinstance(shape, PointShape) for shape in roi.shapes) + for roi in data + ) + + return StreamingDataType.POINTS if all_points else StreamingDataType.SHAPES + + +class StreamingComponentNamesMetadataCollector: + """Collect viewer component-label metadata for one batch.""" + + @staticmethod + def collect( + plate_path: FilePath | None, + microscope_handler: ViewerMicroscopeHandlerABC, + request: StreamingComponentNamesRequest, + ) -> dict[str, ViewerWireValue]: + component_names_metadata = {} + + if plate_path is None: + if request.verbose and request.log_prefix: + logger.warning("%s: No plate_path in kwargs", request.log_prefix) + return component_names_metadata + + for component_name in request.component_names: + metadata = microscope_handler.metadata_handler.get_component_values( + plate_path, + component_name, + ) + if request.verbose and request.log_prefix: + logger.info( + "%s: Got %s metadata: %s", + request.log_prefix, + component_name, + metadata, + ) + if metadata: + component_names_metadata[component_name] = metadata + + return component_names_metadata + + +class StreamingDisplayPayloadBuilder: + """Build the shared viewer display-config payload.""" + + @staticmethod + def build( + stream_request: ViewerStreamRequest, + display_payload_extra: ViewerDisplayPayloadExtra, + ) -> ViewerBatchDisplayPayload: + return ViewerBatchDisplayPayload( + component_modes={ + str(component): str(mode.value if isinstance(mode, Enum) else mode) + for component, mode in stream_request.display_config.component_modes().items() + }, + component_order=tuple( + str(component) + for component in stream_request.display_config.COMPONENT_ORDER + ), + extra=display_payload_extra.to_wire_mapping(), + ) + + +class StreamingBatchItemPreparationAuthority: + """Prepare per-item viewer payloads and transmission bookkeeping.""" + + @staticmethod + def prepare( + backend: "StreamingBackend", + request: StreamingBatchMessageRequest, + ) -> StreamingPreparedBatchItems: + batch_images = [] + image_ids = [] + + for index, (data, file_path) in enumerate( + zip(request.data_list, request.file_paths) + ): + item_path = StreamingItemPath(file_path) + image_id = str(uuid.uuid4()) + image_ids.append(image_id) + + data_type = StreamingDataTypeAuthority.detect(data) + explicit_component_metadata = ( + request.stream_request.source.metadata.component_metadata_for_item( + item_path.value, + index, + ) + ) + item_data, data_type_value = backend._prepare_batch_item( + StreamingItemPreparationRequest( + data=data, + item_path=item_path, + data_type=data_type, + ) + ) + + batch_images.append( + ViewerBatchItemPayload.from_parts( + item_payload=item_data, + data_type=data_type_value, + metadata=explicit_component_metadata, + producer_identity=( + request.stream_request.producer_identity.to_payload() + ), + image_id=image_id, + ).to_wire_mapping() + ) + + return StreamingPreparedBatchItems( + batch_images=batch_images, + image_ids=image_ids, + ) + + +class StreamingComponentMetadataPayloadAuthority: + """Resolve the component metadata payload for one viewer batch.""" + + @staticmethod + def payload( + request: StreamingBatchMessageRequest, + prepared_items: StreamingPreparedBatchItems, + ) -> ViewerComponentMetadataPayload: + declared = ViewerComponentMetadataPayload.from_optional_wire_mapping( + request.stream_request.message_extra_payload() + ) + if declared is not None: + return declared + return ViewerComponentMetadataPayload( + component_names_metadata=( + StreamingComponentNamesMetadataCollector.collect( + request.stream_request.source.identity.plate_path, + request.stream_request.source.identity.microscope_handler, + request.resolved_component_names_request(), + ) + ), + component_value_domain=( + StreamingComponentValueDomainAuthority.wire_payload( + request.stream_request, + prepared_items.batch_images, + ) + ), + ) + + +class StreamingBatchMessageBuilder: + """Build complete viewer batch messages from prepared items.""" + + @classmethod + def build( + cls, + backend: "StreamingBackend", + request: StreamingBatchMessageRequest, + ) -> StreamingBuiltBatch: + if len(request.data_list) != len(request.file_paths): + raise ValueError("data_list and file_paths must have the same length") + + prepared_items = StreamingBatchItemPreparationAuthority.prepare( + backend, + request, + ) + + component_metadata_payload = ( + StreamingComponentMetadataPayloadAuthority.payload( + request, + prepared_items, + ) + ) + + display_payload = StreamingDisplayPayloadBuilder.build( + request.stream_request, + request.display_payload_extra, + ) + message = ViewerBatchMessagePayload.from_parts( + images=prepared_items.batch_images, + display_payload=display_payload, + component_metadata=component_metadata_payload, + timestamp=time.time(), + extra=ViewerComponentMetadataPayload.strip_component_metadata( + backend._message_extra(request.stream_request) + ), + ).to_wire_mapping() + + return StreamingBuiltBatch( + message=message, + batch_images=prepared_items.batch_images, + image_ids=prepared_items.image_ids, + ) + + class StreamingBackend(DataSink): """ Abstract base class for ZeroMQ-based streaming backends. @@ -102,8 +525,8 @@ class StreamingBackend(DataSink): """ # Abstract class attributes that subclasses must define - VIEWER_TYPE: str = None - SHM_PREFIX: str = None + VIEWER_TYPE: str + SHM_PREFIX: str # Class attribute: streaming backends only support image array data and ROIs supports_arbitrary_files: bool = False @@ -119,9 +542,9 @@ def requires_filesystem_validation(self) -> bool: def _filter_streamable_files( self, - data_list: List[Any], - file_paths: List[Union[str, Path]], - ) -> tuple[List[Any], List[Union[str, Path]], List[Union[str, Path]]]: + data_list: list[StreamablePayload], + file_paths: list[FilePath], + ) -> tuple[list[StreamablePayload], list[FilePath], list[FilePath]]: """ Filter data to only include files with supported extensions. @@ -157,129 +580,33 @@ def _filter_streamable_files( return filtered_data, filtered_paths, skipped_paths - def __init__(self, transport_config=None): + def __init__(self, transport_config: ZMQConfig = POLYSTORE_ZMQ_CONFIG): """Initialize ZeroMQ and shared memory infrastructure.""" self._publishers = {} self._context = None self._shared_memory_blocks = {} - self._transport_config = transport_config or POLYSTORE_ZMQ_CONFIG + self._transport_config = transport_config - def _parse_component_metadata( + def create_shared_memory_payload( self, - file_path: Union[str, Path], - microscope_handler, - component_metadata: Mapping[str, Any] | None = None, - ) -> dict: - """Parse real source-plane component metadata for one stream item.""" - filename = os.path.basename(str(file_path)) - parsed_metadata = ( - component_metadata - if component_metadata is not None - else microscope_handler.parser.parse_filename(filename) - ) - if parsed_metadata is None: - raise ValueError( - "Streaming component metadata requires explicit component_metadata " - f"or a parser-readable filename; got {filename!r}." + data: StreamablePayload, + file_path: FilePath, + ) -> dict[str, ViewerWireValue]: + block = StreamingSharedMemoryAuthority.create( + StreamingSharedMemoryRequest( + data=data, + item_path=StreamingItemPath(file_path), + shm_prefix=self.SHM_PREFIX, ) - return StreamingComponentMetadata(parsed_metadata).to_payload() - - @staticmethod - def _component_metadata_for_item( - *, - file_path: Union[str, Path], - index: int, - component_metadata: Mapping[str, Any] | None, - component_metadata_by_path: ComponentMetadataByPath, - ) -> Mapping[str, Any] | None: - """Return explicit component metadata for one batch item when provided.""" - if component_metadata_by_path is None: - return component_metadata - - if isinstance(component_metadata_by_path, Mapping): - path = Path(file_path) - for key in (str(file_path), path.as_posix(), path.name): - if key in component_metadata_by_path: - return component_metadata_by_path[key] - return component_metadata - - if index < len(component_metadata_by_path): - return component_metadata_by_path[index] - - return component_metadata - - def _detect_data_type(self, data: Any): - """ - Detect if data is ROI (shapes/points) or image (common for all streaming backends). - - Args: - data: Data to check - - Returns: - StreamingDataType enum value (IMAGE, SHAPES, or POINTS) - """ - is_roi = isinstance(data, list) and len(data) > 0 and isinstance(data[0], ROI) - - if not is_roi: - return StreamingDataType.IMAGE - - # Check if all ROIs contain only PointShape objects (for points layer) - all_points = all( - roi.shapes and all(isinstance(shape, PointShape) for shape in roi.shapes) - for roi in data ) - - return StreamingDataType.POINTS if all_points else StreamingDataType.SHAPES - - def _create_shared_memory(self, data: Any, file_path: Union[str, Path]) -> dict: - """ - Create shared memory for image data (common for all streaming backends). - - Args: - data: Image data to put in shared memory - file_path: Path identifier - - Returns: - Dict with shared memory metadata - """ - np_data = StreamingPayloadMemoryAuthority.to_numpy(data) - - # Create shared memory with hash-based naming to avoid "File name too long" errors - # Hash the timestamp and object ID to create a short, unique name - from multiprocessing import shared_memory, resource_tracker - import hashlib - timestamp = time.time_ns() - obj_id = id(data) - hash_input = f"{obj_id}_{timestamp}" - hash_suffix = hashlib.md5(hash_input.encode()).hexdigest()[:8] - shm_name = f"{self.SHM_PREFIX}{hash_suffix}" - shm = shared_memory.SharedMemory(create=True, size=np_data.nbytes, name=shm_name) - - # Unregister from resource tracker - we manage cleanup manually - # This prevents resource tracker warnings when worker processes exit - # before the viewer has unlinked the shared memory - try: - resource_tracker.unregister(shm._name, "shared_memory") - except Exception: - pass # Ignore errors if already unregistered - - shm_array = np.ndarray(np_data.shape, dtype=np_data.dtype, buffer=shm.buf) - shm_array[:] = np_data[:] - self._shared_memory_blocks[shm_name] = shm - - return { - 'path': str(file_path), - 'shape': np_data.shape, - 'dtype': str(np_data.dtype), - 'shm_name': shm_name, - } + self._shared_memory_blocks[block.payload.shm_name] = block.shared_memory + return block.payload.to_wire_mapping() def _register_with_queue_tracker( self, - port: int, - image_ids: List[str], - transport_mode: TransportMode | None = None, - transport_config=None, + transport_endpoint: ViewerTransportEndpoint, + image_ids: list[str], + transport_config: ZMQConfig, ) -> None: """ Register sent images with queue tracker (common for all streaming backends). @@ -289,169 +616,143 @@ def _register_with_queue_tracker( image_ids: List of image IDs to register """ listener = GlobalAckListener() - transport_config = transport_config or self._transport_config listener.start( port=transport_config.shared_ack_port, - transport_mode=coerce_transport_mode(transport_mode), + transport_mode=transport_endpoint.resolved_transport_mode(), config=transport_config, ) from zmqruntime.queue_tracker import GlobalQueueTrackerRegistry registry = GlobalQueueTrackerRegistry() - tracker = registry.get_or_create_tracker(port, self.VIEWER_TYPE) + tracker = registry.get_or_create_tracker( + transport_endpoint.port, + self.VIEWER_TYPE, + ) for image_id in image_ids: tracker.register_sent(image_id) - def _build_component_modes(self, display_config) -> dict: - return display_config.component_modes() - - def _build_display_config_base(self, display_config, component_modes: dict) -> dict: - return { - "component_modes": component_modes, - "component_order": display_config.COMPONENT_ORDER, - } + def _cleanup_shared_memory_blocks(self, batch_images, unlink: bool = False) -> None: + for img in batch_images: + shm_name = img.get(ViewerBatchItemWireField.SHM_NAME.value) + if shm_name and shm_name in self._shared_memory_blocks: + try: + shm = self._shared_memory_blocks.pop(shm_name) + shm.close() + if unlink: + shm.unlink() + except Exception as e: + logger.warning(f"Failed to cleanup shared memory {shm_name}: {e}") - def _collect_component_names_metadata( + def _prepare_batch_item( self, - plate_path, - microscope_handler, - component_names: List[str] | None = None, - log_prefix: str | None = None, - verbose: bool = False, - ) -> dict: - component_names = component_names or ["channel", "well", "site"] - component_names_metadata = {} - - if not plate_path or not microscope_handler: - if verbose and log_prefix: - if not plate_path: - logger.warning(f"{log_prefix}: No plate_path in kwargs") - if not microscope_handler: - logger.warning(f"{log_prefix}: No microscope_handler") - return component_names_metadata - - try: - for comp_name in component_names: - metadata = microscope_handler.metadata_handler.get_component_values( - plate_path, - comp_name, - ) - if verbose and log_prefix: - logger.info(f"{log_prefix}: Got {comp_name} metadata: {metadata}") - if metadata: - component_names_metadata[comp_name] = metadata - except Exception as e: - if verbose and log_prefix: - logger.warning(f"{log_prefix}: Could not get component metadata: {e}", exc_info=True) + request: StreamingItemPreparationRequest, + ) -> tuple[ViewerWireMapping, str]: + raise NotImplementedError - return component_names_metadata - - def _prepare_batch_items( + def _display_payload_extra( self, - request: StreamingBatchRequest, - ) -> tuple[list[dict], list[str]]: - batch_images = [] - image_ids = [] + stream_request: ViewerStreamRequest, + ) -> ViewerDisplayPayloadExtra: + return EMPTY_DISPLAY_PAYLOAD_EXTRA - for index, (data, file_path) in enumerate( - zip(request.data_list, request.file_paths) - ): - image_id = str(uuid.uuid4()) - image_ids.append(image_id) - - data_type = self._detect_data_type(data) - explicit_component_metadata = self._component_metadata_for_item( - file_path=file_path, - index=index, - component_metadata=request.component_metadata, - component_metadata_by_path=request.component_metadata_by_path, - ) - component_metadata = self._parse_component_metadata( - file_path, - request.microscope_handler, - explicit_component_metadata, - ) - item_data, data_type_value = request.prepare_item(data, file_path, data_type) + def _message_extra( + self, + stream_request: ViewerStreamRequest, + ) -> dict[str, ViewerWireValue]: + return stream_request.message_extra_payload() - batch_images.append( - { - **item_data, - "data_type": data_type_value, - "metadata": component_metadata, - "producer_identity": request.producer_identity.to_payload(), - "image_id": image_id, - } - ) + def _component_names_request( + self, + stream_request: ViewerStreamRequest, + ) -> StreamingComponentNamesRequest: + return StreamingComponentNamesRequest.from_stream_request(stream_request) - return batch_images, image_ids + def _after_batch_message_built( + self, + stream_request: ViewerStreamRequest, + built_batch: StreamingBuiltBatch, + ) -> None: + pass - def _build_batch_message( + def save_batch( self, - data_list: List[Any], - file_paths: List[Union[str, Path]], - microscope_handler, - producer_identity: StreamProducerIdentity, - display_config, - prepare_item: PrepareStreamingItem, - plate_path: Union[str, Path, None] = None, - component_names_kwargs: dict | None = None, - component_metadata: Mapping[str, Any] | None = None, - component_metadata_by_path: ComponentMetadataByPath = None, - display_payload_extra: dict | None = None, - message_extra: dict | None = None, - ) -> tuple[dict, list[dict], list[str]]: - if len(data_list) != len(file_paths): - raise ValueError("data_list and file_paths must have the same length") + data_list: list[StreamablePayload], + file_paths: list[FilePath], + **kwargs, + ) -> None: + """Stream a batch of image or ROI payloads to this viewer.""" + data_list, file_paths, _skipped_paths = self._filter_streamable_files( + data_list, + file_paths, + ) + if not data_list: + return - batch_images, image_ids = self._prepare_batch_items( - StreamingBatchRequest( + stream_request = ViewerStreamBackendKwargs.from_kwargs(kwargs).stream_request + built_batch = StreamingBatchMessageBuilder.build( + self, + StreamingBatchMessageRequest( data_list=data_list, file_paths=file_paths, - microscope_handler=microscope_handler, - producer_identity=producer_identity, - prepare_item=prepare_item, - component_metadata=component_metadata, - component_metadata_by_path=component_metadata_by_path, - ) + stream_request=stream_request, + component_names_request=self._component_names_request(stream_request), + display_payload_extra=self._display_payload_extra(stream_request), + ), ) + self._after_batch_message_built(stream_request, built_batch) - component_modes = self._build_component_modes(display_config) - - component_names_metadata = self._collect_component_names_metadata( - plate_path, - microscope_handler, - **(component_names_kwargs or {}), + transport_config = stream_request.transport_config.resolve( + self._transport_config ) + transport_endpoint = stream_request.viewer_transport + self._register_with_queue_tracker( + transport_endpoint, + built_batch.image_ids, + transport_config=transport_config, + ) + url = transport_endpoint.data_url(transport_config) - display_payload = self._build_display_config_base(display_config, component_modes) - if display_payload_extra: - display_payload.update(display_payload_extra) + if self._context is None: + self._context = zmq.Context() - message = { - "type": "batch", - "images": batch_images, - "display_config": display_payload, - "component_names_metadata": component_names_metadata, - "timestamp": time.time(), - } - if message_extra: - message.update(message_extra) + viewer_name = str(self.VIEWER_TYPE).title() + viewer_label = viewer_name.upper() + ack_policy = STREAMING_TRANSPORT_DEFAULTS.ack_policy(viewer_name) + socket = self._context.socket(zmq.REQ) + ack_policy.apply_socket_options(socket) + socket.connect(url) + time.sleep(0.1) - return message, batch_images, image_ids + try: + logger.info( + "📤 %s BACKEND: Sending batch of %d images to %s on port %s " + "(REQ/REP - blocking until ack)", + viewer_label, + len(built_batch.batch_images), + viewer_name, + transport_endpoint.port, + ) + socket.send_json(built_batch.message) + ack_response = ack_policy.receive( + socket, + lambda: self._cleanup_shared_memory_blocks( + built_batch.batch_images, + unlink=True, + ), + port=transport_endpoint.port, + ) + logger.info( + "✅ %s BACKEND: Received ack from %s: %s", + viewer_label, + viewer_name, + ack_policy.status(ack_response), + ) + finally: + socket.close() - def _cleanup_shared_memory_blocks(self, batch_images, unlink: bool = False) -> None: - for img in batch_images: - shm_name = img.get("shm_name") - if shm_name and shm_name in self._shared_memory_blocks: - try: - shm = self._shared_memory_blocks.pop(shm_name) - shm.close() - if unlink: - shm.unlink() - except Exception as e: - logger.warning(f"Failed to cleanup shared memory {shm_name}: {e}") + self._cleanup_shared_memory_blocks(built_batch.batch_images, unlink=False) - def save(self, data: Any, file_path: Union[str, Path], **kwargs) -> None: + def save(self, data: StreamablePayload | str, file_path: FilePath, **kwargs) -> None: """ Stream single item (common for all streaming backends). diff --git a/src/polystore/streaming/identity.py b/src/polystore/streaming/identity.py index 3d7440f..1c543a6 100644 --- a/src/polystore/streaming/identity.py +++ b/src/polystore/streaming/identity.py @@ -3,7 +3,48 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Any, ClassVar, Mapping +from enum import Enum +from typing import ClassVar, Mapping, Sequence, TypeAlias + + +StreamProducerPayloadValue: TypeAlias = str | int | None +StreamProducerPayloadMapping: TypeAlias = Mapping[str, StreamProducerPayloadValue] +RouteKeyPart: TypeAlias = str | int | float | bool | None + + +class StreamProducerOrigin(str, Enum): + """Nominal stream producer origin values.""" + + PIPELINE = "pipeline" + MANUAL = "manual" + DIRECT = "direct" + + +class FixedStreamProducerIdentityKind(str, Enum): + """Producer identities whose origin and output kind are intentionally equal.""" + + MANUAL = StreamProducerOrigin.MANUAL.value + DIRECT = StreamProducerOrigin.DIRECT.value + + +class StreamProducerIdentityPayload(dict[str, StreamProducerPayloadValue]): + """Wire payload for one stream producer identity.""" + + @classmethod + def from_identity( + cls, + identity: "StreamProducerIdentity", + ) -> "StreamProducerIdentityPayload": + return cls( + origin=identity.origin, + output_kind=identity.output_kind, + output_key=identity.output_key, + step_name=identity.step_name, + pipeline_position=identity.pipeline_position, + step_scope_id=identity.step_scope_id, + invocation_key=identity.invocation_key, + artifact_kind=identity.artifact_kind, + ) @dataclass(frozen=True, slots=True) @@ -18,16 +59,6 @@ class StreamProducerIdentity: step_scope_id: str | None = None invocation_key: str | None = None artifact_kind: str | None = None - PAYLOAD_FIELDS: ClassVar[tuple[str, ...]] = ( - "origin", - "output_kind", - "output_key", - "step_name", - "pipeline_position", - "step_scope_id", - "invocation_key", - "artifact_kind", - ) @classmethod def pipeline_output( @@ -42,7 +73,7 @@ def pipeline_output( ) -> "StreamProducerIdentity": """Build identity for one pipeline-produced stream output.""" return cls( - origin="pipeline", + origin=StreamProducerOrigin.PIPELINE.value, output_kind=output_kind, output_key=output_key, step_name=step_name, @@ -52,40 +83,23 @@ def pipeline_output( ) @classmethod - def manual(cls, output_key: str) -> "StreamProducerIdentity": - """Build identity for one manual viewer action.""" - return cls._fixed_origin_output( - origin="manual", - output_kind="manual", - output_key=output_key, - ) - - @classmethod - def direct(cls, output_key: str) -> "StreamProducerIdentity": - """Build identity for direct in-process display calls.""" - return cls._fixed_origin_output( - origin="direct", - output_kind="direct", - output_key=output_key, - ) - - @classmethod - def _fixed_origin_output( + def fixed_output( cls, - *, - origin: str, - output_kind: str, + kind: FixedStreamProducerIdentityKind, output_key: str, ) -> "StreamProducerIdentity": - """Build identity variants whose origin and output kind match.""" + """Build identity for producer kinds whose origin owns the output kind.""" return cls( - origin=origin, - output_kind=output_kind, + origin=kind.value, + output_kind=kind.value, output_key=output_key, ) @classmethod - def from_payload(cls, payload: "StreamProducerIdentity | Mapping[str, Any]") -> "StreamProducerIdentity": + def from_payload( + cls, + payload: "StreamProducerIdentity | StreamProducerPayloadMapping", + ) -> "StreamProducerIdentity": if isinstance(payload, cls): return payload if not isinstance(payload, Mapping): @@ -93,32 +107,19 @@ def from_payload(cls, payload: "StreamProducerIdentity | Mapping[str, Any]") -> "Stream producer identity must be a mapping or StreamProducerIdentity, " f"got {type(payload).__name__}." ) - missing = [ - field_name - for field_name in ("origin", "output_kind", "output_key") - if payload.get(field_name) in (None, "") - ] - if missing: - raise ValueError(f"Stream producer identity missing required fields: {missing}") - pipeline_position = payload.get("pipeline_position") return cls( - origin=str(payload["origin"]), - output_kind=str(payload["output_kind"]), - output_key=str(payload["output_key"]), - step_name=_optional_str(payload.get("step_name")), - pipeline_position=( - None if pipeline_position is None else int(pipeline_position) - ), - step_scope_id=_optional_str(payload.get("step_scope_id")), - invocation_key=_optional_str(payload.get("invocation_key")), - artifact_kind=_optional_str(payload.get("artifact_kind")), + origin=_required_payload_str(payload, "origin"), + output_kind=_required_payload_str(payload, "output_kind"), + output_key=_required_payload_str(payload, "output_key"), + step_name=_optional_payload_str(payload, "step_name"), + pipeline_position=_optional_payload_int(payload, "pipeline_position"), + step_scope_id=_optional_payload_str(payload, "step_scope_id"), + invocation_key=_optional_payload_str(payload, "invocation_key"), + artifact_kind=_optional_payload_str(payload, "artifact_kind"), ) - def to_payload(self) -> dict[str, Any]: - return { - field_name: getattr(self, field_name) - for field_name in self.PAYLOAD_FIELDS - } + def to_payload(self) -> StreamProducerIdentityPayload: + return StreamProducerIdentityPayload.from_identity(self) def route_parts(self) -> tuple[str, ...]: parts = [ @@ -139,7 +140,44 @@ def route_parts(self) -> tuple[str, ...]: return tuple(parts) -def _optional_str(value: Any) -> str | None: +def _required_payload_str( + payload: StreamProducerPayloadMapping, + field_name: str, +) -> str: + if field_name not in payload: + raise ValueError( + f"Stream producer identity missing required field: {field_name}" + ) + value = payload[field_name] + if value in (None, ""): + raise ValueError( + f"Stream producer identity missing required field: {field_name}" + ) + return str(value) + + +def _optional_payload_str( + payload: StreamProducerPayloadMapping, + field_name: str, +) -> str | None: + if field_name not in payload: + return None + return _optional_str(payload[field_name]) + + +def _optional_payload_int( + payload: StreamProducerPayloadMapping, + field_name: str, +) -> int | None: + if field_name not in payload: + return None + value = payload[field_name] + if value is None: + return None + return int(value) + + +def _optional_str(value: StreamProducerPayloadValue) -> str | None: if value is None: return None text = str(value) @@ -193,11 +231,11 @@ class StreamRouteKeyAuthority: """Own stable key-token projection for viewer route keys.""" @staticmethod - def token(value: object) -> str: + def token(value: RouteKeyPart) -> str: return str(value).replace("/", "_").replace("\\", "_").replace(" ", "_") @classmethod - def join(cls, parts: tuple[object, ...] | list[object]) -> str: + def join(cls, parts: Sequence[RouteKeyPart]) -> str: if not parts: raise ValueError("Cannot build a stream route key with no parts.") return "_".join(cls.token(part) for part in parts) diff --git a/src/polystore/streaming/receivers/__init__.py b/src/polystore/streaming/receivers/__init__.py index 5b57d56..c6d58db 100644 --- a/src/polystore/streaming/receivers/__init__.py +++ b/src/polystore/streaming/receivers/__init__.py @@ -9,6 +9,8 @@ WindowProjectionABC, DebouncedBatchEngine, GroupedWindowItems, + WindowProjectionPayloadProvider, + WindowProjectionSource, group_items_by_component_modes, ) from polystore.streaming.receivers.fiji.fiji_batch_processor import FijiBatchProcessor @@ -23,6 +25,8 @@ "WindowProjectionABC", "DebouncedBatchEngine", "GroupedWindowItems", + "WindowProjectionPayloadProvider", + "WindowProjectionSource", "group_items_by_component_modes", "FijiBatchProcessor", "NapariBatchProcessor", diff --git a/src/polystore/streaming/receivers/core/__init__.py b/src/polystore/streaming/receivers/core/__init__.py index 084d686..786f526 100644 --- a/src/polystore/streaming/receivers/core/__init__.py +++ b/src/polystore/streaming/receivers/core/__init__.py @@ -7,6 +7,8 @@ ) from polystore.streaming.receivers.core.window_projection import ( GroupedWindowItems, + WindowProjectionPayloadProvider, + WindowProjectionSource, group_items_by_component_modes, ) @@ -15,6 +17,7 @@ "WindowProjectionABC", "DebouncedBatchEngine", "GroupedWindowItems", + "WindowProjectionPayloadProvider", + "WindowProjectionSource", "group_items_by_component_modes", ] - diff --git a/src/polystore/streaming/receivers/core/window_projection.py b/src/polystore/streaming/receivers/core/window_projection.py index 6a99995..308df65 100644 --- a/src/polystore/streaming/receivers/core/window_projection.py +++ b/src/polystore/streaming/receivers/core/window_projection.py @@ -2,71 +2,194 @@ from __future__ import annotations +from abc import ABC, abstractmethod +from collections.abc import Mapping, Sequence from dataclasses import dataclass -from typing import Any +from typing import Generic, TypeVar from polystore.streaming.identity import ( StreamProducerDisplayNameAuthority, StreamProducerIdentity, StreamRouteKeyAuthority, ) +from zmqruntime.viewer_protocol import ( + ViewerBatchDisplayPayload, + ViewerBatchItemWireField, + ViewerComponentMode, + ViewerWireMapping, + ViewerWireValue, +) + + +WINDOW_COMPONENT_MODES = ( + ViewerComponentMode.WINDOW, + ViewerComponentMode.CHANNEL, + ViewerComponentMode.SLICE, + ViewerComponentMode.FRAME, +) +WindowLabel = tuple[str, ViewerWireValue] +WindowProjectionItemT = TypeVar("WindowProjectionItemT") +WindowProjectionProviderT = TypeVar( + "WindowProjectionProviderT", + bound="WindowProjectionPayloadProvider", +) + + +class WindowProjectionPayloadProvider(ABC): + """Item that can expose its viewer wire payload for window projection.""" + @abstractmethod + def window_projection_payload(self) -> Mapping[str, ViewerWireValue]: + """Return the wire payload used for component/window projection.""" -@dataclass(frozen=True) -class GroupedWindowItems: + +class WindowItemPayload(dict[str, ViewerWireValue]): + """Normalized wire payload retained for one projected window item.""" + + @classmethod + def from_mapping( + cls, + payload: Mapping[str, ViewerWireValue], + ) -> "WindowItemPayload": + return cls(dict(payload)) + + +@dataclass(frozen=True, slots=True) +class GroupedWindowItems(Generic[WindowProjectionItemT]): """Projection result for a single batch.""" window_components: list[str] channel_components: list[str] slice_components: list[str] frame_components: list[str] - windows: dict[str, list[dict[str, Any]]] - fixed_window_labels: dict[str, list[tuple[str, Any]]] + windows: dict[str, list[WindowProjectionItemT]] + fixed_window_labels: dict[str, tuple[WindowLabel, ...]] + + +@dataclass(frozen=True, slots=True) +class WindowProjectionSource(Generic[WindowProjectionItemT]): + """Validated receiver item source used by window projection.""" + + item: WindowProjectionItemT + payload: Mapping[str, ViewerWireValue] + metadata: ViewerWireMapping + producer: StreamProducerIdentity + + @classmethod + def from_wire_payload( + cls, + payload: Mapping[str, ViewerWireValue], + ) -> "WindowProjectionSource[WindowItemPayload]": + window_payload = WindowItemPayload.from_mapping(payload) + return cls.from_item(window_payload, window_payload) + + @classmethod + def from_wire_payloads( + cls, + payloads: Sequence[Mapping[str, ViewerWireValue]], + ) -> list["WindowProjectionSource[WindowItemPayload]"]: + return [cls.from_wire_payload(payload) for payload in payloads] + + @classmethod + def from_payload_provider( + cls, + item: WindowProjectionProviderT, + ) -> "WindowProjectionSource[WindowProjectionProviderT]": + return cls.from_item(item, item.window_projection_payload()) + + @classmethod + def from_payload_providers( + cls, + items: Sequence[WindowProjectionProviderT], + ) -> list["WindowProjectionSource[WindowProjectionProviderT]"]: + return [cls.from_payload_provider(item) for item in items] + + @classmethod + def from_item( + cls, + item: WindowProjectionItemT, + payload: Mapping[str, ViewerWireValue], + ) -> "WindowProjectionSource[WindowProjectionItemT]": + metadata = cls._required_mapping( + payload, + ViewerBatchItemWireField.METADATA.value, + ) + producer_identity = cls._required_mapping( + payload, + ViewerBatchItemWireField.PRODUCER_IDENTITY.value, + ) + return cls( + item=item, + payload=payload, + metadata=metadata, + producer=StreamProducerIdentity.from_payload(producer_identity), + ) + + @staticmethod + def _required_mapping( + payload: Mapping[str, ViewerWireValue], + field_name: str, + ) -> ViewerWireMapping: + if field_name not in payload: + raise ValueError( + f"Viewer window projection item missing required field {field_name!r}." + ) + value = payload[field_name] + if not isinstance(value, Mapping): + raise TypeError( + f"Viewer window projection item field {field_name!r} must be a mapping, " + f"got {type(value).__name__}." + ) + return dict(value) def group_items_by_component_modes( - items: list[dict[str, Any]], - component_modes: dict[str, str], - component_order: list[str], -) -> GroupedWindowItems: + items: Sequence[WindowProjectionSource[WindowProjectionItemT]], + display_layout: ViewerBatchDisplayPayload, +) -> GroupedWindowItems[WindowProjectionItemT]: """Project items into window groups using declared component modes.""" - result: dict[str, list[str]] = { - "window": [], - "channel": [], - "slice": [], - "frame": [], - } - for comp_name in component_order: - mode = component_modes[comp_name] - result[mode].append(comp_name) - - window_components = result["window"] - channel_components = result["channel"] - slice_components = result["slice"] - frame_components = result["frame"] - - windows: dict[str, list[dict[str, Any]]] = {} - fixed_window_labels: dict[str, list[tuple[str, Any]]] = {} + mode_groups = display_layout.component_mode_groups(WINDOW_COMPONENT_MODES) + mode_groups.require_all_supported("window projection") + + window_components = list( + mode_groups.components_for_mode(ViewerComponentMode.WINDOW) + ) + channel_components = list( + mode_groups.components_for_mode(ViewerComponentMode.CHANNEL) + ) + slice_components = list( + mode_groups.components_for_mode(ViewerComponentMode.SLICE) + ) + frame_components = list( + mode_groups.components_for_mode(ViewerComponentMode.FRAME) + ) + + windows: dict[str, list[WindowProjectionItemT]] = {} + fixed_window_labels: dict[str, tuple[WindowLabel, ...]] = {} for item in items: - meta = item.get("metadata", {}) - producer = StreamProducerIdentity.from_payload(item.get("producer_identity")) - key_parts: list[str] = list(producer.route_parts()) - fixed_labels: list[tuple[str, Any]] = [ - ("producer", StreamProducerDisplayNameAuthority.output_label(producer)) + key_parts: list[str] = list(item.producer.route_parts()) + fixed_labels: list[WindowLabel] = [ + ( + "producer", + StreamProducerDisplayNameAuthority.output_label(item.producer), + ) ] for comp in window_components: - if comp not in meta: - continue - value = meta[comp] + if comp not in item.metadata: + raise ValueError( + f"Viewer window projection item missing window component {comp!r}." + ) + value = item.metadata[comp] key_parts.append(f"{comp}_{value}") fixed_labels.append((comp, value)) window_key = StreamRouteKeyAuthority.join(key_parts) - windows.setdefault(window_key, []).append(item) if window_key not in fixed_window_labels: - fixed_window_labels[window_key] = fixed_labels + windows[window_key] = [] + fixed_window_labels[window_key] = tuple(fixed_labels) + windows[window_key].append(item.item) return GroupedWindowItems( window_components=window_components, diff --git a/src/polystore/streaming/receivers/fiji/fiji_batch_processor.py b/src/polystore/streaming/receivers/fiji/fiji_batch_processor.py index 9b37366..b1a95f8 100644 --- a/src/polystore/streaming/receivers/fiji/fiji_batch_processor.py +++ b/src/polystore/streaming/receivers/fiji/fiji_batch_processor.py @@ -57,6 +57,7 @@ def add_items( display_config: Dict[str, Any], images_dir: str, component_names_metadata: Dict[str, Any], + component_value_domain: Dict[str, Any], ): """ Add items to the batch for processing. @@ -67,11 +68,13 @@ def add_items( display_config: Display configuration dict images_dir: Artifact image directory context. component_names_metadata: Component name mappings for dimension labels + component_value_domain: Component value domains for axis cardinality """ context = { "display_config": display_config, "images_dir": images_dir, "component_names_metadata": component_names_metadata, + "component_value_domain": component_value_domain, "window_key": window_key, } self._engine.enqueue(items=items, context=context) @@ -90,15 +93,17 @@ def _process_batch(self, items: List[Dict[str, Any]], context: Dict[str, Any]) - display_config = context["display_config"] images_dir = context["images_dir"] component_names_metadata = context["component_names_metadata"] + component_value_domain = context["component_value_domain"] window_key = context["window_key"] logger.info( "FijiBatchProcessor: Processing batch of %d items for window '%s'", len(items), window_key, ) - self.fiji_server._process_items_from_batch( + self.fiji_server.batch_processor.process_wire_items( items=items, - display_config_dict=display_config, + display_config=display_config, images_dir=images_dir, component_names_metadata=component_names_metadata, + component_value_domain=component_value_domain, ) diff --git a/src/polystore/streaming/receivers/napari/layer_key.py b/src/polystore/streaming/receivers/napari/layer_key.py index e29ab73..c382853 100644 --- a/src/polystore/streaming/receivers/napari/layer_key.py +++ b/src/polystore/streaming/receivers/napari/layer_key.py @@ -2,37 +2,92 @@ from __future__ import annotations -from typing import Any +from collections.abc import Mapping from polystore.streaming.identity import StreamProducerIdentity, StreamRouteKeyAuthority from polystore.streaming_constants import StreamingDataType +from zmqruntime.viewer_protocol import ( + ViewerBatchDisplayPayload, + ViewerComponentMode, + ViewerDisplayConfigWireField, + ViewerWireMapping, + ViewerWireValue, +) -def normalize_component_layout(display_config: Any) -> tuple[dict[str, str], list[str]]: - """Return canonical (component_modes, component_order) from display config.""" +def normalize_component_layout( + display_config: ViewerBatchDisplayPayload | ViewerWireMapping, +) -> ViewerBatchDisplayPayload: + """Return canonical display layout from a viewer display-config payload.""" + if isinstance(display_config, ViewerBatchDisplayPayload): + return display_config if isinstance(display_config, dict): - component_modes = display_config["component_modes"] - component_order = display_config["component_order"] - return component_modes, component_order + return ViewerBatchDisplayPayload( + component_modes=_required_mapping( + display_config, + ViewerDisplayConfigWireField.COMPONENT_MODES.value, + ), + component_order=_required_sequence( + display_config, + ViewerDisplayConfigWireField.COMPONENT_ORDER.value, + ), + ) - return display_config.component_modes(), list(display_config.COMPONENT_ORDER) + raise TypeError( + "Napari component layout requires ViewerBatchDisplayPayload or mapping, " + f"got {type(display_config).__name__}." + ) def build_route_key( - producer_identity: StreamProducerIdentity | dict[str, Any], - component_info: dict[str, Any], - component_modes: dict[str, str], - component_order: list[str], + producer_identity: StreamProducerIdentity | Mapping[str, ViewerWireValue], + component_info: Mapping[str, ViewerWireValue], + display_layout: ViewerBatchDisplayPayload, data_type: StreamingDataType, ) -> str: """Build hidden route key from producer identity, slice components, and type.""" producer = StreamProducerIdentity.from_payload(producer_identity) route_parts: list[str] = list(producer.route_parts()) - for component in component_order: - mode = component_modes[component] - if mode == "slice" and component in component_info: - route_parts.append(f"{component}_{component_info[component]}") + for component in display_layout.components_for_mode(ViewerComponentMode.SLICE): + if component not in component_info: + raise ValueError( + f"Napari route key missing slice component {component!r}." + ) + route_parts.append(f"{component}_{component_info[component]}") route_key = StreamRouteKeyAuthority.join(route_parts) return f"{route_key}{data_type.napari_layer_suffix}" + + +def _required_mapping( + payload: Mapping[str, ViewerWireValue], + field_name: str, +) -> dict[str, str]: + if field_name not in payload: + raise ValueError(f"Display config missing required field {field_name!r}.") + value = payload[field_name] + if not isinstance(value, Mapping): + raise TypeError( + f"Display config field {field_name!r} must be a mapping, " + f"got {type(value).__name__}." + ) + return { + str(component): str(mode) + for component, mode in value.items() + } + + +def _required_sequence( + payload: Mapping[str, ViewerWireValue], + field_name: str, +) -> list[str]: + if field_name not in payload: + raise ValueError(f"Display config missing required field {field_name!r}.") + value = payload[field_name] + if isinstance(value, str) or not isinstance(value, list | tuple): + raise TypeError( + f"Display config field {field_name!r} must be a sequence, " + f"got {type(value).__name__}." + ) + return [str(component) for component in value] diff --git a/src/polystore/streaming/receivers/napari/napari_batch_processor.py b/src/polystore/streaming/receivers/napari/napari_batch_processor.py index 4abbc90..e6e80d5 100644 --- a/src/polystore/streaming/receivers/napari/napari_batch_processor.py +++ b/src/polystore/streaming/receivers/napari/napari_batch_processor.py @@ -1,18 +1,28 @@ import logging from dataclasses import dataclass -from typing import Any, Dict, List, Optional +from collections.abc import Sequence +from typing import Generic, Optional, TypeVar logger = logging.getLogger(__name__) +NapariBatchItemT = TypeVar("NapariBatchItemT") +NapariDisplayPayloadT = TypeVar("NapariDisplayPayloadT") +NapariComponentNamesMetadataT = TypeVar("NapariComponentNamesMetadataT") @dataclass(frozen=True) -class NapariBatchDisplayRequest: +class NapariBatchDisplayRequest( + Generic[ + NapariBatchItemT, + NapariDisplayPayloadT, + NapariComponentNamesMetadataT, + ] +): """Nominal request for one debounced Napari display update.""" layer_key: str - items: List[Dict[str, Any]] - display_payload: object - component_names_metadata: Dict[str, Any] + items: Sequence[NapariBatchItemT] + display_payload: NapariDisplayPayloadT + component_names_metadata: NapariComponentNamesMetadataT def dispatch_to(self, napari_server) -> None: napari_server.display_layer_batch( @@ -61,9 +71,9 @@ def __init__( def add_items( self, layer_key: str, - items: List[Dict[str, Any]], - display_payload: object, - component_names_metadata: Dict[str, Any], + items: Sequence[NapariBatchItemT], + display_payload: NapariDisplayPayloadT, + component_names_metadata: NapariComponentNamesMetadataT, ): """ Display items already released by the Qt-thread debounce. @@ -71,7 +81,7 @@ def add_items( Args: layer_key: Unique identifier for the layer items: List of items to add (images or ROIs) - display_payload: Viewer-owned display configuration object + display_payload: Viewer-owned display payload object component_names_metadata: Component name mappings for dimension labels """ NapariBatchDisplayRequest( diff --git a/src/polystore/streaming/viewer_transport.py b/src/polystore/streaming/viewer_transport.py index 6947076..5d2207d 100644 --- a/src/polystore/streaming/viewer_transport.py +++ b/src/polystore/streaming/viewer_transport.py @@ -2,125 +2,331 @@ from __future__ import annotations -from dataclasses import dataclass -from typing import Any, Mapping - -import zmq +from abc import ABC, abstractmethod +from collections.abc import Mapping, Sequence +from dataclasses import dataclass, field +from enum import Enum +from pathlib import Path +from typing import ( + ClassVar, + TypeAlias, +) +from polystore.registry import AutoRegisterMeta from polystore.streaming.identity import StreamProducerIdentity +from zmqruntime.config import ZMQConfig +from zmqruntime.viewer_protocol import ( + ViewerAckPolicy, + ViewerTransportEndpoint, + ViewerTransportMode, + ViewerWireMapping, + ViewerWireValue, +) + + +DisplayComponentToken: TypeAlias = str | Enum +DisplayModeToken: TypeAlias = str | Enum | None +ViewerComponentMetadataByPath: TypeAlias = ( + Mapping[str, ViewerWireMapping] + | Sequence[ViewerWireMapping] + | None +) + + +class ViewerDisplayConfigABC(ABC): + """Display-config surface required by viewer streaming backends.""" + + COMPONENT_ORDER: Sequence[DisplayComponentToken] + + @abstractmethod + def component_modes(self) -> Mapping[DisplayComponentToken, DisplayModeToken]: + """Return mode assignments by display component.""" + + +class ViewerFilenameParserABC(ABC): + """Filename parser surface needed by viewer streaming metadata.""" + + @abstractmethod + def parse_filename(self, filename: str) -> ViewerWireMapping | None: + """Return component metadata parsed from a filename.""" + + +class ViewerMetadataHandlerABC(ABC): + """Metadata-handler surface needed by viewer component labels.""" + + @abstractmethod + def get_component_values( + self, + plate_path: str | Path | None, + component_name: str, + ) -> ViewerWireValue: + """Return display-name metadata for one component.""" + + +class ViewerMicroscopeHandlerABC(ABC): + """Microscope-handler surface used by viewer streaming.""" + + parser: ViewerFilenameParserABC + metadata_handler: ViewerMetadataHandlerABC + + +class ViewerTransportConfigSelection(ABC, metaclass=AutoRegisterMeta): + """Nominal selection of the transport config used for one stream request.""" + + __registry_key__ = "registry_key" + registry_key: ClassVar[str | None] = None + + @classmethod + def select(cls, value) -> "ViewerTransportConfigSelection": + if isinstance(value, cls): + return value + for selection_type in cls.__registry__.values(): + if selection_type.accepts(value): + return selection_type.from_raw(value) + raise TypeError( + "transport_config must be a ZMQConfig, " + "ViewerTransportConfigSelection, or None." + ) + + @classmethod + @abstractmethod + def accepts(cls, value) -> bool: + """Return whether this registered selection can adapt the raw value.""" + + @classmethod + @abstractmethod + def from_raw(cls, value) -> "ViewerTransportConfigSelection": + """Adapt the raw value into a concrete transport-config selection.""" + + @abstractmethod + def resolve(self, default_transport_config: ZMQConfig) -> ZMQConfig: + """Return the concrete config for this request.""" + + +@dataclass(frozen=True) +class DefaultViewerTransportConfig(ViewerTransportConfigSelection): + """Use the backend's configured transport settings.""" + + registry_key: ClassVar[str] = "default" + + @classmethod + def accepts(cls, value) -> bool: + return value is None + + @classmethod + def from_raw(cls, value) -> "DefaultViewerTransportConfig": + return cls() + + def resolve(self, default_transport_config: ZMQConfig) -> ZMQConfig: + return default_transport_config + + +@dataclass(frozen=True) +class ExplicitViewerTransportConfig(ViewerTransportConfigSelection): + """Use a caller-supplied transport config for this request.""" + + registry_key: ClassVar[str] = "explicit" + + config: ZMQConfig + + @classmethod + def accepts(cls, value) -> bool: + return isinstance(value, ZMQConfig) + + @classmethod + def from_raw(cls, value) -> "ExplicitViewerTransportConfig": + return cls(value) + + def resolve(self, default_transport_config: ZMQConfig) -> ZMQConfig: + return self.config @dataclass(frozen=True) class ViewerTransportDefaults: """Declared transport defaults shared by viewer streaming backends.""" - host: str = "localhost" ack_timeout_ms: int = 30_000 + def ack_policy(self, viewer_name: str) -> ViewerAckPolicy: + return ViewerAckPolicy( + viewer_name=viewer_name, + timeout_ms=self.ack_timeout_ms, + ) + @dataclass(frozen=True) -class ViewerStreamKwargs: - """Typed view of backend kwargs at the viewer streaming boundary.""" +class ViewerStreamSourceMetadata: + """Component metadata authority for streamed source items.""" - host: str - port: int - transport_mode: Any - transport_config: Any - display_config: Any - microscope_handler: Any - producer_identity: StreamProducerIdentity - plate_path: Any - component_metadata: Any - component_metadata_by_path: Any - images_dir: Any = None + component_metadata: ViewerWireMapping | None = None + component_metadata_by_path: ViewerComponentMetadataByPath = None - @classmethod - def from_kwargs( - cls, - kwargs: Mapping[str, Any], - defaults: ViewerTransportDefaults, - *, - include_images_dir: bool = False, - ) -> "ViewerStreamKwargs": - return cls( - host=ViewerKwargAuthority.value_or_default(kwargs, "host", defaults.host), - port=ViewerKwargAuthority.required(kwargs, "port"), - transport_mode=ViewerKwargAuthority.required(kwargs, "transport_mode"), - transport_config=ViewerKwargAuthority.optional(kwargs, "transport_config"), - display_config=ViewerKwargAuthority.required(kwargs, "display_config"), - microscope_handler=ViewerKwargAuthority.required(kwargs, "microscope_handler"), - producer_identity=StreamProducerIdentity.from_payload( - ViewerKwargAuthority.required(kwargs, "producer_identity") - ), - plate_path=ViewerKwargAuthority.optional(kwargs, "plate_path"), - component_metadata=ViewerKwargAuthority.optional(kwargs, "component_metadata"), - component_metadata_by_path=ViewerKwargAuthority.optional( - kwargs, "component_metadata_by_path" - ), - images_dir=( - ViewerKwargAuthority.optional(kwargs, "images_dir") - if include_images_dir - else None - ), + def __post_init__(self) -> None: + if ( + self.component_metadata is not None + and self.component_metadata_by_path is not None + ): + raise ValueError( + "Viewer stream source context accepts either component_metadata " + "or component_metadata_by_path, not both." + ) + + def component_metadata_for_item( + self, + file_path: str | Path, + index: int, + ) -> dict[str, ViewerWireValue]: + """Return explicit component metadata for one batch item.""" + if self.component_metadata_by_path is None: + return self._batch_component_metadata(file_path) + + if isinstance(self.component_metadata_by_path, Mapping): + return self._mapping_component_metadata( + file_path, + self.component_metadata_by_path, + ) + + if index < len(self.component_metadata_by_path): + return self.component_metadata_by_path[index] + + raise IndexError( + "Viewer stream component_metadata_by_path has no entry for " + f"item {index} at {file_path!r}." ) + def _batch_component_metadata( + self, + file_path: str | Path, + ) -> dict[str, ViewerWireValue]: + if self.component_metadata is None: + raise ValueError( + "Viewer stream item requires explicit component_metadata or " + f"component_metadata_by_path; got no metadata for {file_path!r}." + ) + return self._metadata_payload( + self.component_metadata, + f"batch metadata for {file_path!r}", + ) -class ViewerKwargAuthority: - """Named access policy for viewer backend kwargs.""" + def _mapping_component_metadata( + self, + file_path: str | Path, + metadata_by_path: Mapping[str, ViewerWireMapping], + ) -> dict[str, ViewerWireValue]: + path = Path(file_path) + for key in (str(file_path), path.as_posix(), path.name): + if key in metadata_by_path: + return self._metadata_payload( + metadata_by_path[key], + f"path metadata for {file_path!r}", + ) + raise KeyError( + "Viewer stream component_metadata_by_path has no entry for " + f"{file_path!r}." + ) @staticmethod - def required(kwargs: Mapping[str, Any], name: str) -> Any: - if name not in kwargs: - raise ValueError(f"Viewer streaming kwargs missing required field '{name}'") - return kwargs[name] + def _metadata_payload( + value: ViewerWireMapping, + source_label: str, + ) -> dict[str, ViewerWireValue]: + if not isinstance(value, Mapping): + raise TypeError( + "Viewer stream component metadata must be a mapping " + f"for {source_label}; got {type(value).__name__}." + ) + return dict(value) - @staticmethod - def optional(kwargs: Mapping[str, Any], name: str) -> Any: - if name in kwargs: - return kwargs[name] - return None - @staticmethod - def value_or_default(kwargs: Mapping[str, Any], name: str, default: Any) -> Any: - if name in kwargs: - return kwargs[name] - return default +@dataclass(frozen=True) +class ViewerStreamSourceIdentity: + """Stable source identity shared by all stream batches for one plate.""" + microscope_handler: ViewerMicroscopeHandlerABC + plate_path: str | Path | None = None -class ViewerTransportConfigAuthority: - """Resolve the concrete transport config without implicit truthiness fallback.""" - @staticmethod - def resolve(transport_config: Any, default_transport_config: Any) -> Any: - if transport_config is None: - return default_transport_config - return transport_config +class ViewerStreamKwarg(str, Enum): + """Raw kwarg names accepted at the top-level viewer stream boundary.""" + + STREAM_REQUEST = "stream_request" + + +@dataclass(frozen=True) +class ViewerStreamSource: + """Source provenance and metadata authority for one viewer stream.""" + + identity: ViewerStreamSourceIdentity + metadata: ViewerStreamSourceMetadata = field( + default_factory=ViewerStreamSourceMetadata + ) + + +@dataclass(frozen=True) +class ViewerStreamRequest: + """Typed view of backend kwargs at the viewer streaming boundary.""" + + viewer_transport: ViewerTransportEndpoint + display_config: ViewerDisplayConfigABC + source: ViewerStreamSource + producer_identity: StreamProducerIdentity + transport_config: ViewerTransportConfigSelection = DefaultViewerTransportConfig() + message_extra: ViewerWireMapping | None = None + images_dir: str | None = None + + @property + def host(self) -> str: + return self.viewer_transport.host + + @property + def port(self) -> int: + return self.viewer_transport.port + + @property + def transport_mode(self) -> ViewerTransportMode: + return self.viewer_transport.transport_mode + + def message_extra_payload(self) -> dict[str, ViewerWireValue]: + return ViewerMessageExtraAuthority.payload(self.message_extra) + + +ViewerStreamKwargPayloadMapping: TypeAlias = Mapping[ + str, + "ViewerStreamRequest", +] @dataclass(frozen=True) -class ViewerAckPolicy: - """REQ/REP ack contract for a streaming viewer.""" - - viewer_name: str - timeout_ms: int - - def apply_socket_options(self, socket: zmq.Socket) -> None: - socket.setsockopt(zmq.LINGER, 0) - socket.setsockopt(zmq.SNDTIMEO, self.timeout_ms) - socket.setsockopt(zmq.RCVTIMEO, self.timeout_ms) - - def receive(self, socket: zmq.Socket, cleanup, *, port: int) -> dict[str, Any]: - try: - return socket.recv_json() - except zmq.Again as exc: - cleanup() - raise TimeoutError( - f"Timed out waiting {self.timeout_ms}ms for {self.viewer_name} ack on port {port}" - ) from exc - - def status(self, ack_response: Mapping[str, Any]) -> str: - if "status" not in ack_response: +class ViewerStreamBackendKwargs: + """The only accepted FileManager kwarg payload for viewer stream backends.""" + + stream_request: ViewerStreamRequest + + @classmethod + def from_kwargs( + cls, + kwargs: ViewerStreamKwargPayloadMapping, + ) -> "ViewerStreamBackendKwargs": + expected = frozenset((ViewerStreamKwarg.STREAM_REQUEST.value,)) + actual = frozenset(kwargs) + if actual != expected: raise ValueError( - f"{self.viewer_name} ack response missing required 'status': {ack_response}" + "Viewer stream backends require exactly one kwarg: stream_request" ) - return ack_response["status"] + value = kwargs[ViewerStreamKwarg.STREAM_REQUEST.value] + if not isinstance(value, ViewerStreamRequest): + raise TypeError("stream_request must be a ViewerStreamRequest instance") + return cls(value) + + def to_kwargs(self) -> dict[str, ViewerStreamRequest]: + return {ViewerStreamKwarg.STREAM_REQUEST.value: self.stream_request} + + +class ViewerMessageExtraAuthority: + """Formal boundary for absent caller-supplied viewer message extras.""" + + @staticmethod + def payload(message_extra: Mapping[str, ViewerWireValue] | None) -> dict[str, ViewerWireValue]: + if message_extra is None: + return {} + return dict(message_extra) diff --git a/tests/test_streaming_metadata.py b/tests/test_streaming_metadata.py index 24fbaaa..a90d013 100644 --- a/tests/test_streaming_metadata.py +++ b/tests/test_streaming_metadata.py @@ -3,18 +3,45 @@ import pytest from polystore.streaming._streaming_backend import StreamingBackend -from polystore.streaming._streaming_backend import StreamingBatchRequest +from polystore.streaming._streaming_backend import StreamingBatchItemPreparationAuthority +from polystore.streaming._streaming_backend import StreamingBatchMessageBuilder +from polystore.streaming._streaming_backend import StreamingBatchMessageRequest +from polystore.streaming._streaming_backend import StreamingItemPath +from polystore.streaming._streaming_backend import StreamingItemPreparationRequest from polystore.streaming.identity import StreamProducerIdentity +from polystore.streaming.viewer_transport import ViewerDisplayConfigABC +from polystore.streaming.viewer_transport import ViewerMicroscopeHandlerABC +from polystore.streaming.viewer_transport import ViewerStreamRequest +from polystore.streaming.viewer_transport import ViewerStreamSource +from polystore.streaming.viewer_transport import ViewerStreamSourceIdentity +from polystore.streaming.viewer_transport import ViewerStreamSourceMetadata +from zmqruntime.config import TransportMode +from zmqruntime.viewer_protocol import ViewerAckPolicy +from zmqruntime.viewer_protocol import ViewerTransportEndpoint class MetadataProbeStreamingBackend(StreamingBackend): VIEWER_TYPE = "probe" SHM_PREFIX = "probe_" + def _prepare_batch_item(self, request: StreamingItemPreparationRequest): + return {"path": request.item_path.wire_value, "payload": "ok"}, "image" + def save_batch(self, data_list, file_paths, **kwargs): raise NotImplementedError +class DisplayConfigStub(ViewerDisplayConfigABC): + COMPONENT_ORDER = ("well", "site", "channel") + + def component_modes(self): + return { + "well": "stack", + "site": "stack", + "channel": "stack", + } + + PRODUCER_IDENTITY = StreamProducerIdentity( origin="pipeline", output_kind="main", @@ -23,111 +50,244 @@ def save_batch(self, data_list, file_paths, **kwargs): ) -def test_streaming_component_metadata_rejects_unparsed_artifact_filename() -> None: - backend = MetadataProbeStreamingBackend() - microscope_handler = SimpleNamespace( - parser=SimpleNamespace(parse_filename=lambda _filename: None) +EMPTY_SOURCE_METADATA = ViewerStreamSourceMetadata() + + +def stream_request( + microscope_handler, + source_metadata=EMPTY_SOURCE_METADATA, + *, + plate_path=None, + message_extra=None, +): + return ViewerStreamRequest( + viewer_transport=ViewerTransportEndpoint( + host="127.0.0.1", + port=5555, + transport_mode=TransportMode.TCP, + ), + display_config=DisplayConfigStub(), + source=ViewerStreamSource( + identity=ViewerStreamSourceIdentity( + microscope_handler=microscope_handler, + plate_path=plate_path, + ), + metadata=source_metadata, + ), + producer_identity=PRODUCER_IDENTITY, + message_extra=message_extra, + ) + + +def batch_message_request(data_list, file_paths, viewer_request): + return StreamingBatchMessageRequest( + data_list=data_list, + file_paths=file_paths, + stream_request=viewer_request, + ) + + +def microscope_handler_with_parser(parser): + class MicroscopeHandlerStub(ViewerMicroscopeHandlerABC): + pass + + microscope_handler = MicroscopeHandlerStub() + microscope_handler.parser = parser + microscope_handler.metadata_handler = SimpleNamespace( + get_component_values=lambda _plate_path, _component_name: None ) + return microscope_handler + +def test_streaming_source_metadata_rejects_missing_component_metadata() -> None: with pytest.raises(ValueError, match="explicit component_metadata"): - backend._parse_component_metadata( + ViewerStreamSourceMetadata().component_metadata_for_item( "A01_s001_w1_z001_t001_Nuclei_step3_rois.roi.zip", - microscope_handler, + 0, ) def test_streaming_batch_items_reject_unparsed_artifact_filename() -> None: backend = MetadataProbeStreamingBackend() - microscope_handler = SimpleNamespace( - parser=SimpleNamespace(parse_filename=lambda _filename: None) + microscope_handler = microscope_handler_with_parser( + SimpleNamespace(parse_filename=lambda _filename: None) ) with pytest.raises(ValueError, match="explicit component_metadata"): - backend._prepare_batch_items( - StreamingBatchRequest( - data_list=[object()], - file_paths=["A01_s001_w1_z001_t001_Nuclei_step3_rois.roi.zip"], - microscope_handler=microscope_handler, - producer_identity=PRODUCER_IDENTITY, - prepare_item=lambda _data, _path, _data_type: ({"payload": "ok"}, "image"), + StreamingBatchItemPreparationAuthority.prepare( + backend, + batch_message_request( + [object()], + ["A01_s001_w1_z001_t001_Nuclei_step3_rois.roi.zip"], + stream_request(microscope_handler), ) ) def test_streaming_batch_items_accept_per_path_component_metadata() -> None: backend = MetadataProbeStreamingBackend() - microscope_handler = SimpleNamespace( - parser=SimpleNamespace(parse_filename=lambda _filename: None) + microscope_handler = microscope_handler_with_parser( + SimpleNamespace(parse_filename=lambda _filename: None) ) - batch_images, _image_ids = backend._prepare_batch_items( - StreamingBatchRequest( - data_list=[object()], - file_paths=["A01_s001_w1_z001_t001_Nuclei_step3_rois.roi.zip"], - microscope_handler=microscope_handler, - producer_identity=PRODUCER_IDENTITY, - prepare_item=lambda _data, _path, _data_type: ({"payload": "ok"}, "image"), - component_metadata_by_path={ - "A01_s001_w1_z001_t001_Nuclei_step3_rois.roi.zip": { - "well": "A01", - "site": 1, - "channel": 1, - }, - }, + prepared_items = StreamingBatchItemPreparationAuthority.prepare( + backend, + batch_message_request( + [object()], + ["A01_s001_w1_z001_t001_Nuclei_step3_rois.roi.zip"], + stream_request( + microscope_handler, + ViewerStreamSourceMetadata( + component_metadata_by_path={ + "A01_s001_w1_z001_t001_Nuclei_step3_rois.roi.zip": { + "well": "A01", + "site": 1, + "channel": 1, + }, + } + ), + ), ) ) - assert batch_images[0]["metadata"] == { + assert prepared_items.batch_images[0]["metadata"] == { + "well": "A01", + "site": 1, + "channel": 1, + } + assert ( + prepared_items.batch_images[0]["producer_identity"] + == PRODUCER_IDENTITY.to_payload() + ) + + +def test_streaming_item_component_metadata_preserves_explicit_fields() -> None: + metadata = ViewerStreamSourceMetadata( + component_metadata={"well": "A01", "site": 1, "channel": 1}, + ).component_metadata_for_item( + StreamingItemPath("A01_s001_w1_z001_t001_Nuclei_step3_rois.roi.zip").value, + 0, + ) + + assert metadata == { "well": "A01", "site": 1, "channel": 1, } - assert batch_images[0]["producer_identity"] == PRODUCER_IDENTITY.to_payload() -def test_streaming_component_metadata_preserves_parsed_filename_fields() -> None: +def test_streaming_batch_message_declares_component_value_domain() -> None: backend = MetadataProbeStreamingBackend() - microscope_handler = SimpleNamespace( - parser=SimpleNamespace( - parse_filename=lambda _filename: {"well": "A01", "channel": 1} - ) + microscope_handler = microscope_handler_with_parser( + SimpleNamespace(parse_filename=lambda _filename: None) ) - metadata = backend._parse_component_metadata( - "A01_s001_w1_z001_t001.TIF", - microscope_handler, + built_batch = StreamingBatchMessageBuilder.build( + backend, + batch_message_request( + [object(), object()], + ["A01_s001_w1_z001_t001.tif", "A01_s002_w2_z001_t001.tif"], + stream_request( + microscope_handler, + ViewerStreamSourceMetadata( + component_metadata_by_path=( + {"well": "A01", "site": 1, "channel": 1}, + {"well": "A01", "site": 2, "channel": 2}, + ), + ), + ), + ), ) - assert metadata == {"well": "A01", "channel": 1} + assert built_batch.message["component_value_domain"] == { + "well": ["A01"], + "site": [1, 2], + "channel": [1, 2], + } -def test_streaming_component_metadata_prefers_explicit_metadata() -> None: +def test_streaming_batch_message_honors_declared_component_metadata_payload() -> None: backend = MetadataProbeStreamingBackend() - microscope_handler = SimpleNamespace( - parser=SimpleNamespace(parse_filename=lambda _filename: None) + microscope_handler = microscope_handler_with_parser( + SimpleNamespace(parse_filename=lambda _filename: None) ) - metadata = backend._parse_component_metadata( - "A01_s001_w1_z001_t001_Nuclei_step3_rois.roi.zip", - microscope_handler, - component_metadata={"well": "A01", "site": 1, "channel": 1}, + built_batch = StreamingBatchMessageBuilder.build( + backend, + batch_message_request( + [object()], + ["A01_s001_w1_z001_t001.tif"], + stream_request( + microscope_handler, + ViewerStreamSourceMetadata( + component_metadata={"well": "A01", "site": 1, "channel": 1}, + ), + message_extra={ + "component_value_domain": {"well": ["A01", "B01"]}, + "component_names_metadata": {"well": {"A01": "control"}}, + }, + ), + ), ) - assert metadata == { - "well": "A01", - "site": 1, - "channel": 1, + assert built_batch.message["component_value_domain"] == {"well": ["A01", "B01"]} + assert built_batch.message["component_names_metadata"] == { + "well": {"A01": "control"} } -def test_streaming_component_metadata_rejects_invalid_parser_result() -> None: +def test_streaming_batch_message_rejects_partial_declared_component_metadata_payload() -> None: backend = MetadataProbeStreamingBackend() - microscope_handler = SimpleNamespace( - parser=SimpleNamespace(parse_filename=lambda _filename: ["not", "metadata"]) + microscope_handler = microscope_handler_with_parser( + SimpleNamespace(parse_filename=lambda _filename: None) ) + with pytest.raises(ValueError, match="component_names_metadata"): + StreamingBatchMessageBuilder.build( + backend, + batch_message_request( + [object()], + ["A01_s001_w1_z001_t001.tif"], + stream_request( + microscope_handler, + ViewerStreamSourceMetadata( + component_metadata={"well": "A01", "site": 1, "channel": 1}, + ), + message_extra={"component_value_domain": {"well": ["A01"]}}, + ), + ), + ) + + +def test_streaming_component_metadata_rejects_invalid_explicit_metadata() -> None: with pytest.raises(TypeError, match="must be a mapping"): - backend._parse_component_metadata( - "A01_s001_w1_z001_t001.TIF", - microscope_handler, + ViewerStreamSourceMetadata( + component_metadata=["not", "metadata"], + ).component_metadata_for_item( + StreamingItemPath("A01_s001_w1_z001_t001.TIF").value, + 0, + ) + + +class ViewerAckSocketStub: + def __init__(self, response): + self.response = response + + def recv_json(self): + return self.response + + +def test_viewer_ack_policy_rejects_error_status_and_cleans_up() -> None: + cleanup_calls = [] + policy = ViewerAckPolicy(viewer_name="Napari", timeout_ms=30_000) + + with pytest.raises(RuntimeError, match="Napari rejected stream batch"): + policy.receive( + ViewerAckSocketStub( + {"status": "error", "message": "missing component_value_domain"} + ), + lambda: cleanup_calls.append("cleanup"), + port=5555, ) + + assert cleanup_calls == ["cleanup"] diff --git a/tests/test_streaming_receiver_core.py b/tests/test_streaming_receiver_core.py index be0184b..eb7a76b 100644 --- a/tests/test_streaming_receiver_core.py +++ b/tests/test_streaming_receiver_core.py @@ -5,17 +5,20 @@ from polystore.streaming_constants import StreamingDataType from polystore.streaming.identity import ( + FixedStreamProducerIdentityKind, StreamProducerDisplayNameAuthority, StreamProducerIdentity, ) from polystore.streaming.receivers.core import ( DebouncedBatchEngine, + WindowProjectionSource, group_items_by_component_modes, ) from polystore.streaming.receivers.napari import ( normalize_component_layout, build_route_key, ) +from zmqruntime.viewer_protocol import ViewerBatchDisplayPayload class PipelineProducerFixture: """Nominal producer fixtures for receiver-core tests.""" @@ -83,9 +86,11 @@ def test_group_items_by_component_modes_keys_windows_by_producer_identity() -> N component_order = ["well", "channel"] grouped = group_items_by_component_modes( - items, - component_modes=component_modes, - component_order=component_order, + WindowProjectionSource.from_wire_payloads(items), + display_layout=ViewerBatchDisplayPayload( + component_modes=component_modes, + component_order=component_order, + ), ) assert grouped.window_components == [] @@ -94,13 +99,35 @@ def test_group_items_by_component_modes_keys_windows_by_producer_identity() -> N assert grouped.slice_components == [] assert grouped.fixed_window_labels[ "origin_pipeline_kind_artifact_out_Nuclei_step_2_name_Segment_artifact_object_labels" - ] == [("producer", "3. Segment Nuclei")] + ] == (("producer", "3. Segment Nuclei"),) assert set(grouped.windows) == { "origin_pipeline_kind_artifact_out_Nuclei_step_2_name_Segment_artifact_object_labels", "origin_pipeline_kind_main_out_main_step_1_name_RawLoad", } +def test_group_items_by_component_modes_rejects_missing_metadata() -> None: + producer = PipelineProducerFixture.main_output( + step_name="RawLoad", + pipeline_position=1, + ) + + try: + group_items_by_component_modes( + WindowProjectionSource.from_wire_payloads( + [{"producer_identity": producer.to_payload()}] + ), + display_layout=ViewerBatchDisplayPayload( + component_modes={"well": "window"}, + component_order=["well"], + ), + ) + except ValueError as error: + assert "metadata" in str(error) + else: + raise AssertionError("missing metadata must fail loudly") + + def test_stream_producer_display_name_authority_matches_pipeline_editor_indexing() -> None: main_output = PipelineProducerFixture.main_output( step_name="ConvertObjectsToImage", @@ -112,7 +139,10 @@ def test_stream_producer_display_name_authority_matches_pipeline_editor_indexing pipeline_position=8, artifact_kind="object_labels", ) - manual_output = StreamProducerIdentity.manual("selected_rois") + manual_output = StreamProducerIdentity.fixed_output( + FixedStreamProducerIdentityKind.MANUAL, + "selected_rois", + ) assert ( StreamProducerDisplayNameAuthority.producer_label(main_output) @@ -146,15 +176,19 @@ def test_napari_route_key_builder_uses_producer_slice_components_and_payload_typ key_image = build_route_key( producer_identity=producer, component_info=component_info, - component_modes=component_modes, - component_order=component_order, + display_layout=ViewerBatchDisplayPayload( + component_modes=component_modes, + component_order=component_order, + ), data_type=StreamingDataType.IMAGE, ) key_shapes = build_route_key( producer_identity=producer, component_info=component_info, - component_modes=component_modes, - component_order=component_order, + display_layout=ViewerBatchDisplayPayload( + component_modes=component_modes, + component_order=component_order, + ), data_type=StreamingDataType.SHAPES, ) @@ -162,15 +196,38 @@ def test_napari_route_key_builder_uses_producer_slice_components_and_payload_typ assert key_shapes == "origin_pipeline_kind_artifact_out_Nuclei_step_2_name_Segment_well_A01_site_3_shapes" +def test_napari_route_key_builder_rejects_missing_slice_component() -> None: + producer = PipelineProducerFixture.artifact_output( + output_key="Nuclei", + step_name="Segment", + pipeline_position=2, + ) + + try: + build_route_key( + producer_identity=producer, + component_info={"well": "A01"}, + display_layout=ViewerBatchDisplayPayload( + component_modes={"well": "slice", "site": "slice"}, + component_order=["well", "site"], + ), + data_type=StreamingDataType.IMAGE, + ) + except ValueError as error: + assert "site" in str(error) + else: + raise AssertionError("missing slice component must fail loudly") + + def test_normalize_component_layout_dict_config() -> None: - component_modes, component_order = normalize_component_layout( + display_layout = normalize_component_layout( { "component_modes": {"well": "slice", "channel": "stack"}, "component_order": ["well", "channel"], } ) - assert component_order == ["well", "channel"] - assert component_modes["well"] == "slice" + assert list(display_layout.component_order) == ["well", "channel"] + assert display_layout.component_modes["well"] == "slice" def test_debounced_batch_engine_flush_processes_pending_once() -> None: diff --git a/tests/test_viewer_transport.py b/tests/test_viewer_transport.py new file mode 100644 index 0000000..753de6f --- /dev/null +++ b/tests/test_viewer_transport.py @@ -0,0 +1,151 @@ +import pytest + +from polystore.streaming.identity import StreamProducerIdentity +from polystore.streaming.viewer_transport import ExplicitViewerTransportConfig +from polystore.streaming.viewer_transport import ViewerStreamBackendKwargs +from polystore.streaming.viewer_transport import ViewerStreamKwarg +from polystore.streaming.viewer_transport import ViewerDisplayConfigABC +from polystore.streaming.viewer_transport import ViewerFilenameParserABC +from polystore.streaming.viewer_transport import ViewerMetadataHandlerABC +from polystore.streaming.viewer_transport import ViewerMicroscopeHandlerABC +from polystore.streaming.viewer_transport import ViewerStreamRequest +from polystore.streaming.viewer_transport import ViewerStreamSource +from polystore.streaming.viewer_transport import ViewerStreamSourceIdentity +from polystore.streaming.viewer_transport import ViewerStreamSourceMetadata +from zmqruntime.config import TransportMode, ZMQConfig +from zmqruntime.viewer_protocol import ViewerTransportEndpoint + + +class DisplayConfigFixture(ViewerDisplayConfigABC): + COMPONENT_ORDER = ("well", "site", "channel") + + def component_modes(self): + return { + "well": "stack", + "site": "slice", + "channel": "channel", + } + + +class FilenameParserFixture(ViewerFilenameParserABC): + def parse_filename(self, filename): + return {"filename": filename} + + +class MetadataHandlerFixture(ViewerMetadataHandlerABC): + def get_component_values(self, plate_path, component_name): + return f"{plate_path}:{component_name}" + + +class MicroscopeHandlerFixture(ViewerMicroscopeHandlerABC): + parser = FilenameParserFixture() + metadata_handler = MetadataHandlerFixture() + + +EMPTY_SOURCE_METADATA = ViewerStreamSourceMetadata() + + +def stream_source( + source_metadata=EMPTY_SOURCE_METADATA, + *, + plate_path="/tmp/plate", +): + return ViewerStreamSource( + identity=ViewerStreamSourceIdentity( + microscope_handler=MicroscopeHandlerFixture(), + plate_path=plate_path, + ), + metadata=source_metadata, + ) + + +def required_stream_request(**kwargs): + values = { + "viewer_transport": ViewerTransportEndpoint( + host="127.0.0.1", + port=5555, + transport_mode=TransportMode.TCP, + ), + "display_config": DisplayConfigFixture(), + "source": stream_source(), + "producer_identity": StreamProducerIdentity.pipeline_output( + output_kind="main", + output_key="main", + step_name="IdentifyPrimaryObjects", + pipeline_position=2, + ), + } + values.update(kwargs) + return ViewerStreamRequest(**values) + + +def test_viewer_stream_kwargs_declares_explicit_backend_request() -> None: + stream_kwargs = required_stream_request( + source=stream_source( + ViewerStreamSourceMetadata( + component_metadata_by_path=( + {"well": "A01", "site": 1}, + {"well": "A01", "site": 2}, + ), + ), + plate_path="/tmp/plate", + ), + message_extra={"component_value_domain": {"well": ["A01"]}}, + images_dir="/tmp/images", + ) + + backend_kwargs = ViewerStreamBackendKwargs(stream_kwargs).to_kwargs() + + assert backend_kwargs == {ViewerStreamKwarg.STREAM_REQUEST.value: stream_kwargs} + assert ViewerStreamBackendKwargs.from_kwargs(backend_kwargs).stream_request is stream_kwargs + assert stream_kwargs.host == "127.0.0.1" + assert stream_kwargs.port == 5555 + assert stream_kwargs.transport_mode is TransportMode.TCP + assert stream_kwargs.producer_identity == StreamProducerIdentity.pipeline_output( + output_kind="main", + output_key="main", + step_name="IdentifyPrimaryObjects", + pipeline_position=2, + ) + assert stream_kwargs.source.metadata.component_metadata is None + assert stream_kwargs.source.metadata.component_metadata_by_path == ( + {"well": "A01", "site": 1}, + {"well": "A01", "site": 2}, + ) + default_config = ZMQConfig(default_port=9001) + assert stream_kwargs.transport_config.resolve(default_config) is default_config + + +def test_viewer_stream_kwargs_rejects_mixed_source_metadata() -> None: + with pytest.raises(ValueError, match="either component_metadata"): + ViewerStreamSourceMetadata( + component_metadata={"well": "A01", "site": 1}, + component_metadata_by_path=( + {"well": "A01", "site": 1}, + ), + ) + + +def test_viewer_stream_backend_rejects_flat_kwargs() -> None: + with pytest.raises(ValueError, match="stream_request"): + ViewerStreamBackendKwargs.from_kwargs( + {"display_config": DisplayConfigFixture()} + ) + + +def test_viewer_stream_kwargs_preserves_explicit_transport_config() -> None: + explicit_config = ZMQConfig(shared_ack_port=8111) + default_config = ZMQConfig(shared_ack_port=8222) + + stream_kwargs = required_stream_request( + transport_config=ExplicitViewerTransportConfig(explicit_config) + ) + + assert stream_kwargs.transport_config.resolve(default_config) is explicit_config + + +def test_viewer_stream_backend_rejects_non_request_payload() -> None: + with pytest.raises(TypeError, match="ViewerStreamRequest"): + ViewerStreamBackendKwargs.from_kwargs( + {ViewerStreamKwarg.STREAM_REQUEST.value: DisplayConfigFixture()} + )