From f8487e065ca4eb67b792ee8b76ec8c47923f8bba Mon Sep 17 00:00:00 2001 From: Andy Staples Date: Fri, 26 Jun 2026 11:07:56 -0600 Subject: [PATCH 1/9] Add scheduled tasks feature and delayed entity signals Add a recurring schedule feature (durabletask.scheduled) for parity with durabletask-dotnet, built on durable entities and a helper orchestrator. Enable it on a worker with configure_scheduled_tasks(worker) and manage schedules from the client via ScheduledTaskClient / ScheduleClient (create, describe, list, update, pause, resume, delete). To support the schedule entity's self-rearming, add an optional signal_time parameter to entity, orchestration-context, and client signal_entity methods, and make the in-memory backend honor delayed (scheduled) entity signals. Includes unit tests, in-memory E2E tests, and live DTS E2E tests for both the schedule feature and delayed signals, plus an example and changelog entries. --- CHANGELOG.md | 15 +- durabletask/client.py | 10 +- durabletask/entities/durable_entity.py | 11 +- durabletask/entities/entity_context.py | 16 +- durabletask/internal/client_helpers.py | 6 +- durabletask/internal/helpers.py | 9 +- durabletask/scheduled/__init__.py | 38 +++ durabletask/scheduled/client.py | 158 +++++++++ durabletask/scheduled/exceptions.py | 36 ++ durabletask/scheduled/models.py | 309 ++++++++++++++++++ durabletask/scheduled/orchestrator.py | 36 ++ durabletask/scheduled/registration.py | 21 ++ durabletask/scheduled/schedule_entity.py | 264 +++++++++++++++ durabletask/scheduled/schedule_status.py | 17 + durabletask/scheduled/transitions.py | 47 +++ durabletask/task.py | 7 +- durabletask/testing/in_memory_backend.py | 64 +++- durabletask/worker.py | 10 +- examples/scheduled_tasks.py | 79 +++++ .../entities/test_dts_delayed_signals_e2e.py | 108 ++++++ .../scheduled/__init__.py | 2 + .../scheduled/test_dts_scheduled_e2e.py | 192 +++++++++++ .../entities/test_delayed_signals_e2e.py | 121 +++++++ tests/durabletask/scheduled/__init__.py | 2 + tests/durabletask/scheduled/test_models.py | 139 ++++++++ .../scheduled/test_schedule_entity.py | 176 ++++++++++ .../scheduled/test_scheduled_e2e.py | 221 +++++++++++++ .../durabletask/scheduled/test_transitions.py | 58 ++++ 28 files changed, 2146 insertions(+), 26 deletions(-) create mode 100644 durabletask/scheduled/__init__.py create mode 100644 durabletask/scheduled/client.py create mode 100644 durabletask/scheduled/exceptions.py create mode 100644 durabletask/scheduled/models.py create mode 100644 durabletask/scheduled/orchestrator.py create mode 100644 durabletask/scheduled/registration.py create mode 100644 durabletask/scheduled/schedule_entity.py create mode 100644 durabletask/scheduled/schedule_status.py create mode 100644 durabletask/scheduled/transitions.py create mode 100644 examples/scheduled_tasks.py create mode 100644 tests/durabletask-azuremanaged/entities/test_dts_delayed_signals_e2e.py create mode 100644 tests/durabletask-azuremanaged/scheduled/__init__.py create mode 100644 tests/durabletask-azuremanaged/scheduled/test_dts_scheduled_e2e.py create mode 100644 tests/durabletask/entities/test_delayed_signals_e2e.py create mode 100644 tests/durabletask/scheduled/__init__.py create mode 100644 tests/durabletask/scheduled/test_models.py create mode 100644 tests/durabletask/scheduled/test_schedule_entity.py create mode 100644 tests/durabletask/scheduled/test_scheduled_e2e.py create mode 100644 tests/durabletask/scheduled/test_transitions.py diff --git a/CHANGELOG.md b/CHANGELOG.md index a786ae1c..27492085 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,7 +7,20 @@ adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## Unreleased -N/A +ADDED + +- Added `durabletask.scheduled`, a recurring schedule feature built on durable + entities. Use `configure_scheduled_tasks(worker)` to enable it on a worker, + then manage schedules from the client via `ScheduledTaskClient` (and the + per-schedule `ScheduleClient`). Supports creating, describing, listing, + updating, pausing, resuming, and deleting schedules with configurable + `interval`, `start_at`, `end_at`, and `start_immediately_if_late` options. +- Added an optional `signal_time` parameter to `EntityContext.signal_entity` + and `DurableEntity.signal_entity`, allowing an entity signal to be scheduled + for future delivery. +- Added an optional `signal_time` parameter to `OrchestrationContext.signal_entity` + and to the client `signal_entity` methods (sync and async), allowing entity + signals to be scheduled for future delivery from orchestrations and clients. ## v1.6.0 diff --git a/durabletask/client.py b/durabletask/client.py index 47e711c8..6b920d65 100644 --- a/durabletask/client.py +++ b/durabletask/client.py @@ -719,8 +719,9 @@ def purge_orchestrations_by(self, def signal_entity(self, entity_instance_id: EntityInstanceId, operation_name: str, - input: Any | None = None) -> None: - req = build_signal_entity_req(entity_instance_id, operation_name, input) + input: Any | None = None, + signal_time: datetime | None = None) -> None: + req = build_signal_entity_req(entity_instance_id, operation_name, input, signal_time) self._logger.info(f"Signaling entity '{entity_instance_id}' operation '{operation_name}'.") if self._payload_store is not None: payload_helpers.externalize_payloads( @@ -1199,8 +1200,9 @@ async def purge_orchestrations_by(self, async def signal_entity(self, entity_instance_id: EntityInstanceId, operation_name: str, - input: Any | None = None) -> None: - req = build_signal_entity_req(entity_instance_id, operation_name, input) + input: Any | None = None, + signal_time: datetime | None = None) -> None: + req = build_signal_entity_req(entity_instance_id, operation_name, input, signal_time) self._logger.info(f"Signaling entity '{entity_instance_id}' operation '{operation_name}'.") if self._payload_store is not None: await payload_helpers.externalize_payloads_async( diff --git a/durabletask/entities/durable_entity.py b/durabletask/entities/durable_entity.py index b93b21af..cfb85a8d 100644 --- a/durabletask/entities/durable_entity.py +++ b/durabletask/entities/durable_entity.py @@ -2,6 +2,7 @@ # Licensed under the MIT License. from typing import Any, TypeVar, overload +from datetime import datetime from durabletask.entities.entity_context import EntityContext from durabletask.entities.entity_instance_id import EntityInstanceId @@ -52,7 +53,9 @@ def set_state(self, state: Any): """ self.entity_context.set_state(state) - def signal_entity(self, entity_instance_id: EntityInstanceId, operation: str, input: Any | None = None) -> None: + def signal_entity(self, entity_instance_id: EntityInstanceId, operation: str, + input: Any | None = None, + signal_time: datetime | None = None) -> None: """Signal another entity to perform an operation. Parameters @@ -63,8 +66,12 @@ def signal_entity(self, entity_instance_id: EntityInstanceId, operation: str, in The operation to perform on the entity. input : Any, optional The input to provide to the entity for the operation. + signal_time : datetime, optional + The time at which the signal should be delivered. If None, the signal is + delivered as soon as possible. Use this to schedule a future operation, + for example to have an entity wake itself up at a later time. """ - self.entity_context.signal_entity(entity_instance_id, operation, input) + self.entity_context.signal_entity(entity_instance_id, operation, input, signal_time) def schedule_new_orchestration(self, orchestration_name: str, input: Any | None = None, instance_id: str | None = None) -> str: """Schedule a new orchestration instance. diff --git a/durabletask/entities/entity_context.py b/durabletask/entities/entity_context.py index 03ece715..a00d3df3 100644 --- a/durabletask/entities/entity_context.py +++ b/durabletask/entities/entity_context.py @@ -1,8 +1,10 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +from datetime import datetime from typing import Any, TypeVar, overload import uuid +from google.protobuf import timestamp_pb2 from durabletask.entities.entity_instance_id import EntityInstanceId from durabletask.internal import helpers, shared from durabletask.internal.entity_state_shim import StateShim @@ -83,7 +85,9 @@ def set_state(self, new_state: Any) -> None: """ self._state.set_state(new_state) - def signal_entity(self, entity_instance_id: EntityInstanceId, operation: str, input: Any | None = None) -> None: + def signal_entity(self, entity_instance_id: EntityInstanceId, operation: str, + input: Any | None = None, + signal_time: datetime | None = None) -> None: """Signal another entity to perform an operation. Parameters @@ -94,15 +98,23 @@ def signal_entity(self, entity_instance_id: EntityInstanceId, operation: str, in The operation to perform on the entity. input : Any, optional The input to provide to the entity for the operation. + signal_time : datetime, optional + The time at which the signal should be delivered. If None, the signal is + delivered as soon as possible. Use this to schedule a future operation, + for example to have an entity wake itself up at a later time. """ encoded_input: str | None = shared.to_json(input) if input is not None else None + scheduled_time: timestamp_pb2.Timestamp | None = None + if signal_time is not None: + scheduled_time = timestamp_pb2.Timestamp() + scheduled_time.FromDatetime(signal_time) self._state.add_operation_action( pb.OperationAction( sendSignal=pb.SendSignalAction( instanceId=str(entity_instance_id), name=operation, input=helpers.get_string_value(encoded_input), - scheduledTime=None, + scheduledTime=scheduled_time, requestTime=None, parentTraceContext=None, ) diff --git a/durabletask/internal/client_helpers.py b/durabletask/internal/client_helpers.py index ef27c50c..e500785c 100644 --- a/durabletask/internal/client_helpers.py +++ b/durabletask/internal/client_helpers.py @@ -192,14 +192,16 @@ def build_terminate_req( def build_signal_entity_req( entity_instance_id: EntityInstanceId, operation_name: str, - input: Any | None = None) -> pb.SignalEntityRequest: + input: Any | None = None, + signal_time: datetime | None = None) -> pb.SignalEntityRequest: """Build a SignalEntityRequest for signaling an entity.""" + scheduled_time = helpers.new_timestamp(signal_time) if signal_time is not None else None return pb.SignalEntityRequest( instanceId=str(entity_instance_id), name=operation_name, input=helpers.get_string_value(shared.to_json(input) if input is not None else None), requestId=str(uuid.uuid4()), - scheduledTime=None, + scheduledTime=scheduled_time, parentTraceContext=None, requestTime=helpers.new_timestamp(datetime.now(timezone.utc)) ) diff --git a/durabletask/internal/helpers.py b/durabletask/internal/helpers.py index 2342afdd..273096d6 100644 --- a/durabletask/internal/helpers.py +++ b/durabletask/internal/helpers.py @@ -255,11 +255,16 @@ def new_signal_entity_action(id: int, entity_id: EntityInstanceId, operation: str, encoded_input: str | None, - request_id: str) -> pb.OrchestratorAction: + request_id: str, + scheduled_time: datetime | None = None) -> pb.OrchestratorAction: + scheduled_timestamp: timestamp_pb2.Timestamp | None = None + if scheduled_time is not None: + scheduled_timestamp = timestamp_pb2.Timestamp() + scheduled_timestamp.FromDatetime(scheduled_time) return pb.OrchestratorAction(id=id, sendEntityMessage=pb.SendEntityMessageAction(entityOperationSignaled=pb.EntityOperationSignaledEvent( requestId=request_id, operation=operation, - scheduledTime=None, + scheduledTime=scheduled_timestamp, input=get_string_value(encoded_input), targetInstanceId=get_string_value(str(entity_id)), ))) diff --git a/durabletask/scheduled/__init__.py b/durabletask/scheduled/__init__.py new file mode 100644 index 00000000..36d215a7 --- /dev/null +++ b/durabletask/scheduled/__init__.py @@ -0,0 +1,38 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Scheduled tasks support for the Durable Task SDK. + +This package provides a recurring schedule feature built on top of durable +entities and a helper orchestrator. Register the entity and orchestrator with a +worker via :func:`configure_scheduled_tasks`, then manage schedules from the +client via :class:`ScheduledTaskClient`. +""" + +from durabletask.scheduled.client import ScheduleClient, ScheduledTaskClient +from durabletask.scheduled.exceptions import (ScheduleClientValidationError, + ScheduleError, + ScheduleInvalidTransitionError, + ScheduleNotFoundError) +from durabletask.scheduled.models import (ScheduleCreationOptions, + ScheduleDescription, ScheduleQuery, + ScheduleUpdateOptions) +from durabletask.scheduled.registration import configure_scheduled_tasks +from durabletask.scheduled.schedule_status import ScheduleStatus + +__all__ = [ + "ScheduledTaskClient", + "ScheduleClient", + "ScheduleCreationOptions", + "ScheduleUpdateOptions", + "ScheduleDescription", + "ScheduleQuery", + "ScheduleStatus", + "ScheduleError", + "ScheduleNotFoundError", + "ScheduleClientValidationError", + "ScheduleInvalidTransitionError", + "configure_scheduled_tasks", +] + +PACKAGE_NAME = "durabletask.scheduled" diff --git a/durabletask/scheduled/client.py b/durabletask/scheduled/client.py new file mode 100644 index 00000000..a3474565 --- /dev/null +++ b/durabletask/scheduled/client.py @@ -0,0 +1,158 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import logging +from dataclasses import asdict +from typing import Any + +from durabletask.client import (EntityQuery, OrchestrationStatus, + TaskHubGrpcClient) +from durabletask.entities import EntityInstanceId +from durabletask.internal import shared +from durabletask.scheduled import transitions +from durabletask.scheduled.exceptions import ScheduleNotFoundError +from durabletask.scheduled.models import (ScheduleCreationOptions, + ScheduleDescription, ScheduleQuery, + ScheduleState, ScheduleUpdateOptions) +from durabletask.scheduled.orchestrator import ( + ScheduleOperationRequest, execute_schedule_operation_orchestrator) +from durabletask.scheduled.schedule_entity import (DELETE_OPERATION, + ENTITY_NAME) + +logger = logging.getLogger("durabletask.scheduled") + + +def _parse_state(serialized_state: Any) -> ScheduleState | None: + if serialized_state is None: + return None + data = serialized_state + if isinstance(data, str): + if not data.strip(): + # A deleted (or never-initialized) entity reports empty state. + return None + data = shared.from_json(data) + if isinstance(data, dict): + return ScheduleState.from_dict(data) + return None + + +class ScheduleClient: + """Client for managing a single schedule instance.""" + + def __init__(self, client: TaskHubGrpcClient, schedule_id: str, + *, operation_timeout: float = 60): + if not schedule_id: + raise ValueError("schedule_id cannot be empty.") + self._client = client + self._schedule_id = schedule_id + self._entity_id = EntityInstanceId(ENTITY_NAME, schedule_id) + self._operation_timeout = operation_timeout + + @property + def schedule_id(self) -> str: + """Gets the ID of this schedule.""" + return self._schedule_id + + def _run_operation(self, operation_name: str, input: Any | None = None) -> None: + request = ScheduleOperationRequest( + entity_id=str(self._entity_id), + operation_name=operation_name, + input=input, + ) + instance_id = self._client.schedule_new_orchestration( + execute_schedule_operation_orchestrator, input=asdict(request)) + state = self._client.wait_for_orchestration_completion( + instance_id, timeout=self._operation_timeout) + if state is None or state.runtime_status != OrchestrationStatus.COMPLETED: + failure = state.failure_details if state else None + message = failure.message if failure else "unknown error" + raise RuntimeError( + f"Failed to '{operation_name}' schedule '{self._schedule_id}': {message}") + + def create(self, options: ScheduleCreationOptions) -> None: + """Create or update this schedule with the given configuration.""" + self._run_operation(transitions.CREATE_SCHEDULE, options.to_dict()) + + def update(self, options: ScheduleUpdateOptions) -> None: + """Update this schedule's configuration.""" + self._run_operation(transitions.UPDATE_SCHEDULE, options.to_dict()) + + def pause(self) -> None: + """Pause this schedule.""" + self._run_operation(transitions.PAUSE_SCHEDULE) + + def resume(self) -> None: + """Resume this schedule.""" + self._run_operation(transitions.RESUME_SCHEDULE) + + def delete(self) -> None: + """Delete this schedule.""" + self._run_operation(DELETE_OPERATION) + + def describe(self) -> ScheduleDescription: + """Retrieve the current details of this schedule.""" + metadata = self._client.get_entity(self._entity_id, include_state=True) + if metadata is None: + raise ScheduleNotFoundError(self._schedule_id) + state = _parse_state(metadata.get_state()) + if state is None: + raise ScheduleNotFoundError(self._schedule_id) + return state.to_description() + + +class ScheduledTaskClient: + """Client for managing scheduled tasks in a Durable Task application.""" + + def __init__(self, client: TaskHubGrpcClient, *, operation_timeout: float = 60): + self._client = client + self._operation_timeout = operation_timeout + + def get_schedule_client(self, schedule_id: str) -> ScheduleClient: + """Get a handle to manage a specific schedule.""" + return ScheduleClient(self._client, schedule_id, + operation_timeout=self._operation_timeout) + + def create_schedule(self, options: ScheduleCreationOptions) -> ScheduleClient: + """Create a new schedule and return a client for managing it.""" + schedule_client = self.get_schedule_client(options.schedule_id) + schedule_client.create(options) + return schedule_client + + def get_schedule(self, schedule_id: str) -> ScheduleDescription | None: + """Get a schedule description by ID, or None if it does not exist.""" + try: + return self.get_schedule_client(schedule_id).describe() + except ScheduleNotFoundError: + return None + + def list_schedules(self, filter: ScheduleQuery | None = None) -> list[ScheduleDescription]: + """List schedules matching the given filter criteria.""" + prefix = filter.schedule_id_prefix if filter and filter.schedule_id_prefix else "" + page_size = filter.page_size if filter and filter.page_size else ScheduleQuery.DEFAULT_PAGE_SIZE + query = EntityQuery( + instance_id_starts_with=f"@{ENTITY_NAME}@{prefix}", + include_state=True, + page_size=page_size, + ) + results: list[ScheduleDescription] = [] + for metadata in self._client.get_all_entities(query): + state = _parse_state(metadata.get_state()) + if state is None or state.schedule_configuration is None: + continue + if not self._matches_filter(state, filter): + continue + results.append(state.to_description()) + return results + + @staticmethod + def _matches_filter(state: ScheduleState, filter: ScheduleQuery | None) -> bool: + if filter is None: + return True + if filter.status is not None and state.status != filter.status: + return False + created_at = state.schedule_created_at + if filter.created_from is not None and not (created_at and created_at > filter.created_from): + return False + if filter.created_to is not None and not (created_at and created_at < filter.created_to): + return False + return True diff --git a/durabletask/scheduled/exceptions.py b/durabletask/scheduled/exceptions.py new file mode 100644 index 00000000..f984056e --- /dev/null +++ b/durabletask/scheduled/exceptions.py @@ -0,0 +1,36 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + + +class ScheduleError(Exception): + """Base class for schedule-related errors.""" + + +class ScheduleNotFoundError(ScheduleError): + """Raised when a requested schedule does not exist.""" + + def __init__(self, schedule_id: str): + self.schedule_id = schedule_id + super().__init__(f"Schedule with ID '{schedule_id}' was not found.") + + +class ScheduleClientValidationError(ScheduleError): + """Raised when a schedule operation fails client-side validation.""" + + def __init__(self, schedule_id: str, message: str): + self.schedule_id = schedule_id + super().__init__(f"Validation failed for schedule '{schedule_id}': {message}") + + +class ScheduleInvalidTransitionError(ScheduleError): + """Raised when an operation is not valid for the schedule's current status.""" + + def __init__(self, schedule_id: str, from_status: object, to_status: object, operation_name: str): + self.schedule_id = schedule_id + self.from_status = from_status + self.to_status = to_status + self.operation_name = operation_name + super().__init__( + f"Invalid state transition for schedule '{schedule_id}': operation " + f"'{operation_name}' cannot transition from '{from_status}' to '{to_status}'." + ) diff --git a/durabletask/scheduled/models.py b/durabletask/scheduled/models.py new file mode 100644 index 00000000..2f03416d --- /dev/null +++ b/durabletask/scheduled/models.py @@ -0,0 +1,309 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import uuid +from dataclasses import dataclass +from datetime import datetime, timedelta +from typing import Any + +from durabletask.scheduled.schedule_status import ScheduleStatus + +MINIMUM_INTERVAL = timedelta(seconds=1) + + +def _validate_interval(interval: timedelta) -> timedelta: + if interval <= timedelta(0): + raise ValueError("Interval must be positive.") + if interval < MINIMUM_INTERVAL: + raise ValueError("Interval must be at least 1 second.") + return interval + + +def _to_iso(value: datetime | None) -> str | None: + return value.isoformat() if value is not None else None + + +def _from_iso(value: str | None) -> datetime | None: + return datetime.fromisoformat(value) if value else None + + +def _interval_to_seconds(value: timedelta | None) -> float | None: + return value.total_seconds() if value is not None else None + + +def _interval_from_seconds(value: float | None) -> timedelta | None: + return timedelta(seconds=value) if value is not None else None + + +@dataclass +class ScheduleCreationOptions: + """Options for creating a new schedule.""" + + schedule_id: str + orchestration_name: str + interval: timedelta + orchestration_input: Any | None = None + orchestration_instance_id: str | None = None + start_at: datetime | None = None + end_at: datetime | None = None + start_immediately_if_late: bool = False + + def __post_init__(self): + if not self.schedule_id: + raise ValueError("schedule_id cannot be empty.") + if not self.orchestration_name: + raise ValueError("orchestration_name cannot be empty.") + _validate_interval(self.interval) + + def to_dict(self) -> dict[str, Any]: + return { + "schedule_id": self.schedule_id, + "orchestration_name": self.orchestration_name, + "interval_seconds": self.interval.total_seconds(), + "orchestration_input": self.orchestration_input, + "orchestration_instance_id": self.orchestration_instance_id, + "start_at": _to_iso(self.start_at), + "end_at": _to_iso(self.end_at), + "start_immediately_if_late": self.start_immediately_if_late, + } + + @staticmethod + def from_dict(data: dict[str, Any]) -> "ScheduleCreationOptions": + return ScheduleCreationOptions( + schedule_id=data["schedule_id"], + orchestration_name=data["orchestration_name"], + interval=timedelta(seconds=data["interval_seconds"]), + orchestration_input=data.get("orchestration_input"), + orchestration_instance_id=data.get("orchestration_instance_id"), + start_at=_from_iso(data.get("start_at")), + end_at=_from_iso(data.get("end_at")), + start_immediately_if_late=bool(data.get("start_immediately_if_late", False)), + ) + + +@dataclass +class ScheduleUpdateOptions: + """Options for updating an existing schedule. Only set fields are applied.""" + + orchestration_name: str | None = None + orchestration_input: Any | None = None + orchestration_instance_id: str | None = None + start_at: datetime | None = None + end_at: datetime | None = None + interval: timedelta | None = None + start_immediately_if_late: bool | None = None + + def __post_init__(self): + if self.interval is not None: + _validate_interval(self.interval) + + def to_dict(self) -> dict[str, Any]: + return { + "orchestration_name": self.orchestration_name, + "orchestration_input": self.orchestration_input, + "orchestration_instance_id": self.orchestration_instance_id, + "start_at": _to_iso(self.start_at), + "end_at": _to_iso(self.end_at), + "interval_seconds": _interval_to_seconds(self.interval), + "start_immediately_if_late": self.start_immediately_if_late, + } + + @staticmethod + def from_dict(data: dict[str, Any]) -> "ScheduleUpdateOptions": + return ScheduleUpdateOptions( + orchestration_name=data.get("orchestration_name"), + orchestration_input=data.get("orchestration_input"), + orchestration_instance_id=data.get("orchestration_instance_id"), + start_at=_from_iso(data.get("start_at")), + end_at=_from_iso(data.get("end_at")), + interval=_interval_from_seconds(data.get("interval_seconds")), + start_immediately_if_late=data.get("start_immediately_if_late"), + ) + + +@dataclass +class ScheduleQuery: + """Query parameters for filtering schedules.""" + + DEFAULT_PAGE_SIZE = 100 + + status: ScheduleStatus | None = None + schedule_id_prefix: str | None = None + created_from: datetime | None = None + created_to: datetime | None = None + page_size: int | None = None + + +@dataclass +class ScheduleDescription: + """A read-only snapshot of a schedule's configuration and runtime state.""" + + schedule_id: str + orchestration_name: str | None = None + orchestration_input: Any | None = None + orchestration_instance_id: str | None = None + start_at: datetime | None = None + end_at: datetime | None = None + interval: timedelta | None = None + start_immediately_if_late: bool | None = None + status: ScheduleStatus = ScheduleStatus.UNINITIALIZED + execution_token: str = "" + last_run_at: datetime | None = None + next_run_at: datetime | None = None + + +class ScheduleConfiguration: + """Internal configuration for a scheduled task. Persisted as part of the entity state.""" + + def __init__(self, schedule_id: str, orchestration_name: str, interval: timedelta): + if not schedule_id: + raise ValueError("schedule_id cannot be empty.") + if not orchestration_name: + raise ValueError("orchestration_name cannot be empty.") + self.schedule_id = schedule_id + self.orchestration_name = orchestration_name + self.interval = _validate_interval(interval) + self.orchestration_input: Any | None = None + self.orchestration_instance_id: str | None = None + self.start_at: datetime | None = None + self.end_at: datetime | None = None + self.start_immediately_if_late: bool = False + + @staticmethod + def from_create_options(options: ScheduleCreationOptions) -> "ScheduleConfiguration": + config = ScheduleConfiguration(options.schedule_id, options.orchestration_name, options.interval) + config.orchestration_input = options.orchestration_input + config.orchestration_instance_id = options.orchestration_instance_id + config.start_at = options.start_at + config.end_at = options.end_at + config.start_immediately_if_late = options.start_immediately_if_late + config._validate() + return config + + def update(self, options: ScheduleUpdateOptions) -> set[str]: + """Apply the update options and return the set of changed field names.""" + updated: set[str] = set() + + if options.orchestration_name and options.orchestration_name != self.orchestration_name: + self.orchestration_name = options.orchestration_name + updated.add("orchestration_name") + + if options.orchestration_input is not None and options.orchestration_input != self.orchestration_input: + self.orchestration_input = options.orchestration_input + updated.add("orchestration_input") + + if options.orchestration_instance_id and options.orchestration_instance_id != self.orchestration_instance_id: + self.orchestration_instance_id = options.orchestration_instance_id + updated.add("orchestration_instance_id") + + if options.start_at is not None and options.start_at != self.start_at: + self.start_at = options.start_at + updated.add("start_at") + + if options.end_at is not None and options.end_at != self.end_at: + self.end_at = options.end_at + updated.add("end_at") + + if options.interval is not None and options.interval != self.interval: + self.interval = _validate_interval(options.interval) + updated.add("interval") + + if options.start_immediately_if_late is not None \ + and options.start_immediately_if_late != self.start_immediately_if_late: + self.start_immediately_if_late = options.start_immediately_if_late + updated.add("start_immediately_if_late") + + self._validate() + return updated + + def _validate(self): + if self.start_at is not None and self.end_at is not None and self.start_at > self.end_at: + raise ValueError("start_at cannot be later than end_at.") + + def to_dict(self) -> dict[str, Any]: + return { + "schedule_id": self.schedule_id, + "orchestration_name": self.orchestration_name, + "interval_seconds": self.interval.total_seconds(), + "orchestration_input": self.orchestration_input, + "orchestration_instance_id": self.orchestration_instance_id, + "start_at": _to_iso(self.start_at), + "end_at": _to_iso(self.end_at), + "start_immediately_if_late": self.start_immediately_if_late, + } + + @staticmethod + def from_dict(data: dict[str, Any]) -> "ScheduleConfiguration": + config = ScheduleConfiguration( + data["schedule_id"], + data["orchestration_name"], + timedelta(seconds=data["interval_seconds"]), + ) + config.orchestration_input = data.get("orchestration_input") + config.orchestration_instance_id = data.get("orchestration_instance_id") + config.start_at = _from_iso(data.get("start_at")) + config.end_at = _from_iso(data.get("end_at")) + config.start_immediately_if_late = bool(data.get("start_immediately_if_late", False)) + return config + + +class ScheduleState: + """Internal runtime state for a schedule. Persisted as the entity state.""" + + def __init__(self): + self.status: ScheduleStatus = ScheduleStatus.UNINITIALIZED + self.execution_token: str = _new_token() + self.last_run_at: datetime | None = None + self.next_run_at: datetime | None = None + self.schedule_created_at: datetime | None = None + self.schedule_last_modified_at: datetime | None = None + self.schedule_configuration: ScheduleConfiguration | None = None + + def refresh_execution_token(self): + self.execution_token = _new_token() + + def to_dict(self) -> dict[str, Any]: + return { + "status": self.status.value, + "execution_token": self.execution_token, + "last_run_at": _to_iso(self.last_run_at), + "next_run_at": _to_iso(self.next_run_at), + "schedule_created_at": _to_iso(self.schedule_created_at), + "schedule_last_modified_at": _to_iso(self.schedule_last_modified_at), + "schedule_configuration": + self.schedule_configuration.to_dict() if self.schedule_configuration else None, + } + + @staticmethod + def from_dict(data: dict[str, Any]) -> "ScheduleState": + state = ScheduleState() + state.status = ScheduleStatus(data["status"]) + state.execution_token = data["execution_token"] + state.last_run_at = _from_iso(data.get("last_run_at")) + state.next_run_at = _from_iso(data.get("next_run_at")) + state.schedule_created_at = _from_iso(data.get("schedule_created_at")) + state.schedule_last_modified_at = _from_iso(data.get("schedule_last_modified_at")) + config_data = data.get("schedule_configuration") + state.schedule_configuration = ScheduleConfiguration.from_dict(config_data) if config_data else None + return state + + def to_description(self) -> ScheduleDescription: + config = self.schedule_configuration + return ScheduleDescription( + schedule_id=config.schedule_id if config else "", + orchestration_name=config.orchestration_name if config else None, + orchestration_input=config.orchestration_input if config else None, + orchestration_instance_id=config.orchestration_instance_id if config else None, + start_at=config.start_at if config else None, + end_at=config.end_at if config else None, + interval=config.interval if config else None, + start_immediately_if_late=config.start_immediately_if_late if config else None, + status=self.status, + execution_token=self.execution_token, + last_run_at=self.last_run_at, + next_run_at=self.next_run_at, + ) + + +def _new_token() -> str: + return uuid.uuid4().hex diff --git a/durabletask/scheduled/orchestrator.py b/durabletask/scheduled/orchestrator.py new file mode 100644 index 00000000..111c98e1 --- /dev/null +++ b/durabletask/scheduled/orchestrator.py @@ -0,0 +1,36 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from collections.abc import Generator +from dataclasses import dataclass +from types import SimpleNamespace +from typing import Any + +from durabletask import task +from durabletask.entities import EntityInstanceId + + +@dataclass +class ScheduleOperationRequest: + """Request describing an operation to execute against a schedule entity.""" + + entity_id: str + operation_name: str + input: Any | None = None + + +def execute_schedule_operation_orchestrator( + ctx: task.OrchestrationContext, request: Any) -> Generator[task.Task[Any], Any, Any]: + """Orchestrator that executes a single operation on a schedule entity. + + Client-side write operations route through this orchestrator so callers can await + completion (and surface failures) of the underlying entity operation. + """ + if isinstance(request, SimpleNamespace): + request = vars(request) + if isinstance(request, dict): + request = ScheduleOperationRequest(**request) + + entity_id = EntityInstanceId.parse(request.entity_id) + result = yield ctx.call_entity(entity_id, request.operation_name, request.input) + return result diff --git a/durabletask/scheduled/registration.py b/durabletask/scheduled/registration.py new file mode 100644 index 00000000..858bca76 --- /dev/null +++ b/durabletask/scheduled/registration.py @@ -0,0 +1,21 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from durabletask.worker import TaskHubGrpcWorker +from durabletask.scheduled.orchestrator import \ + execute_schedule_operation_orchestrator +from durabletask.scheduled.schedule_entity import ENTITY_NAME, Schedule + + +def configure_scheduled_tasks(worker: TaskHubGrpcWorker) -> None: + """Register the scheduled tasks entity and orchestrator with a worker. + + Call this before starting the worker to enable scheduled tasks support. + + Parameters + ---------- + worker : TaskHubGrpcWorker + The worker to register the schedule entity and operation orchestrator with. + """ + worker.add_entity(Schedule, ENTITY_NAME) + worker.add_orchestrator(execute_schedule_operation_orchestrator) diff --git a/durabletask/scheduled/schedule_entity.py b/durabletask/scheduled/schedule_entity.py new file mode 100644 index 00000000..38056590 --- /dev/null +++ b/durabletask/scheduled/schedule_entity.py @@ -0,0 +1,264 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import logging +from datetime import datetime, timezone +from types import SimpleNamespace +from typing import Any + +from durabletask.entities import DurableEntity, EntityInstanceId +from durabletask.scheduled import transitions +from durabletask.scheduled.exceptions import ScheduleInvalidTransitionError +from durabletask.scheduled.models import (ScheduleConfiguration, + ScheduleCreationOptions, ScheduleState, + ScheduleUpdateOptions) +from durabletask.scheduled.schedule_status import ScheduleStatus + +ENTITY_NAME = "schedule" +"""The lowercased entity name used for schedule entity instances.""" + +DELETE_OPERATION = "delete" +RUN_SCHEDULE_OPERATION = "run_schedule" + +logger = logging.getLogger("durabletask.scheduled") + + +def _now() -> datetime: + return datetime.now(timezone.utc) + + +def _ensure_aware(value: datetime | None) -> datetime | None: + if value is None: + return None + if value.tzinfo is None: + return value.replace(tzinfo=timezone.utc) + return value + + +def _coerce_options(input: Any, cls: type) -> Any: + """Coerce a round-tripped input (dict/SimpleNamespace) into the given options dataclass.""" + if input is None or isinstance(input, cls): + return input + if isinstance(input, SimpleNamespace): + input = vars(input) + if isinstance(input, dict): + return cls.from_dict(input) + return input + + +class Schedule(DurableEntity): + """Entity that manages the state and execution of a scheduled task. + + The Schedule entity maintains the configuration and runtime state of a scheduled + task, handling operations like creation, updates, pausing/resuming, and executing + the target orchestration according to the defined schedule. + """ + + def _load_state(self) -> ScheduleState: + raw = self.get_state() + if raw is None: + return ScheduleState() + if isinstance(raw, SimpleNamespace): + raw = vars(raw) + if isinstance(raw, dict): + return ScheduleState.from_dict(raw) + raise TypeError(f"Unexpected schedule state type: {type(raw).__name__}") + + def _save_state(self, state: ScheduleState) -> None: + self.set_state(state.to_dict()) + + def _entity_id(self, schedule_id: str) -> EntityInstanceId: + return EntityInstanceId(ENTITY_NAME, schedule_id) + + def _can_transition_to(self, state: ScheduleState, operation_name: str, + target_status: ScheduleStatus) -> bool: + return transitions.is_valid_transition(operation_name, state.status, target_status) + + def create_schedule(self, options: ScheduleCreationOptions) -> None: + """Create a new schedule. If one already exists, update it in place.""" + options = _coerce_options(options, ScheduleCreationOptions) + state = self._load_state() + + if not self._can_transition_to(state, transitions.CREATE_SCHEDULE, ScheduleStatus.ACTIVE): + raise ScheduleInvalidTransitionError( + options.schedule_id if options else "", state.status, ScheduleStatus.ACTIVE, + transitions.CREATE_SCHEDULE) + + already_exists = state.schedule_created_at is not None + state.schedule_configuration = ScheduleConfiguration.from_create_options(options) + + if already_exists: + state.schedule_last_modified_at = _now() + state.refresh_execution_token() + state.next_run_at = None + else: + state.status = ScheduleStatus.ACTIVE + state.schedule_created_at = state.schedule_last_modified_at = _now() + + logger.info(f"Created schedule '{state.schedule_configuration.schedule_id}'.") + self._save_state(state) + + # Signal to run the schedule and let run_schedule decide whether to run now or later. + self.signal_entity( + self._entity_id(state.schedule_configuration.schedule_id), + RUN_SCHEDULE_OPERATION, + state.execution_token, + ) + + def update_schedule(self, options: ScheduleUpdateOptions) -> None: + """Update an existing schedule's configuration.""" + options = _coerce_options(options, ScheduleUpdateOptions) + state = self._load_state() + + if not self._can_transition_to(state, transitions.UPDATE_SCHEDULE, state.status): + raise ScheduleInvalidTransitionError( + state.schedule_configuration.schedule_id if state.schedule_configuration else "", + state.status, state.status, transitions.UPDATE_SCHEDULE) + + if state.schedule_configuration is None: + raise ValueError("Schedule configuration is missing.") + + updated_fields = state.schedule_configuration.update(options) + if not updated_fields: + logger.debug("Schedule configuration is already up to date.") + self._save_state(state) + return + + state.schedule_last_modified_at = _now() + + if updated_fields & {"start_at", "interval", "start_immediately_if_late"}: + state.next_run_at = None + + state.refresh_execution_token() + logger.info(f"Updated schedule '{state.schedule_configuration.schedule_id}'.") + self._save_state(state) + + if state.status == ScheduleStatus.ACTIVE: + self.signal_entity( + self._entity_id(state.schedule_configuration.schedule_id), + RUN_SCHEDULE_OPERATION, + state.execution_token, + ) + + def pause_schedule(self, _: Any = None) -> None: + """Pause the schedule.""" + state = self._load_state() + schedule_id = state.schedule_configuration.schedule_id if state.schedule_configuration else "" + + if not self._can_transition_to(state, transitions.PAUSE_SCHEDULE, ScheduleStatus.PAUSED): + raise ScheduleInvalidTransitionError( + schedule_id, state.status, ScheduleStatus.PAUSED, transitions.PAUSE_SCHEDULE) + + if state.schedule_configuration is None: + raise ValueError("Schedule configuration is missing.") + + state.status = ScheduleStatus.PAUSED + state.next_run_at = None + state.refresh_execution_token() + logger.info(f"Paused schedule '{schedule_id}'.") + self._save_state(state) + + def resume_schedule(self, _: Any = None) -> None: + """Resume a paused schedule.""" + state = self._load_state() + schedule_id = state.schedule_configuration.schedule_id if state.schedule_configuration else "" + + if not self._can_transition_to(state, transitions.RESUME_SCHEDULE, ScheduleStatus.ACTIVE): + raise ScheduleInvalidTransitionError( + schedule_id, state.status, ScheduleStatus.ACTIVE, transitions.RESUME_SCHEDULE) + + if state.schedule_configuration is None: + raise ValueError("Schedule configuration is missing.") + + state.status = ScheduleStatus.ACTIVE + state.next_run_at = None + logger.info(f"Resumed schedule '{schedule_id}'.") + self._save_state(state) + + self.signal_entity( + self._entity_id(schedule_id), + RUN_SCHEDULE_OPERATION, + state.execution_token, + ) + + def run_schedule(self, execution_token: str) -> None: + """Heartbeat operation: starts the target orchestration when due and re-arms itself.""" + state = self._load_state() + + if state.status == ScheduleStatus.UNINITIALIZED: + # This signal is no longer useful since the schedule was deleted. + self.set_state(None) + return + + config = state.schedule_configuration + if config is None: + raise ValueError("Schedule configuration is missing.") + + if execution_token != state.execution_token: + logger.debug(f"Ignoring stale run signal for schedule '{config.schedule_id}'.") + return + + if state.status != ScheduleStatus.ACTIVE: + raise ValueError("Schedule must be in Active status to run.") + + end_at = _ensure_aware(config.end_at) + if end_at is not None and _now() > end_at: + logger.info(f"Schedule '{config.schedule_id}' has passed its end time; deleting.") + state.next_run_at = None + self._save_state(state) + self.signal_entity(self._entity_id(config.schedule_id), DELETE_OPERATION) + return + + state.next_run_at = self._determine_next_run_time(state, config) + + if state.next_run_at <= _now(): + self._start_orchestration(config, state.next_run_at) + state.last_run_at = state.next_run_at + state.next_run_at = None + state.next_run_at = self._determine_next_run_time(state, config) + + self._save_state(state) + + self.signal_entity( + self._entity_id(config.schedule_id), + RUN_SCHEDULE_OPERATION, + state.execution_token, + signal_time=state.next_run_at, + ) + + def delete(self, _: Any = None) -> None: + """Delete the schedule entity.""" + self.set_state(None) + + def _start_orchestration(self, config: ScheduleConfiguration, scheduled_run_time: datetime) -> None: + instance_id = config.orchestration_instance_id + if not instance_id: + instance_id = f"{config.schedule_id}-{scheduled_run_time.isoformat()}" + + logger.info( + f"Starting orchestration '{config.orchestration_name}' with instance ID '{instance_id}' " + f"for schedule '{config.schedule_id}'.") + self.schedule_new_orchestration( + config.orchestration_name, + config.orchestration_input, + instance_id=instance_id, + ) + + def _determine_next_run_time(self, state: ScheduleState, config: ScheduleConfiguration) -> datetime: + if state.next_run_at is not None: + return _ensure_aware(state.next_run_at) # type: ignore[return-value] + + now = _now() + start_time = _ensure_aware(config.start_at) or _ensure_aware(state.schedule_created_at) or now + time_since_start = now - start_time + + # Next run is in the future relative to the start time. + if time_since_start.total_seconds() < 0: + return start_time + + is_first_run = state.last_run_at is None + if is_first_run and config.start_immediately_if_late: + return now + + intervals_elapsed = int(time_since_start.total_seconds() // config.interval.total_seconds()) + return start_time + config.interval * (intervals_elapsed + 1) diff --git a/durabletask/scheduled/schedule_status.py b/durabletask/scheduled/schedule_status.py new file mode 100644 index 00000000..6a21afd2 --- /dev/null +++ b/durabletask/scheduled/schedule_status.py @@ -0,0 +1,17 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from enum import Enum + + +class ScheduleStatus(str, Enum): + """Represents the current status of a schedule.""" + + UNINITIALIZED = "Uninitialized" + """Schedule has not been created.""" + + ACTIVE = "Active" + """Schedule is active and running.""" + + PAUSED = "Paused" + """Schedule is paused.""" diff --git a/durabletask/scheduled/transitions.py b/durabletask/scheduled/transitions.py new file mode 100644 index 00000000..ad2f21ec --- /dev/null +++ b/durabletask/scheduled/transitions.py @@ -0,0 +1,47 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from durabletask.scheduled.schedule_status import ScheduleStatus + +# Operation names used by the Schedule entity. These must match the entity +# method names so that the transition table can be keyed by operation. +CREATE_SCHEDULE = "create_schedule" +UPDATE_SCHEDULE = "update_schedule" +PAUSE_SCHEDULE = "pause_schedule" +RESUME_SCHEDULE = "resume_schedule" + + +def is_valid_transition(operation_name: str, from_status: ScheduleStatus, + target_status: ScheduleStatus) -> bool: + """Check whether a transition to the target status is valid for the given operation. + + Parameters + ---------- + operation_name : str + The name of the operation being performed. + from_status : ScheduleStatus + The current schedule status. + target_status : ScheduleStatus + The status the schedule would transition to. + + Returns + ------- + bool + True if the transition is valid; otherwise False. + """ + if operation_name == CREATE_SCHEDULE: + return ( + (from_status == ScheduleStatus.UNINITIALIZED and target_status == ScheduleStatus.ACTIVE) + or (from_status == ScheduleStatus.ACTIVE and target_status == ScheduleStatus.ACTIVE) + or (from_status == ScheduleStatus.PAUSED and target_status == ScheduleStatus.ACTIVE) + ) + if operation_name == UPDATE_SCHEDULE: + return ( + (from_status == ScheduleStatus.ACTIVE and target_status == ScheduleStatus.ACTIVE) + or (from_status == ScheduleStatus.PAUSED and target_status == ScheduleStatus.PAUSED) + ) + if operation_name == PAUSE_SCHEDULE: + return from_status == ScheduleStatus.ACTIVE and target_status == ScheduleStatus.PAUSED + if operation_name == RESUME_SCHEDULE: + return from_status == ScheduleStatus.PAUSED and target_status == ScheduleStatus.ACTIVE + return False diff --git a/durabletask/task.py b/durabletask/task.py index b1ae27ca..bb39c67d 100644 --- a/durabletask/task.py +++ b/durabletask/task.py @@ -168,7 +168,8 @@ def signal_entity( self, entity_id: EntityInstanceId, operation_name: str, - input: Any = None + input: Any = None, + signal_time: datetime | None = None ) -> None: """Signal an entity function for execution. @@ -180,6 +181,10 @@ def signal_entity( The name of the operation to invoke on the entity. input: TInput | None The optional JSON-serializable input to pass to the entity function. + signal_time: datetime | None + The optional time at which the signal should be delivered. If None, the + signal is delivered as soon as possible. Use this to schedule a future + operation on the entity. """ pass diff --git a/durabletask/testing/in_memory_backend.py b/durabletask/testing/in_memory_backend.py index 0eb761c7..d112ef25 100644 --- a/durabletask/testing/in_memory_backend.py +++ b/durabletask/testing/in_memory_backend.py @@ -919,9 +919,14 @@ def CompleteEntityTask(self, request: pb.EntityBatchResult, context: grpc.Servic try: if action.HasField("sendSignal"): signal = action.sendSignal + scheduled_time = ( + signal.scheduledTime.ToDatetime(tzinfo=timezone.utc) + if signal.HasField("scheduledTime") else None + ) self._signal_entity_internal( signal.instanceId, signal.name, - signal.input.value if signal.input else None + signal.input.value if signal.input else None, + scheduled_time=scheduled_time, ) elif action.HasField("startNewOrchestration"): start_orch = action.startNewOrchestration @@ -945,6 +950,10 @@ def CompleteEntityTask(self, request: pb.EntityBatchResult, context: grpc.Servic def SignalEntity(self, request: pb.SignalEntityRequest, context: grpc.ServicerContext) -> pb.SignalEntityResponse: """Signals an entity, queueing an operation for processing.""" + scheduled_time = ( + request.scheduledTime.ToDatetime(tzinfo=timezone.utc) + if request.HasField("scheduledTime") else None + ) with self._lock: entity_id = request.instanceId entity = self._entities.get(entity_id) @@ -967,8 +976,11 @@ def SignalEntity(self, request: pb.SignalEntityRequest, context: grpc.ServicerCo targetInstanceId=wrappers_pb2.StringValue(value=entity_id), ) ) - entity.pending_operations.append(event) - self._enqueue_entity(entity_id) + if scheduled_time is not None and scheduled_time > datetime.now(timezone.utc): + self._schedule_delayed_entity_operation(entity_id, event, scheduled_time) + else: + entity.pending_operations.append(event) + self._enqueue_entity(entity_id) self._logger.info(f"Signaled entity '{entity_id}' operation '{request.name}'") return pb.SignalEntityResponse() @@ -1604,11 +1616,19 @@ def _process_send_entity_message_action(self, instance: OrchestrationInstance, instance.history.append(history_event) if target_id: - self._queue_entity_operation(target_id, pb.HistoryEvent( + operation_event = pb.HistoryEvent( eventId=-1, timestamp=timestamp_pb2.Timestamp(), entityOperationSignaled=signaled, - )) + ) + scheduled_time = ( + signaled.scheduledTime.ToDatetime(tzinfo=timezone.utc) + if signaled.HasField("scheduledTime") else None + ) + if scheduled_time is not None and scheduled_time > datetime.now(timezone.utc): + self._schedule_delayed_entity_operation(target_id, operation_event, scheduled_time) + else: + self._queue_entity_operation(target_id, operation_event) elif msg.HasField("entityOperationCalled"): called = msg.entityOperationCalled @@ -1748,8 +1768,13 @@ def _queue_entity_operation(self, entity_id: str, event: pb.HistoryEvent): self._enqueue_entity(entity_id) def _signal_entity_internal(self, entity_id: str, operation: str, - input_value: str | None = None): - """Internal method to signal an entity (from entity side-effect actions).""" + input_value: str | None = None, + scheduled_time: datetime | None = None): + """Internal method to signal an entity (from entity side-effect actions). + + If ``scheduled_time`` is set and in the future, the operation is delayed + until that time (mirroring delayed-signal delivery on a real backend). + """ event = pb.HistoryEvent( eventId=-1, timestamp=timestamp_pb2.Timestamp(), @@ -1760,7 +1785,30 @@ def _signal_entity_internal(self, entity_id: str, operation: str, targetInstanceId=wrappers_pb2.StringValue(value=entity_id), ) ) - self._queue_entity_operation(entity_id, event) + if scheduled_time is not None and scheduled_time > datetime.now(timezone.utc): + self._schedule_delayed_entity_operation(entity_id, event, scheduled_time) + else: + self._queue_entity_operation(entity_id, event) + + def _schedule_delayed_entity_operation(self, entity_id: str, event: pb.HistoryEvent, + fire_at: datetime): + """Schedules an entity operation to be enqueued at a future time. + + Uses a background timer thread, mirroring the timer-firing mechanism used + for orchestration timers, so that delayed entity signals are honored by the + in-memory backend. + """ + delay = max(0.0, (fire_at - datetime.now(timezone.utc)).total_seconds()) + + def fire() -> None: + time.sleep(delay) + with self._lock: + if self._shutdown_event.is_set(): + return + self._queue_entity_operation(entity_id, event) + + timer_thread = threading.Thread(target=fire, daemon=True) + timer_thread.start() def _enqueue_entity(self, entity_id: str): """Enqueues an entity for processing.""" diff --git a/durabletask/worker.py b/durabletask/worker.py index aff8f4f1..72f3194b 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -1677,12 +1677,13 @@ def signal_entity( self, entity_id: EntityInstanceId, operation_name: str, - input: Any = None + input: Any = None, + signal_time: datetime | None = None ) -> None: id = self.next_sequence_number() self.signal_entity_function_helper( - id, entity_id, operation_name, input + id, entity_id, operation_name, input, signal_time ) def lock_entities(self, entities: list[EntityInstanceId]) -> task.CompletableTask[EntityLock]: @@ -1822,7 +1823,8 @@ def signal_entity_function_helper( id: int | None, entity_id: EntityInstanceId, operation: str, - input: Any = None + input: Any = None, + signal_time: datetime | None = None ) -> None: if id is None: id = self.next_sequence_number() @@ -1834,7 +1836,7 @@ def signal_entity_function_helper( encoded_input = shared.to_json(input) if input is not None else None - action = ph.new_signal_entity_action(id, entity_id, operation, encoded_input, self.new_uuid()) + action = ph.new_signal_entity_action(id, entity_id, operation, encoded_input, self.new_uuid(), signal_time) self._pending_actions[id] = action def lock_entities_function_helper(self, id: int | None, entities: list[EntityInstanceId]) -> None: diff --git a/examples/scheduled_tasks.py b/examples/scheduled_tasks.py new file mode 100644 index 00000000..841d3969 --- /dev/null +++ b/examples/scheduled_tasks.py @@ -0,0 +1,79 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""End-to-end sample showing how to create and manage a recurring schedule. + +The schedule periodically starts a target orchestration. This sample uses the +Durable Task Scheduler worker/client (compatible with the DTS emulator). +""" +import os +import time +from collections.abc import Generator +from datetime import datetime, timedelta, timezone +from typing import Any + +from azure.identity import DefaultAzureCredential + +from durabletask import task +from durabletask.azuremanaged.client import DurableTaskSchedulerClient +from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker +from durabletask.scheduled import (ScheduledTaskClient, ScheduleCreationOptions, + configure_scheduled_tasks) + + +def greet_orchestrator(ctx: task.OrchestrationContext, name: str) -> Generator[task.Task[Any], Any, Any]: + """The target orchestration that the schedule will start on each run.""" + yield ctx.create_timer(timedelta(seconds=1)) + return f"Hello, {name}!" + + +# Use environment variables if provided, otherwise use default emulator values +taskhub_name = os.getenv("TASKHUB", "default") +endpoint = os.getenv("ENDPOINT", "http://localhost:8080") + +print(f"Using taskhub: {taskhub_name}") +print(f"Using endpoint: {endpoint}") + +secure_channel = endpoint.startswith("https://") +credential = DefaultAzureCredential() if secure_channel else None + +with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=secure_channel, + taskhub=taskhub_name, token_credential=credential) as worker: + worker.add_orchestrator(greet_orchestrator) + # Register the schedule entity and operation orchestrator. + configure_scheduled_tasks(worker) + worker.start() + + client = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=secure_channel, + taskhub=taskhub_name, token_credential=credential) + scheduled_tasks = ScheduledTaskClient(client) + + # Create a schedule that runs the greet orchestration every 5 seconds. + schedule = scheduled_tasks.create_schedule(ScheduleCreationOptions( + schedule_id="greet-every-5s", + orchestration_name=task.get_name(greet_orchestrator), + interval=timedelta(seconds=5), + orchestration_input="world", + start_at=datetime.now(timezone.utc), + start_immediately_if_late=True, + )) + + print(f"Created schedule '{schedule.schedule_id}'.") + print(f"Description: {scheduled_tasks.get_schedule(schedule.schedule_id)}") + + # Let it run for a bit. + time.sleep(12) + + # Pause, then resume the schedule. + schedule.pause() + print("Schedule paused.") + time.sleep(2) + schedule.resume() + print("Schedule resumed.") + + # List all schedules. + print(f"All schedules: {[s.schedule_id for s in scheduled_tasks.list_schedules()]}") + + # Clean up. + schedule.delete() + print("Schedule deleted.") diff --git a/tests/durabletask-azuremanaged/entities/test_dts_delayed_signals_e2e.py b/tests/durabletask-azuremanaged/entities/test_dts_delayed_signals_e2e.py new file mode 100644 index 00000000..0033bd20 --- /dev/null +++ b/tests/durabletask-azuremanaged/entities/test_dts_delayed_signals_e2e.py @@ -0,0 +1,108 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Live DTS E2E tests for delayed (scheduled) entity signals. + +These tests assume a sidecar/emulator is running. Example command: + docker run -i -p 8080:8080 -p 8082:8082 -d mcr.microsoft.com/dts/dts-emulator:latest +""" + +import os +import threading +import time +import uuid +from datetime import datetime, timedelta, timezone + +import pytest + +from durabletask import entities, task +from durabletask.azuremanaged.client import DurableTaskSchedulerClient +from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker + +pytestmark = pytest.mark.dts + +taskhub_name = os.getenv("TASKHUB", "default") +endpoint = os.getenv("ENDPOINT", "http://localhost:8080") +secure_channel = endpoint.startswith("https://") + + +class _Recorder: + def __init__(self): + self._lock = threading.Lock() + self.times: list[datetime] = [] + + def record(self) -> None: + with self._lock: + self.times.append(datetime.now(timezone.utc)) + + @property + def count(self) -> int: + with self._lock: + return len(self.times) + + +recorder = _Recorder() + + +class Recorder(entities.DurableEntity): + def ping(self, _=None): + recorder.record() + + +def setup_function(_): + recorder.times.clear() + + +def _wait_until(predicate, timeout: float = 30, interval: float = 0.5) -> bool: + deadline = time.time() + timeout + while time.time() < deadline: + if predicate(): + return True + time.sleep(interval) + return False + + +def test_client_delayed_signal_is_deferred(): + with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=secure_channel, + taskhub=taskhub_name, token_credential=None) as w: + w.add_entity(Recorder) + w.start() + + c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=secure_channel, + taskhub=taskhub_name, token_credential=None) + entity_id = entities.EntityInstanceId("Recorder", f"client-delay-{uuid.uuid4().hex[:8]}") + sent_at = datetime.now(timezone.utc) + c.signal_entity(entity_id, "ping", signal_time=sent_at + timedelta(seconds=4)) + + time.sleep(1) + assert recorder.count == 0, "delayed signal fired too early" + + assert _wait_until(lambda: recorder.count >= 1) + elapsed = (recorder.times[0] - sent_at).total_seconds() + assert elapsed >= 3, f"signal fired too early ({elapsed}s)" + + +def test_orchestration_delayed_signal_is_deferred(): + def signaling_orchestrator(ctx: task.OrchestrationContext, key: str): + entity_id = entities.EntityInstanceId("Recorder", key) + ctx.signal_entity(entity_id, "ping", signal_time=ctx.current_utc_datetime + timedelta(seconds=4)) + + with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=secure_channel, + taskhub=taskhub_name, token_credential=None) as w: + w.add_orchestrator(signaling_orchestrator) + w.add_entity(Recorder) + w.start() + + c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=secure_channel, + taskhub=taskhub_name, token_credential=None) + key = f"orch-delay-{uuid.uuid4().hex[:8]}" + sent_at = datetime.now(timezone.utc) + instance_id = c.schedule_new_orchestration(signaling_orchestrator, input=key) + c.wait_for_orchestration_completion(instance_id, timeout=30) + + time.sleep(1) + assert recorder.count == 0, "delayed signal fired too early" + + assert _wait_until(lambda: recorder.count >= 1) + elapsed = (recorder.times[0] - sent_at).total_seconds() + assert elapsed >= 3, f"signal fired too early ({elapsed}s)" diff --git a/tests/durabletask-azuremanaged/scheduled/__init__.py b/tests/durabletask-azuremanaged/scheduled/__init__.py new file mode 100644 index 00000000..59e481eb --- /dev/null +++ b/tests/durabletask-azuremanaged/scheduled/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. diff --git a/tests/durabletask-azuremanaged/scheduled/test_dts_scheduled_e2e.py b/tests/durabletask-azuremanaged/scheduled/test_dts_scheduled_e2e.py new file mode 100644 index 00000000..06d3d76a --- /dev/null +++ b/tests/durabletask-azuremanaged/scheduled/test_dts_scheduled_e2e.py @@ -0,0 +1,192 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""End-to-end tests for scheduled tasks against a live Durable Task Scheduler. + +These tests assume a sidecar/emulator is running. Example command: + docker run -i -p 8080:8080 -p 8082:8082 -d mcr.microsoft.com/dts/dts-emulator:latest +""" + +import threading +import time +import uuid +from datetime import datetime, timedelta, timezone + +import pytest + +from durabletask import task +from durabletask.azuremanaged.client import DurableTaskSchedulerClient +from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker +from durabletask.scheduled import (ScheduledTaskClient, ScheduleCreationOptions, + ScheduleQuery, ScheduleStatus, + ScheduleUpdateOptions, + configure_scheduled_tasks) + +import os + +# NOTE: These tests assume a sidecar process is running. Example command: +# docker run -i -p 8080:8080 -p 8082:8082 -d mcr.microsoft.com/dts/dts-emulator:latest +pytestmark = pytest.mark.dts + +taskhub_name = os.getenv("TASKHUB", "default") +endpoint = os.getenv("ENDPOINT", "http://localhost:8080") +secure_channel = endpoint.startswith("https://") + + +class _RunTracker: + def __init__(self): + self._lock = threading.Lock() + self.inputs: list[object] = [] + + def record(self, value: object) -> None: + with self._lock: + self.inputs.append(value) + + @property + def count(self) -> int: + with self._lock: + return len(self.inputs) + + +tracker = _RunTracker() + + +def target_orchestrator(ctx: task.OrchestrationContext, value): + if not ctx.is_replaying: + tracker.record(value) + return value + + +def _make_worker() -> DurableTaskSchedulerWorker: + w = DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=secure_channel, + taskhub=taskhub_name, token_credential=None) + w.add_orchestrator(target_orchestrator) + configure_scheduled_tasks(w) + return w + + +def _make_client() -> DurableTaskSchedulerClient: + return DurableTaskSchedulerClient(host_address=endpoint, secure_channel=secure_channel, + taskhub=taskhub_name, token_credential=None) + + +def _wait_until(predicate, timeout: float = 30, interval: float = 0.5) -> bool: + deadline = time.time() + timeout + while time.time() < deadline: + if predicate(): + return True + time.sleep(interval) + return False + + +def setup_function(_): + tracker.inputs.clear() + + +def test_create_describe_and_run(): + schedule_id = f"sched-run-{uuid.uuid4().hex[:8]}" + with _make_worker() as w: + w.start() + c = _make_client() + scheduled = ScheduledTaskClient(c) + schedule = scheduled.create_schedule(ScheduleCreationOptions( + schedule_id=schedule_id, + orchestration_name=task.get_name(target_orchestrator), + interval=timedelta(seconds=1), + orchestration_input="hello", + start_at=datetime.now(timezone.utc), + start_immediately_if_late=True, + )) + try: + assert _wait_until(lambda: tracker.count >= 1), "target orchestration did not run" + + description = schedule.describe() + assert description.schedule_id == schedule_id + assert description.status == ScheduleStatus.ACTIVE + assert description.last_run_at is not None + finally: + schedule.delete() + + +def test_pause_and_resume(): + schedule_id = f"sched-pause-{uuid.uuid4().hex[:8]}" + with _make_worker() as w: + w.start() + c = _make_client() + scheduled = ScheduledTaskClient(c) + schedule = scheduled.create_schedule(ScheduleCreationOptions( + schedule_id=schedule_id, + orchestration_name=task.get_name(target_orchestrator), + interval=timedelta(seconds=1), + start_at=datetime.now(timezone.utc), + start_immediately_if_late=True, + )) + try: + assert _wait_until(lambda: tracker.count >= 1) + + schedule.pause() + assert schedule.describe().status == ScheduleStatus.PAUSED + + time.sleep(3) + count_after_pause = tracker.count + time.sleep(3) + assert tracker.count == count_after_pause, "schedule kept running while paused" + + schedule.resume() + assert schedule.describe().status == ScheduleStatus.ACTIVE + assert _wait_until(lambda: tracker.count > count_after_pause) + finally: + schedule.delete() + + +def test_update_interval_and_input(): + schedule_id = f"sched-update-{uuid.uuid4().hex[:8]}" + with _make_worker() as w: + w.start() + c = _make_client() + scheduled = ScheduledTaskClient(c) + schedule = scheduled.create_schedule(ScheduleCreationOptions( + schedule_id=schedule_id, + orchestration_name=task.get_name(target_orchestrator), + interval=timedelta(seconds=30), + orchestration_input="before", + start_at=datetime.now(timezone.utc) + timedelta(hours=1), + )) + try: + schedule.update(ScheduleUpdateOptions( + orchestration_input="after", + interval=timedelta(seconds=2), + )) + + description = schedule.describe() + assert description.interval == timedelta(seconds=2) + assert description.orchestration_input == "after" + finally: + schedule.delete() + + +def test_list_and_delete(): + prefix = f"grp-{uuid.uuid4().hex[:8]}-" + with _make_worker() as w: + w.start() + c = _make_client() + scheduled = ScheduledTaskClient(c) + created = [] + for i in range(3): + schedule_id = f"{prefix}{i}" + scheduled.create_schedule(ScheduleCreationOptions( + schedule_id=schedule_id, + orchestration_name=task.get_name(target_orchestrator), + interval=timedelta(seconds=30), + start_at=datetime.now(timezone.utc) + timedelta(hours=1), + )) + created.append(schedule_id) + + try: + listed = scheduled.list_schedules(ScheduleQuery(schedule_id_prefix=prefix)) + assert {s.schedule_id for s in listed} == set(created) + finally: + for schedule_id in created: + scheduled.get_schedule_client(schedule_id).delete() + + assert _wait_until(lambda: scheduled.get_schedule(created[0]) is None) diff --git a/tests/durabletask/entities/test_delayed_signals_e2e.py b/tests/durabletask/entities/test_delayed_signals_e2e.py new file mode 100644 index 00000000..b759e5cd --- /dev/null +++ b/tests/durabletask/entities/test_delayed_signals_e2e.py @@ -0,0 +1,121 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""E2E tests for delayed (scheduled) entity signals via the client and orchestrator.""" + +import threading +import time +from datetime import datetime, timedelta, timezone + +import pytest + +from durabletask import client, entities, task, worker +from durabletask.testing import create_test_backend + +from tests.durabletask._port_utils import find_free_port + +PORT = find_free_port() +HOST = f"localhost:{PORT}" + + +@pytest.fixture(autouse=True) +def backend(): + b = create_test_backend(port=PORT) + yield b + b.stop() + b.reset() + + +class _Recorder: + """Thread-safe recorder of operation invocation times.""" + + def __init__(self): + self._lock = threading.Lock() + self.times: list[datetime] = [] + + def record(self) -> None: + with self._lock: + self.times.append(datetime.now(timezone.utc)) + + @property + def count(self) -> int: + with self._lock: + return len(self.times) + + +recorder = _Recorder() + + +class Recorder(entities.DurableEntity): + def ping(self, _=None): + recorder.record() + + +def setup_function(_): + recorder.times.clear() + + +def _wait_until(predicate, timeout: float = 10, interval: float = 0.1) -> bool: + deadline = time.time() + timeout + while time.time() < deadline: + if predicate(): + return True + time.sleep(interval) + return False + + +def test_client_delayed_signal_is_deferred(): + with worker.TaskHubGrpcWorker(host_address=HOST) as w: + w.add_entity(Recorder) + w.start() + + with client.TaskHubGrpcClient(host_address=HOST) as c: + entity_id = entities.EntityInstanceId("Recorder", "client-delay") + signal_at = datetime.now(timezone.utc) + timedelta(seconds=2) + sent_at = datetime.now(timezone.utc) + c.signal_entity(entity_id, "ping", signal_time=signal_at) + + # Should not have fired immediately. + time.sleep(0.5) + assert recorder.count == 0, "delayed signal fired too early" + + # Should fire around the scheduled time. + assert _wait_until(lambda: recorder.count >= 1) + elapsed = (recorder.times[0] - sent_at).total_seconds() + assert elapsed >= 1.5, f"signal fired too early ({elapsed}s)" + + +def test_client_immediate_signal_still_works(): + with worker.TaskHubGrpcWorker(host_address=HOST) as w: + w.add_entity(Recorder) + w.start() + + with client.TaskHubGrpcClient(host_address=HOST) as c: + entity_id = entities.EntityInstanceId("Recorder", "client-now") + c.signal_entity(entity_id, "ping") + assert _wait_until(lambda: recorder.count >= 1) + + +def test_orchestration_delayed_signal_is_deferred(): + def signaling_orchestrator(ctx: task.OrchestrationContext, _): + entity_id = entities.EntityInstanceId("Recorder", "orch-delay") + signal_at = ctx.current_utc_datetime + timedelta(seconds=2) + ctx.signal_entity(entity_id, "ping", signal_time=signal_at) + + with worker.TaskHubGrpcWorker(host_address=HOST) as w: + w.add_orchestrator(signaling_orchestrator) + w.add_entity(Recorder) + w.start() + + with client.TaskHubGrpcClient(host_address=HOST) as c: + sent_at = datetime.now(timezone.utc) + instance_id = c.schedule_new_orchestration(signaling_orchestrator) + c.wait_for_orchestration_completion(instance_id, timeout=30) + + # Orchestration completes immediately, but the signal should be deferred. + time.sleep(0.5) + assert recorder.count == 0, "delayed signal fired too early" + + assert _wait_until(lambda: recorder.count >= 1) + elapsed = (recorder.times[0] - sent_at).total_seconds() + assert elapsed >= 1.5, f"signal fired too early ({elapsed}s)" diff --git a/tests/durabletask/scheduled/__init__.py b/tests/durabletask/scheduled/__init__.py new file mode 100644 index 00000000..59e481eb --- /dev/null +++ b/tests/durabletask/scheduled/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. diff --git a/tests/durabletask/scheduled/test_models.py b/tests/durabletask/scheduled/test_models.py new file mode 100644 index 00000000..97f3dafc --- /dev/null +++ b/tests/durabletask/scheduled/test_models.py @@ -0,0 +1,139 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Unit tests for scheduled tasks models, validation, and serialization.""" + +from datetime import datetime, timedelta, timezone + +import pytest + +from durabletask.internal import shared +from durabletask.scheduled.models import (ScheduleConfiguration, + ScheduleCreationOptions, + ScheduleState, ScheduleUpdateOptions) +from durabletask.scheduled.schedule_status import ScheduleStatus + + +class TestCreationOptionsValidation: + def test_requires_schedule_id(self): + with pytest.raises(ValueError): + ScheduleCreationOptions(schedule_id="", orchestration_name="orch", + interval=timedelta(seconds=5)) + + def test_requires_orchestration_name(self): + with pytest.raises(ValueError): + ScheduleCreationOptions(schedule_id="s1", orchestration_name="", + interval=timedelta(seconds=5)) + + def test_interval_must_be_at_least_one_second(self): + with pytest.raises(ValueError): + ScheduleCreationOptions(schedule_id="s1", orchestration_name="orch", + interval=timedelta(milliseconds=500)) + + def test_interval_must_be_positive(self): + with pytest.raises(ValueError): + ScheduleCreationOptions(schedule_id="s1", orchestration_name="orch", + interval=timedelta(seconds=-1)) + + def test_valid_options(self): + options = ScheduleCreationOptions(schedule_id="s1", orchestration_name="orch", + interval=timedelta(seconds=30)) + assert options.schedule_id == "s1" + assert options.interval == timedelta(seconds=30) + + +class TestCreationOptionsSerialization: + def test_round_trip_through_json(self): + start = datetime(2026, 1, 1, tzinfo=timezone.utc) + end = datetime(2026, 2, 1, tzinfo=timezone.utc) + options = ScheduleCreationOptions( + schedule_id="s1", orchestration_name="orch", interval=timedelta(minutes=5), + orchestration_input={"key": "value"}, orchestration_instance_id="inst-1", + start_at=start, end_at=end, start_immediately_if_late=True) + + encoded = shared.to_json(options.to_dict()) + decoded = ScheduleCreationOptions.from_dict(shared.from_json(encoded)) + + assert decoded.schedule_id == "s1" + assert decoded.orchestration_name == "orch" + assert decoded.interval == timedelta(minutes=5) + assert decoded.orchestration_input == {"key": "value"} + assert decoded.orchestration_instance_id == "inst-1" + assert decoded.start_at == start + assert decoded.end_at == end + assert decoded.start_immediately_if_late is True + + +class TestUpdateOptions: + def test_interval_validation(self): + with pytest.raises(ValueError): + ScheduleUpdateOptions(interval=timedelta(milliseconds=100)) + + def test_round_trip_through_json(self): + options = ScheduleUpdateOptions(orchestration_name="orch2", interval=timedelta(seconds=10)) + decoded = ScheduleUpdateOptions.from_dict(shared.from_json(shared.to_json(options.to_dict()))) + assert decoded.orchestration_name == "orch2" + assert decoded.interval == timedelta(seconds=10) + assert decoded.start_at is None + + +class TestScheduleConfiguration: + def test_from_create_options_rejects_start_after_end(self): + options = ScheduleCreationOptions( + schedule_id="s1", orchestration_name="orch", interval=timedelta(seconds=5), + start_at=datetime(2026, 2, 1, tzinfo=timezone.utc), + end_at=datetime(2026, 1, 1, tzinfo=timezone.utc)) + with pytest.raises(ValueError): + ScheduleConfiguration.from_create_options(options) + + def test_update_returns_changed_fields(self): + config = ScheduleConfiguration.from_create_options( + ScheduleCreationOptions(schedule_id="s1", orchestration_name="orch", + interval=timedelta(seconds=5))) + changed = config.update(ScheduleUpdateOptions(interval=timedelta(seconds=10), + orchestration_name="orch2")) + assert changed == {"interval", "orchestration_name"} + assert config.interval == timedelta(seconds=10) + assert config.orchestration_name == "orch2" + + def test_update_no_changes_returns_empty(self): + config = ScheduleConfiguration.from_create_options( + ScheduleCreationOptions(schedule_id="s1", orchestration_name="orch", + interval=timedelta(seconds=5))) + changed = config.update(ScheduleUpdateOptions(orchestration_name="orch")) + assert changed == set() + + def test_config_round_trip(self): + config = ScheduleConfiguration.from_create_options( + ScheduleCreationOptions(schedule_id="s1", orchestration_name="orch", + interval=timedelta(seconds=5), + start_at=datetime(2026, 1, 1, tzinfo=timezone.utc))) + restored = ScheduleConfiguration.from_dict(shared.from_json(shared.to_json(config.to_dict()))) + assert restored.schedule_id == "s1" + assert restored.interval == timedelta(seconds=5) + assert restored.start_at == datetime(2026, 1, 1, tzinfo=timezone.utc) + + +class TestScheduleState: + def test_round_trip_and_description(self): + state = ScheduleState() + state.status = ScheduleStatus.ACTIVE + state.schedule_created_at = datetime(2026, 1, 1, tzinfo=timezone.utc) + state.schedule_configuration = ScheduleConfiguration.from_create_options( + ScheduleCreationOptions(schedule_id="s1", orchestration_name="orch", + interval=timedelta(seconds=5))) + + restored = ScheduleState.from_dict(shared.from_json(shared.to_json(state.to_dict()))) + assert restored.status == ScheduleStatus.ACTIVE + assert restored.schedule_created_at == datetime(2026, 1, 1, tzinfo=timezone.utc) + + description = restored.to_description() + assert description.schedule_id == "s1" + assert description.status == ScheduleStatus.ACTIVE + assert description.interval == timedelta(seconds=5) + + def test_refresh_execution_token_changes_token(self): + state = ScheduleState() + original = state.execution_token + state.refresh_execution_token() + assert state.execution_token != original diff --git a/tests/durabletask/scheduled/test_schedule_entity.py b/tests/durabletask/scheduled/test_schedule_entity.py new file mode 100644 index 00000000..c65c6d18 --- /dev/null +++ b/tests/durabletask/scheduled/test_schedule_entity.py @@ -0,0 +1,176 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Unit tests for the Schedule entity behavior, driven through _EntityExecutor.""" + +import logging +from datetime import datetime, timedelta, timezone + +import pytest + +from durabletask.entities import EntityInstanceId +from durabletask.internal import shared +from durabletask.internal.entity_state_shim import StateShim +from durabletask.scheduled.exceptions import ScheduleInvalidTransitionError +from durabletask.scheduled.models import (ScheduleCreationOptions, + ScheduleUpdateOptions) +from durabletask.scheduled.schedule_entity import (ENTITY_NAME, Schedule) +from durabletask.scheduled.schedule_status import ScheduleStatus +from durabletask.worker import _EntityExecutor, _Registry + +SCHEDULE_ID = "sched-1" + + +class Harness: + """Drives Schedule entity operations against a persistent in-memory state.""" + + def __init__(self): + registry = _Registry() + registry.add_entity(Schedule, ENTITY_NAME) + self.executor = _EntityExecutor(registry, logging.getLogger("test")) + self.state = StateShim(None) + self.entity_id = EntityInstanceId(ENTITY_NAME, SCHEDULE_ID) + + def run(self, operation, input=None): + before = len(self.state.get_operation_actions()) + encoded = shared.to_json(input) if input is not None else None + result = self.executor.execute("orch-1", self.entity_id, operation, self.state, encoded) + self.state.commit() + actions = self.state.get_operation_actions()[before:] + return result, actions + + @property + def state_dict(self): + return self.state._current_state # pyright: ignore[reportPrivateUsage] + + @property + def token(self): + return self.state_dict["execution_token"] + + +def _signal_actions(actions): + return [a for a in actions if a.HasField("sendSignal")] + + +def _start_actions(actions): + return [a for a in actions if a.HasField("startNewOrchestration")] + + +def _creation_options(**kwargs): + base = dict(schedule_id=SCHEDULE_ID, orchestration_name="my_orch", interval=timedelta(seconds=30)) + base.update(kwargs) + return ScheduleCreationOptions(**base).to_dict() + + +class TestCreate: + def test_create_activates_and_signals_run(self): + h = Harness() + _, actions = h.run("create_schedule", _creation_options()) + + assert h.state_dict["status"] == ScheduleStatus.ACTIVE.value + assert h.state_dict["schedule_created_at"] is not None + signals = _signal_actions(actions) + assert len(signals) == 1 + assert signals[0].sendSignal.name == "run_schedule" + assert signals[0].sendSignal.instanceId == f"@{ENTITY_NAME}@{SCHEDULE_ID}" + + def test_create_twice_updates_in_place(self): + h = Harness() + h.run("create_schedule", _creation_options()) + first_token = h.token + h.run("create_schedule", _creation_options(interval=timedelta(seconds=60))) + # Re-creation refreshes the execution token. + assert h.token != first_token + assert h.state_dict["status"] == ScheduleStatus.ACTIVE.value + + +class TestPauseResume: + def test_pause_then_resume(self): + h = Harness() + h.run("create_schedule", _creation_options()) + + h.run("pause_schedule") + assert h.state_dict["status"] == ScheduleStatus.PAUSED.value + assert h.state_dict["next_run_at"] is None + + _, actions = h.run("resume_schedule") + assert h.state_dict["status"] == ScheduleStatus.ACTIVE.value + assert len(_signal_actions(actions)) == 1 + + def test_pause_when_not_active_raises(self): + h = Harness() + h.run("create_schedule", _creation_options()) + h.run("pause_schedule") + with pytest.raises(ScheduleInvalidTransitionError): + h.run("pause_schedule") + + +class TestUpdate: + def test_update_changes_config_and_resignals(self): + h = Harness() + h.run("create_schedule", _creation_options()) + _, actions = h.run("update_schedule", + ScheduleUpdateOptions(interval=timedelta(seconds=120)).to_dict()) + assert abs(h.state_dict["schedule_configuration"]["interval_seconds"] - 120) < 0.001 + assert len(_signal_actions(actions)) == 1 + + def test_update_no_change_does_not_signal(self): + h = Harness() + h.run("create_schedule", _creation_options()) + _, actions = h.run("update_schedule", + ScheduleUpdateOptions(orchestration_name="my_orch").to_dict()) + assert len(_signal_actions(actions)) == 0 + + +class TestRunSchedule: + def test_runs_orchestration_when_due_and_rearms(self): + h = Harness() + past = datetime.now(timezone.utc) - timedelta(hours=1) + h.run("create_schedule", _creation_options(start_at=past, start_immediately_if_late=True)) + + _, actions = h.run("run_schedule", h.token) + + starts = _start_actions(actions) + assert len(starts) == 1 + assert starts[0].startNewOrchestration.name == "my_orch" + assert h.state_dict["last_run_at"] is not None + + # Re-arm signal should carry a future scheduled time. + signals = _signal_actions(actions) + assert len(signals) == 1 + assert signals[0].sendSignal.HasField("scheduledTime") + + def test_ignores_stale_token(self): + h = Harness() + h.run("create_schedule", _creation_options()) + _, actions = h.run("run_schedule", "stale-token") + assert len(_start_actions(actions)) == 0 + assert len(_signal_actions(actions)) == 0 + + def test_future_start_does_not_run_yet(self): + h = Harness() + future = datetime.now(timezone.utc) + timedelta(days=1) + h.run("create_schedule", _creation_options(start_at=future)) + _, actions = h.run("run_schedule", h.token) + assert len(_start_actions(actions)) == 0 + # Still re-arms with a future scheduled signal. + signals = _signal_actions(actions) + assert len(signals) == 1 + assert signals[0].sendSignal.HasField("scheduledTime") + + def test_past_end_time_deletes(self): + h = Harness() + start = datetime.now(timezone.utc) - timedelta(hours=2) + end = datetime.now(timezone.utc) - timedelta(hours=1) + h.run("create_schedule", _creation_options(start_at=start, end_at=end)) + _, actions = h.run("run_schedule", h.token) + delete_signals = [a for a in _signal_actions(actions) if a.sendSignal.name == "delete"] + assert len(delete_signals) == 1 + + +class TestDelete: + def test_delete_clears_state(self): + h = Harness() + h.run("create_schedule", _creation_options()) + h.run("delete") + assert h.state_dict is None diff --git a/tests/durabletask/scheduled/test_scheduled_e2e.py b/tests/durabletask/scheduled/test_scheduled_e2e.py new file mode 100644 index 00000000..911cfa64 --- /dev/null +++ b/tests/durabletask/scheduled/test_scheduled_e2e.py @@ -0,0 +1,221 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""End-to-end tests for scheduled tasks against the in-memory backend.""" + +import threading +import time +from datetime import datetime, timedelta, timezone + +import pytest + +from durabletask import client, task, worker +from durabletask.scheduled import (ScheduledTaskClient, ScheduleCreationOptions, + ScheduleQuery, ScheduleStatus, + ScheduleUpdateOptions, + configure_scheduled_tasks) +from durabletask.testing import create_test_backend + +from tests.durabletask._port_utils import find_free_port + +PORT = find_free_port() +HOST = f"localhost:{PORT}" + + +@pytest.fixture(autouse=True) +def backend(): + """Create an in-memory backend for each test.""" + b = create_test_backend(port=PORT) + yield b + b.stop() + b.reset() + + +class _RunTracker: + """Thread-safe tracker that records each target orchestration run.""" + + def __init__(self): + self._lock = threading.Lock() + self.inputs: list[object] = [] + + def record(self, value: object) -> None: + with self._lock: + self.inputs.append(value) + + @property + def count(self) -> int: + with self._lock: + return len(self.inputs) + + +# A module-level tracker is used because orchestrators must be registered by +# reference and run on worker threads. +tracker = _RunTracker() + + +def target_orchestrator(ctx: task.OrchestrationContext, value): + """The orchestration started on each schedule run.""" + if not ctx.is_replaying: + tracker.record(value) + return value + + +def _make_worker() -> worker.TaskHubGrpcWorker: + w = worker.TaskHubGrpcWorker(host_address=HOST) + w.add_orchestrator(target_orchestrator) + configure_scheduled_tasks(w) + return w + + +def _wait_until(predicate, timeout: float = 15, interval: float = 0.2) -> bool: + deadline = time.time() + timeout + while time.time() < deadline: + if predicate(): + return True + time.sleep(interval) + return False + + +def setup_function(_): + # Reset the shared tracker before each test. + tracker.inputs.clear() + + +def test_create_describe_and_run(): + with _make_worker() as w: + w.start() + with client.TaskHubGrpcClient(host_address=HOST) as c: + scheduled = ScheduledTaskClient(c) + schedule = scheduled.create_schedule(ScheduleCreationOptions( + schedule_id="sched-run", + orchestration_name=task.get_name(target_orchestrator), + interval=timedelta(seconds=1), + orchestration_input="hello", + start_at=datetime.now(timezone.utc), + start_immediately_if_late=True, + )) + + assert _wait_until(lambda: tracker.count >= 1), "target orchestration did not run" + + description = schedule.describe() + assert description.schedule_id == "sched-run" + assert description.status == ScheduleStatus.ACTIVE + assert description.orchestration_name == task.get_name(target_orchestrator) + assert description.last_run_at is not None + + schedule.delete() + + +def test_get_nonexistent_returns_none(): + with _make_worker() as w: + w.start() + with client.TaskHubGrpcClient(host_address=HOST) as c: + scheduled = ScheduledTaskClient(c) + assert scheduled.get_schedule("does-not-exist") is None + + +def test_pause_and_resume(): + with _make_worker() as w: + w.start() + with client.TaskHubGrpcClient(host_address=HOST) as c: + scheduled = ScheduledTaskClient(c) + schedule = scheduled.create_schedule(ScheduleCreationOptions( + schedule_id="sched-pause", + orchestration_name=task.get_name(target_orchestrator), + interval=timedelta(seconds=1), + start_at=datetime.now(timezone.utc), + start_immediately_if_late=True, + )) + + assert _wait_until(lambda: tracker.count >= 1) + + schedule.pause() + assert schedule.describe().status == ScheduleStatus.PAUSED + + # After pausing, no further runs should occur. + time.sleep(2) + count_after_pause = tracker.count + time.sleep(2) + assert tracker.count == count_after_pause, "schedule kept running while paused" + + schedule.resume() + assert schedule.describe().status == ScheduleStatus.ACTIVE + assert _wait_until(lambda: tracker.count > count_after_pause) + + schedule.delete() + + +def test_update_interval_and_input(): + with _make_worker() as w: + w.start() + with client.TaskHubGrpcClient(host_address=HOST) as c: + scheduled = ScheduledTaskClient(c) + schedule = scheduled.create_schedule(ScheduleCreationOptions( + schedule_id="sched-update", + orchestration_name=task.get_name(target_orchestrator), + interval=timedelta(seconds=30), + orchestration_input="before", + start_at=datetime.now(timezone.utc) + timedelta(hours=1), + )) + + schedule.update(ScheduleUpdateOptions( + orchestration_input="after", + interval=timedelta(seconds=2), + )) + + description = schedule.describe() + assert description.interval == timedelta(seconds=2) + assert description.orchestration_input == "after" + + schedule.delete() + + +def test_list_schedules_with_prefix_and_status(): + with _make_worker() as w: + w.start() + with client.TaskHubGrpcClient(host_address=HOST) as c: + scheduled = ScheduledTaskClient(c) + for i in range(3): + scheduled.create_schedule(ScheduleCreationOptions( + schedule_id=f"group-a-{i}", + orchestration_name=task.get_name(target_orchestrator), + interval=timedelta(seconds=30), + start_at=datetime.now(timezone.utc) + timedelta(hours=1), + )) + scheduled.create_schedule(ScheduleCreationOptions( + schedule_id="group-b-0", + orchestration_name=task.get_name(target_orchestrator), + interval=timedelta(seconds=30), + start_at=datetime.now(timezone.utc) + timedelta(hours=1), + )) + + all_schedules = scheduled.list_schedules() + ids = {s.schedule_id for s in all_schedules} + assert {"group-a-0", "group-a-1", "group-a-2", "group-b-0"}.issubset(ids) + + group_a = scheduled.list_schedules(ScheduleQuery(schedule_id_prefix="group-a-")) + assert {s.schedule_id for s in group_a} == {"group-a-0", "group-a-1", "group-a-2"} + + active = scheduled.list_schedules(ScheduleQuery(status=ScheduleStatus.ACTIVE)) + assert all(s.status == ScheduleStatus.ACTIVE for s in active) + + for s in all_schedules: + scheduled.get_schedule_client(s.schedule_id).delete() + + +def test_delete_removes_schedule(): + with _make_worker() as w: + w.start() + with client.TaskHubGrpcClient(host_address=HOST) as c: + scheduled = ScheduledTaskClient(c) + schedule = scheduled.create_schedule(ScheduleCreationOptions( + schedule_id="sched-delete", + orchestration_name=task.get_name(target_orchestrator), + interval=timedelta(seconds=30), + start_at=datetime.now(timezone.utc) + timedelta(hours=1), + )) + + assert scheduled.get_schedule("sched-delete") is not None + + schedule.delete() + assert _wait_until(lambda: scheduled.get_schedule("sched-delete") is None) diff --git a/tests/durabletask/scheduled/test_transitions.py b/tests/durabletask/scheduled/test_transitions.py new file mode 100644 index 00000000..1a7f2c05 --- /dev/null +++ b/tests/durabletask/scheduled/test_transitions.py @@ -0,0 +1,58 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Unit tests for schedule state transition rules.""" + +from durabletask.scheduled import transitions +from durabletask.scheduled.schedule_status import ScheduleStatus + + +class TestCreateTransitions: + def test_create_from_uninitialized_to_active_is_valid(self): + assert transitions.is_valid_transition( + transitions.CREATE_SCHEDULE, ScheduleStatus.UNINITIALIZED, ScheduleStatus.ACTIVE) + + def test_create_from_active_to_active_is_valid(self): + assert transitions.is_valid_transition( + transitions.CREATE_SCHEDULE, ScheduleStatus.ACTIVE, ScheduleStatus.ACTIVE) + + def test_create_from_paused_to_active_is_valid(self): + assert transitions.is_valid_transition( + transitions.CREATE_SCHEDULE, ScheduleStatus.PAUSED, ScheduleStatus.ACTIVE) + + +class TestUpdateTransitions: + def test_update_active_to_active_is_valid(self): + assert transitions.is_valid_transition( + transitions.UPDATE_SCHEDULE, ScheduleStatus.ACTIVE, ScheduleStatus.ACTIVE) + + def test_update_paused_to_paused_is_valid(self): + assert transitions.is_valid_transition( + transitions.UPDATE_SCHEDULE, ScheduleStatus.PAUSED, ScheduleStatus.PAUSED) + + def test_update_from_uninitialized_is_invalid(self): + assert not transitions.is_valid_transition( + transitions.UPDATE_SCHEDULE, ScheduleStatus.UNINITIALIZED, ScheduleStatus.UNINITIALIZED) + + +class TestPauseResumeTransitions: + def test_pause_active_is_valid(self): + assert transitions.is_valid_transition( + transitions.PAUSE_SCHEDULE, ScheduleStatus.ACTIVE, ScheduleStatus.PAUSED) + + def test_pause_paused_is_invalid(self): + assert not transitions.is_valid_transition( + transitions.PAUSE_SCHEDULE, ScheduleStatus.PAUSED, ScheduleStatus.PAUSED) + + def test_resume_paused_is_valid(self): + assert transitions.is_valid_transition( + transitions.RESUME_SCHEDULE, ScheduleStatus.PAUSED, ScheduleStatus.ACTIVE) + + def test_resume_active_is_invalid(self): + assert not transitions.is_valid_transition( + transitions.RESUME_SCHEDULE, ScheduleStatus.ACTIVE, ScheduleStatus.ACTIVE) + + +def test_unknown_operation_is_invalid(): + assert not transitions.is_valid_transition( + "unknown_op", ScheduleStatus.ACTIVE, ScheduleStatus.ACTIVE) From 58ee9c929cda161a6cb195d5ffd39b4b043f1f16 Mon Sep 17 00:00:00 2001 From: Andy Staples Date: Fri, 26 Jun 2026 11:27:39 -0600 Subject: [PATCH 2/9] Let to_json hook take precedence over dataclass asdict in serializer The JSON serializer encoded dataclasses via dataclasses.asdict before ever checking for a to_json hook, so a dataclass could not override its own serialization -- a problem for dataclasses whose fields are not JSON-native (e.g. timedelta/datetime). The read path already consults from_json before its dataclass branch; this makes the write path symmetric by checking the to_json hook first and falling back to asdict/SimpleNamespace. --- durabletask/internal/json_codec.py | 15 ++++++++--- tests/durabletask/test_serialization.py | 34 +++++++++++++++++++++++++ 2 files changed, 45 insertions(+), 4 deletions(-) diff --git a/durabletask/internal/json_codec.py b/durabletask/internal/json_codec.py index 8fda0eae..83a87f53 100644 --- a/durabletask/internal/json_codec.py +++ b/durabletask/internal/json_codec.py @@ -76,18 +76,25 @@ def _encode_custom_object(o: Any) -> Any: namedtuples are handled natively by the encoder (serialized as JSON arrays) and never reach this hook. """ - if dataclasses.is_dataclass(o) and not isinstance(o, type): - return dataclasses.asdict(o) - if isinstance(o, SimpleNamespace): - return vars(o) # Custom objects may opt in via a ``to_json`` hook. It is resolved off the # type and called with the instance (``type(o).to_json(o)``) so that both # instance methods and ``@staticmethod`` hooks work -- matching the calling # convention used by ``azure-functions-durable``. The hook returns a # JSON-serializable value (a structure or a string), not a JSON document. + # + # The hook is checked before the dataclass / ``SimpleNamespace`` branches so + # a type may override the default structural encoding -- mirroring the read + # path, where :func:`coerce_to_type` consults ``from_json`` before its + # dataclass branch. This matters for dataclasses whose fields are not + # JSON-native (e.g. ``timedelta`` / ``datetime``), which ``asdict`` alone + # cannot serialize. to_json_hook = getattr(cast(Any, type(o)), "to_json", None) if callable(to_json_hook): return to_json_hook(o) + if dataclasses.is_dataclass(o) and not isinstance(o, type): + return dataclasses.asdict(o) + if isinstance(o, SimpleNamespace): + return vars(o) # This will raise a TypeError describing the unsupported type. raise TypeError(f"Object of type '{type(o).__name__}' is not JSON serializable") diff --git a/tests/durabletask/test_serialization.py b/tests/durabletask/test_serialization.py index 33003c21..34c612d6 100644 --- a/tests/durabletask/test_serialization.py +++ b/tests/durabletask/test_serialization.py @@ -93,6 +93,40 @@ def test_instance_to_json_hook_receives_instance(): assert json.loads(json_codec.to_json(Widget("gear", 5))) == {"label": "gear", "size": 5} +@dataclass +class HookedDataclass: + """Dataclass that overrides serialization via to_json/from_json. + + Its ``value`` field is not JSON-native on the wire (it is encoded as a + string), so plain ``dataclasses.asdict`` would not round-trip. The hooks + take precedence over the default dataclass encoding. + """ + + value: int + + def to_json(self) -> dict: + return {"value": str(self.value)} + + @classmethod + def from_json(cls, data: dict) -> "HookedDataclass": + return cls(int(data["value"])) + + def __eq__(self, other: object) -> bool: + return isinstance(other, HookedDataclass) and other.value == self.value + + +def test_dataclass_to_json_hook_takes_precedence_over_asdict(): + # A dataclass exposing to_json should serialize via the hook, not asdict. + encoded = json_codec.to_json(HookedDataclass(7)) + assert json.loads(encoded) == {"value": "7"} + + +def test_dataclass_hook_round_trips_with_expected_type(): + encoded = json_codec.to_json(HookedDataclass(7)) + result = json_codec.from_json(encoded, HookedDataclass) + assert result == HookedDataclass(7) + + # ----- to_json ----- From 4df10bcbaa8f2ead9e3f7150206a68668b9b5436 Mon Sep 17 00:00:00 2001 From: Andy Staples Date: Fri, 26 Jun 2026 11:34:16 -0600 Subject: [PATCH 3/9] Use to_json/from_json hooks for schedule options With the serializer now honoring the to_json hook for dataclasses, the schedule option types expose to_json/from_json instead of bespoke to_dict/from_dict. The schedule entity operations can again annotate their input as the option dataclass and let the worker reconstruct it via the from_json hook, removing the manual _coerce_options shim and the Any-typed parameter workaround. --- durabletask/scheduled/client.py | 4 ++-- durabletask/scheduled/models.py | 16 ++++++------- durabletask/scheduled/schedule_entity.py | 24 ++----------------- tests/durabletask/scheduled/test_models.py | 6 ++--- .../scheduled/test_schedule_entity.py | 6 ++--- 5 files changed, 18 insertions(+), 38 deletions(-) diff --git a/durabletask/scheduled/client.py b/durabletask/scheduled/client.py index a3474565..1dca6135 100644 --- a/durabletask/scheduled/client.py +++ b/durabletask/scheduled/client.py @@ -71,11 +71,11 @@ def _run_operation(self, operation_name: str, input: Any | None = None) -> None: def create(self, options: ScheduleCreationOptions) -> None: """Create or update this schedule with the given configuration.""" - self._run_operation(transitions.CREATE_SCHEDULE, options.to_dict()) + self._run_operation(transitions.CREATE_SCHEDULE, options.to_json()) def update(self, options: ScheduleUpdateOptions) -> None: """Update this schedule's configuration.""" - self._run_operation(transitions.UPDATE_SCHEDULE, options.to_dict()) + self._run_operation(transitions.UPDATE_SCHEDULE, options.to_json()) def pause(self) -> None: """Pause this schedule.""" diff --git a/durabletask/scheduled/models.py b/durabletask/scheduled/models.py index 2f03416d..77a7e6c1 100644 --- a/durabletask/scheduled/models.py +++ b/durabletask/scheduled/models.py @@ -55,7 +55,7 @@ def __post_init__(self): raise ValueError("orchestration_name cannot be empty.") _validate_interval(self.interval) - def to_dict(self) -> dict[str, Any]: + def to_json(self) -> dict[str, Any]: return { "schedule_id": self.schedule_id, "orchestration_name": self.orchestration_name, @@ -67,9 +67,9 @@ def to_dict(self) -> dict[str, Any]: "start_immediately_if_late": self.start_immediately_if_late, } - @staticmethod - def from_dict(data: dict[str, Any]) -> "ScheduleCreationOptions": - return ScheduleCreationOptions( + @classmethod + def from_json(cls, data: dict[str, Any]) -> "ScheduleCreationOptions": + return cls( schedule_id=data["schedule_id"], orchestration_name=data["orchestration_name"], interval=timedelta(seconds=data["interval_seconds"]), @@ -97,7 +97,7 @@ def __post_init__(self): if self.interval is not None: _validate_interval(self.interval) - def to_dict(self) -> dict[str, Any]: + def to_json(self) -> dict[str, Any]: return { "orchestration_name": self.orchestration_name, "orchestration_input": self.orchestration_input, @@ -108,9 +108,9 @@ def to_dict(self) -> dict[str, Any]: "start_immediately_if_late": self.start_immediately_if_late, } - @staticmethod - def from_dict(data: dict[str, Any]) -> "ScheduleUpdateOptions": - return ScheduleUpdateOptions( + @classmethod + def from_json(cls, data: dict[str, Any]) -> "ScheduleUpdateOptions": + return cls( orchestration_name=data.get("orchestration_name"), orchestration_input=data.get("orchestration_input"), orchestration_instance_id=data.get("orchestration_instance_id"), diff --git a/durabletask/scheduled/schedule_entity.py b/durabletask/scheduled/schedule_entity.py index 3bf02551..a0b60a6b 100644 --- a/durabletask/scheduled/schedule_entity.py +++ b/durabletask/scheduled/schedule_entity.py @@ -35,17 +35,6 @@ def _ensure_aware(value: datetime | None) -> datetime | None: return value -def _coerce_options(input: Any, cls: type) -> Any: - """Coerce a round-tripped input (dict/SimpleNamespace) into the given options dataclass.""" - if input is None or isinstance(input, cls): - return input - if isinstance(input, SimpleNamespace): - input = vars(input) - if isinstance(input, dict): - return cls.from_dict(input) - return input - - class Schedule(DurableEntity): """Entity that manages the state and execution of a scheduled task. @@ -74,15 +63,8 @@ def _can_transition_to(self, state: ScheduleState, operation_name: str, target_status: ScheduleStatus) -> bool: return transitions.is_valid_transition(operation_name, state.status, target_status) - # NOTE: the input is intentionally annotated ``Any`` rather than - # ``ScheduleCreationOptions``. The worker reconstructs an entity operation's - # input from its parameter annotation; a dataclass annotation would map the - # wire dict by field name and drop our JSON-friendly fields (e.g. - # ``interval_seconds``). Keeping ``Any`` lets the raw dict reach - # ``_coerce_options``, which rebuilds the options via ``from_dict``. - def create_schedule(self, options: Any) -> None: + def create_schedule(self, options: ScheduleCreationOptions) -> None: """Create a new schedule. If one already exists, update it in place.""" - options = _coerce_options(options, ScheduleCreationOptions) state = self._load_state() if not self._can_transition_to(state, transitions.CREATE_SCHEDULE, ScheduleStatus.ACTIVE): @@ -111,10 +93,8 @@ def create_schedule(self, options: Any) -> None: state.execution_token, ) - # NOTE: input annotated ``Any`` for the same reason as ``create_schedule``. - def update_schedule(self, options: Any) -> None: + def update_schedule(self, options: ScheduleUpdateOptions) -> None: """Update an existing schedule's configuration.""" - options = _coerce_options(options, ScheduleUpdateOptions) state = self._load_state() if not self._can_transition_to(state, transitions.UPDATE_SCHEDULE, state.status): diff --git a/tests/durabletask/scheduled/test_models.py b/tests/durabletask/scheduled/test_models.py index 97f3dafc..61eef997 100644 --- a/tests/durabletask/scheduled/test_models.py +++ b/tests/durabletask/scheduled/test_models.py @@ -51,8 +51,8 @@ def test_round_trip_through_json(self): orchestration_input={"key": "value"}, orchestration_instance_id="inst-1", start_at=start, end_at=end, start_immediately_if_late=True) - encoded = shared.to_json(options.to_dict()) - decoded = ScheduleCreationOptions.from_dict(shared.from_json(encoded)) + encoded = shared.to_json(options) + decoded = shared.from_json(encoded, ScheduleCreationOptions) assert decoded.schedule_id == "s1" assert decoded.orchestration_name == "orch" @@ -71,7 +71,7 @@ def test_interval_validation(self): def test_round_trip_through_json(self): options = ScheduleUpdateOptions(orchestration_name="orch2", interval=timedelta(seconds=10)) - decoded = ScheduleUpdateOptions.from_dict(shared.from_json(shared.to_json(options.to_dict()))) + decoded = shared.from_json(shared.to_json(options), ScheduleUpdateOptions) assert decoded.orchestration_name == "orch2" assert decoded.interval == timedelta(seconds=10) assert decoded.start_at is None diff --git a/tests/durabletask/scheduled/test_schedule_entity.py b/tests/durabletask/scheduled/test_schedule_entity.py index c65c6d18..f5326f71 100644 --- a/tests/durabletask/scheduled/test_schedule_entity.py +++ b/tests/durabletask/scheduled/test_schedule_entity.py @@ -59,7 +59,7 @@ def _start_actions(actions): def _creation_options(**kwargs): base = dict(schedule_id=SCHEDULE_ID, orchestration_name="my_orch", interval=timedelta(seconds=30)) base.update(kwargs) - return ScheduleCreationOptions(**base).to_dict() + return ScheduleCreationOptions(**base) class TestCreate: @@ -110,7 +110,7 @@ def test_update_changes_config_and_resignals(self): h = Harness() h.run("create_schedule", _creation_options()) _, actions = h.run("update_schedule", - ScheduleUpdateOptions(interval=timedelta(seconds=120)).to_dict()) + ScheduleUpdateOptions(interval=timedelta(seconds=120))) assert abs(h.state_dict["schedule_configuration"]["interval_seconds"] - 120) < 0.001 assert len(_signal_actions(actions)) == 1 @@ -118,7 +118,7 @@ def test_update_no_change_does_not_signal(self): h = Harness() h.run("create_schedule", _creation_options()) _, actions = h.run("update_schedule", - ScheduleUpdateOptions(orchestration_name="my_orch").to_dict()) + ScheduleUpdateOptions(orchestration_name="my_orch")) assert len(_signal_actions(actions)) == 0 From 2dbff2346520573b2c3dd606afb329c51671fd17 Mon Sep 17 00:00:00 2001 From: Andy Staples Date: Fri, 26 Jun 2026 14:39:46 -0600 Subject: [PATCH 4/9] Serialization improvements (WIP) --- CHANGELOG.md | 7 +- durabletask/internal/json_codec.py | 87 +++++++++++-- durabletask/scheduled/client.py | 29 +---- durabletask/scheduled/models.py | 30 +++-- durabletask/scheduled/orchestrator.py | 17 +-- durabletask/scheduled/schedule_entity.py | 20 +-- durabletask/serialization.py | 4 +- tests/durabletask/scheduled/test_models.py | 7 +- .../scheduled/test_schedule_entity.py | 7 ++ tests/durabletask/test_serialization.py | 118 ++++++++++++++++++ 10 files changed, 252 insertions(+), 74 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b14da70d..ee182d21 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,7 +27,12 @@ ADDED custom status, entity state) routes through it. The default `JsonDataConverter` preserves existing behavior, so a custom converter (for example one backed by pydantic) is opt-in. Custom objects can opt in via a - `to_json()` hook and a `from_json(value)` classmethod. + `to_json()` hook and a `from_json(value)` classmethod. Objects that contain + other hook-using objects round-trip automatically: nested `to_json()` hooks + fire at any depth during serialization, and a `from_json` hook may declare an + optional second parameter (`from_json(cls, value, converter)`) to reconstruct + nested typed values via `converter.coerce(child, ChildType)` instead of by + hand. - `OrchestrationContext.call_activity`, `call_sub_orchestrator`, and `call_entity` accept an optional `return_type`, and `wait_for_external_event` accepts an optional `data_type`. When provided, the result/event payload is diff --git a/durabletask/internal/json_codec.py b/durabletask/internal/json_codec.py index 83a87f53..f616b533 100644 --- a/durabletask/internal/json_codec.py +++ b/durabletask/internal/json_codec.py @@ -14,6 +14,7 @@ from __future__ import annotations import dataclasses +import inspect import json import types import typing @@ -47,7 +48,8 @@ def to_json(obj: Any) -> str: ) from e -def from_json(json_str: str | bytes | bytearray, expected_type: type | None = None) -> Any: +def from_json(json_str: str | bytes | bytearray, expected_type: type | None = None, + converter: Any | None = None) -> Any: """Deserialize a JSON string, optionally coercing the result to a type. When ``expected_type`` is ``None`` (the default) the raw parsed JSON is @@ -62,11 +64,15 @@ def from_json(json_str: str | bytes | bytearray, expected_type: type | None = No classmethod are reconstructed via that hook, and ``Optional``/``Union`` and ``list`` type hints are honored recursively. The destination type is always supplied by the caller; it is never read from the payload. + + ``converter`` is the active :class:`~durabletask.serialization.DataConverter`. + It is forwarded to ``from_json`` hooks that opt in (see :func:`coerce_to_type`) + so they can reconstruct nested typed values via ``converter.coerce(...)``. """ if expected_type is None: return json.loads(json_str, object_hook=_legacy_object_hook) raw = json.loads(json_str, object_hook=_strip_legacy_marker) - return coerce_to_type(raw, expected_type) + return coerce_to_type(raw, expected_type, converter) def _encode_custom_object(o: Any) -> Any: @@ -92,7 +98,13 @@ def _encode_custom_object(o: Any) -> Any: if callable(to_json_hook): return to_json_hook(o) if dataclasses.is_dataclass(o) and not isinstance(o, type): - return dataclasses.asdict(o) + # Shallow-convert to a dict whose *values are the original field objects* + # (unlike ``dataclasses.asdict``, which deep-recurses and would convert a + # nested dataclass via ``asdict`` -- bypassing that child's ``to_json`` + # hook). ``json.dumps`` then recurses into each value and re-enters this + # hook for any nested custom object, so nested ``to_json`` hooks fire at + # every depth (including inside lists/dicts). + return {f.name: getattr(o, f.name) for f in dataclasses.fields(o)} if isinstance(o, SimpleNamespace): return vars(o) # This will raise a TypeError describing the unsupported type. @@ -112,20 +124,31 @@ def _strip_legacy_marker(d: dict[str, Any]) -> dict[str, Any]: return d -def coerce_to_type(value: Any, expected_type: Any) -> Any: +def coerce_to_type(value: Any, expected_type: Any, converter: Any | None = None) -> Any: """Coerce an already-parsed JSON value to ``expected_type``. Handles ``None``/``Optional``/``Union`` and ``list`` type hints recursively, types exposing a ``from_json()`` classmethod, and dataclasses (including nested dataclass fields). The destination type is always caller-supplied and never derived from the payload, keeping deserialization secure. + + ``converter`` is the active :class:`~durabletask.serialization.DataConverter`. + A ``from_json`` hook may opt in to receiving it (by accepting a second + positional parameter) and delegate nested reconstruction back to the + converter, e.g. ``converter.coerce(child, ChildType)``. This keeps hooks free + of manual nested deserialization and routes children through the same policy. """ if expected_type is None or value is None: return value + # ``Any`` imposes no constraint -- and ``isinstance(x, Any)`` raises -- so + # short-circuit before any type inspection below. + if expected_type is typing.Any: + return value + origin = typing.get_origin(expected_type) if origin is not None: - return _coerce_generic(value, expected_type, origin) + return _coerce_generic(value, expected_type, origin, converter) if not isinstance(expected_type, type): # Not a concrete, instantiable type (e.g. a typing special form we don't @@ -137,10 +160,10 @@ def coerce_to_type(value: Any, expected_type: Any) -> Any: from_json_hook = getattr(expected_type, "from_json", None) if callable(from_json_hook): - return from_json_hook(value) + return _invoke_from_json(from_json_hook, value, converter) if dataclasses.is_dataclass(expected_type) and isinstance(value, dict): - return _build_dataclass(expected_type, cast(dict[str, Any], value)) + return _build_dataclass(expected_type, cast(dict[str, Any], value), converter) type_ctor = cast(Any, expected_type) try: @@ -153,12 +176,52 @@ def coerce_to_type(value: Any, expected_type: Any) -> Any: ) from e -def _coerce_generic(value: Any, expected_type: Any, origin: Any) -> Any: +def _invoke_from_json(from_json_hook: Any, value: Any, converter: Any | None) -> Any: + """Invoke a ``from_json`` hook, passing the converter if the hook accepts it. + + Hooks may be declared as ``from_json(cls, value)`` (the original contract) or + ``from_json(cls, value, converter)`` to opt into managed nested + reconstruction. Arity is detected from the bound hook's signature, so the + extra parameter is fully backwards compatible. When a converter-aware hook is + found but no converter was threaded (e.g. a direct ``json_codec`` call), the + shared default converter is resolved lazily so the hook always receives one. + """ + wants_converter = False + try: + params = [ + p for p in inspect.signature(from_json_hook).parameters.values() + if p.kind in (inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD) + ] + wants_converter = len(params) >= 2 + except (TypeError, ValueError): + wants_converter = False + + if wants_converter: + if converter is None: + converter = _default_converter() + return from_json_hook(value, converter) + return from_json_hook(value) + + +def _default_converter() -> Any: + # Lazy import to avoid a load-time cycle: ``serialization`` imports this + # module at import time, but by the time a hook actually runs both modules + # are fully initialized. + from durabletask.serialization import DEFAULT_DATA_CONVERTER + return DEFAULT_DATA_CONVERTER + + +def _coerce_generic(value: Any, expected_type: Any, origin: Any, converter: Any | None = None) -> Any: args = typing.get_args(expected_type) if origin is typing.Union or origin is types.UnionType: # If the value already matches a member type, keep it as-is. non_none = [a for a in args if a is not type(None)] for arg in non_none: + # An ``Any`` member imposes no constraint (and ``isinstance(x, Any)`` + # raises), so the value is acceptable as-is. + if arg is typing.Any: + return value if isinstance(arg, type) and isinstance(value, arg): return value # ``Optional[T]`` (exactly one non-None member): coerce to that member. @@ -166,16 +229,16 @@ def _coerce_generic(value: Any, expected_type: Any, origin: Any) -> Any: # the members, leave it untouched rather than guessing the first arg -- # forcing a coercion there can silently mis-construct the wrong type. if len(non_none) == 1: - return coerce_to_type(value, non_none[0]) + return coerce_to_type(value, non_none[0], converter) return value if origin in (list, Sequence) and isinstance(value, list): elem_type = args[0] if args else None - return [coerce_to_type(item, elem_type) for item in cast(list[Any], value)] + return [coerce_to_type(item, elem_type, converter) for item in cast(list[Any], value)] # Other generics (dict, tuple, ...) are returned as parsed JSON. return value -def _build_dataclass(cls: Any, data: dict[str, Any]) -> Any: +def _build_dataclass(cls: Any, data: dict[str, Any], converter: Any | None = None) -> Any: """Construct a dataclass from its dict payload, recursing into typed fields.""" try: hints = typing.get_type_hints(cls) @@ -186,5 +249,5 @@ def _build_dataclass(cls: Any, data: dict[str, Any]) -> Any: if field.name not in data: continue field_type = hints.get(field.name) - kwargs[field.name] = coerce_to_type(data[field.name], field_type) + kwargs[field.name] = coerce_to_type(data[field.name], field_type, converter) return cls(**kwargs) diff --git a/durabletask/scheduled/client.py b/durabletask/scheduled/client.py index 1dca6135..4e327b68 100644 --- a/durabletask/scheduled/client.py +++ b/durabletask/scheduled/client.py @@ -2,13 +2,10 @@ # Licensed under the MIT License. import logging -from dataclasses import asdict -from typing import Any from durabletask.client import (EntityQuery, OrchestrationStatus, TaskHubGrpcClient) from durabletask.entities import EntityInstanceId -from durabletask.internal import shared from durabletask.scheduled import transitions from durabletask.scheduled.exceptions import ScheduleNotFoundError from durabletask.scheduled.models import (ScheduleCreationOptions, @@ -22,20 +19,6 @@ logger = logging.getLogger("durabletask.scheduled") -def _parse_state(serialized_state: Any) -> ScheduleState | None: - if serialized_state is None: - return None - data = serialized_state - if isinstance(data, str): - if not data.strip(): - # A deleted (or never-initialized) entity reports empty state. - return None - data = shared.from_json(data) - if isinstance(data, dict): - return ScheduleState.from_dict(data) - return None - - class ScheduleClient: """Client for managing a single schedule instance.""" @@ -53,14 +36,14 @@ def schedule_id(self) -> str: """Gets the ID of this schedule.""" return self._schedule_id - def _run_operation(self, operation_name: str, input: Any | None = None) -> None: + def _run_operation(self, operation_name: str, input: object | None = None) -> None: request = ScheduleOperationRequest( entity_id=str(self._entity_id), operation_name=operation_name, input=input, ) instance_id = self._client.schedule_new_orchestration( - execute_schedule_operation_orchestrator, input=asdict(request)) + execute_schedule_operation_orchestrator, input=request) state = self._client.wait_for_orchestration_completion( instance_id, timeout=self._operation_timeout) if state is None or state.runtime_status != OrchestrationStatus.COMPLETED: @@ -71,11 +54,11 @@ def _run_operation(self, operation_name: str, input: Any | None = None) -> None: def create(self, options: ScheduleCreationOptions) -> None: """Create or update this schedule with the given configuration.""" - self._run_operation(transitions.CREATE_SCHEDULE, options.to_json()) + self._run_operation(transitions.CREATE_SCHEDULE, options) def update(self, options: ScheduleUpdateOptions) -> None: """Update this schedule's configuration.""" - self._run_operation(transitions.UPDATE_SCHEDULE, options.to_json()) + self._run_operation(transitions.UPDATE_SCHEDULE, options) def pause(self) -> None: """Pause this schedule.""" @@ -94,7 +77,7 @@ def describe(self) -> ScheduleDescription: metadata = self._client.get_entity(self._entity_id, include_state=True) if metadata is None: raise ScheduleNotFoundError(self._schedule_id) - state = _parse_state(metadata.get_state()) + state = metadata.get_typed_state(ScheduleState) if state is None: raise ScheduleNotFoundError(self._schedule_id) return state.to_description() @@ -136,7 +119,7 @@ def list_schedules(self, filter: ScheduleQuery | None = None) -> list[ScheduleDe ) results: list[ScheduleDescription] = [] for metadata in self._client.get_all_entities(query): - state = _parse_state(metadata.get_state()) + state = metadata.get_typed_state(ScheduleState) if state is None or state.schedule_configuration is None: continue if not self._matches_filter(state, filter): diff --git a/durabletask/scheduled/models.py b/durabletask/scheduled/models.py index 77a7e6c1..5aa723ce 100644 --- a/durabletask/scheduled/models.py +++ b/durabletask/scheduled/models.py @@ -220,7 +220,7 @@ def _validate(self): if self.start_at is not None and self.end_at is not None and self.start_at > self.end_at: raise ValueError("start_at cannot be later than end_at.") - def to_dict(self) -> dict[str, Any]: + def to_json(self) -> dict[str, Any]: return { "schedule_id": self.schedule_id, "orchestration_name": self.orchestration_name, @@ -232,9 +232,9 @@ def to_dict(self) -> dict[str, Any]: "start_immediately_if_late": self.start_immediately_if_late, } - @staticmethod - def from_dict(data: dict[str, Any]) -> "ScheduleConfiguration": - config = ScheduleConfiguration( + @classmethod + def from_json(cls, data: dict[str, Any]) -> "ScheduleConfiguration": + config = cls( data["schedule_id"], data["orchestration_name"], timedelta(seconds=data["interval_seconds"]), @@ -262,7 +262,10 @@ def __init__(self): def refresh_execution_token(self): self.execution_token = _new_token() - def to_dict(self) -> dict[str, Any]: + def to_json(self) -> dict[str, Any]: + # ``schedule_configuration`` is returned as the object itself; the + # serializer recurses into it and fires its own ``to_json`` hook. Only + # this type's non-JSON-native leaves (datetimes) are converted here. return { "status": self.status.value, "execution_token": self.execution_token, @@ -270,21 +273,24 @@ def to_dict(self) -> dict[str, Any]: "next_run_at": _to_iso(self.next_run_at), "schedule_created_at": _to_iso(self.schedule_created_at), "schedule_last_modified_at": _to_iso(self.schedule_last_modified_at), - "schedule_configuration": - self.schedule_configuration.to_dict() if self.schedule_configuration else None, + "schedule_configuration": self.schedule_configuration, } - @staticmethod - def from_dict(data: dict[str, Any]) -> "ScheduleState": - state = ScheduleState() + @classmethod + def from_json(cls, data: dict[str, Any], converter: Any) -> "ScheduleState": + # The nested configuration is reconstructed through the converter, which + # routes it to ``ScheduleConfiguration``'s own ``from_json`` hook (and + # honors a custom converter). Only this type's datetime leaves are + # rebuilt by hand. + state = cls() state.status = ScheduleStatus(data["status"]) state.execution_token = data["execution_token"] state.last_run_at = _from_iso(data.get("last_run_at")) state.next_run_at = _from_iso(data.get("next_run_at")) state.schedule_created_at = _from_iso(data.get("schedule_created_at")) state.schedule_last_modified_at = _from_iso(data.get("schedule_last_modified_at")) - config_data = data.get("schedule_configuration") - state.schedule_configuration = ScheduleConfiguration.from_dict(config_data) if config_data else None + state.schedule_configuration = converter.coerce( + data.get("schedule_configuration"), ScheduleConfiguration) return state def to_description(self) -> ScheduleDescription: diff --git a/durabletask/scheduled/orchestrator.py b/durabletask/scheduled/orchestrator.py index 111c98e1..ab94d60a 100644 --- a/durabletask/scheduled/orchestrator.py +++ b/durabletask/scheduled/orchestrator.py @@ -3,7 +3,6 @@ from collections.abc import Generator from dataclasses import dataclass -from types import SimpleNamespace from typing import Any from durabletask import task @@ -12,7 +11,13 @@ @dataclass class ScheduleOperationRequest: - """Request describing an operation to execute against a schedule entity.""" + """Request describing an operation to execute against a schedule entity. + + A plain dataclass: the serializer round-trips it (and its ``input`` payload) + automatically. ``input`` stays an ``Any`` here -- it is reconstructed into the + concrete options type at the entity-method boundary from that method's + parameter annotation. + """ entity_id: str operation_name: str @@ -20,17 +25,13 @@ class ScheduleOperationRequest: def execute_schedule_operation_orchestrator( - ctx: task.OrchestrationContext, request: Any) -> Generator[task.Task[Any], Any, Any]: + ctx: task.OrchestrationContext, + request: ScheduleOperationRequest) -> Generator[task.Task[Any], Any, Any]: """Orchestrator that executes a single operation on a schedule entity. Client-side write operations route through this orchestrator so callers can await completion (and surface failures) of the underlying entity operation. """ - if isinstance(request, SimpleNamespace): - request = vars(request) - if isinstance(request, dict): - request = ScheduleOperationRequest(**request) - entity_id = EntityInstanceId.parse(request.entity_id) result = yield ctx.call_entity(entity_id, request.operation_name, request.input) return result diff --git a/durabletask/scheduled/schedule_entity.py b/durabletask/scheduled/schedule_entity.py index a0b60a6b..c72c9a28 100644 --- a/durabletask/scheduled/schedule_entity.py +++ b/durabletask/scheduled/schedule_entity.py @@ -3,7 +3,6 @@ import logging from datetime import datetime, timezone -from types import SimpleNamespace from typing import Any from durabletask.entities import DurableEntity, EntityInstanceId @@ -44,17 +43,14 @@ class Schedule(DurableEntity): """ def _load_state(self) -> ScheduleState: - raw = self.get_state() - if raw is None: - return ScheduleState() - if isinstance(raw, SimpleNamespace): - raw = vars(raw) - if isinstance(raw, dict): - return ScheduleState.from_dict(raw) - raise TypeError(f"Unexpected schedule state type: {type(raw).__name__}") + # The serializer reconstructs the persisted ``ScheduleState`` via its + # ``from_json`` hook; a missing/uninitialized entity yields a fresh one. + return self.get_state(ScheduleState, ScheduleState()) def _save_state(self, state: ScheduleState) -> None: - self.set_state(state.to_dict()) + # Store the object as-is; the serializer persists it via its ``to_json`` + # hook when the entity batch is committed. + self.set_state(state) def _entity_id(self, schedule_id: str) -> EntityInstanceId: return EntityInstanceId(ENTITY_NAME, schedule_id) @@ -213,10 +209,6 @@ def run_schedule(self, execution_token: str) -> None: signal_time=state.next_run_at, ) - def delete(self, _: Any = None) -> None: - """Delete the schedule entity.""" - self.set_state(None) - def _start_orchestration(self, config: ScheduleConfiguration, scheduled_run_time: datetime) -> None: instance_id = config.orchestration_instance_id if not instance_id: diff --git a/durabletask/serialization.py b/durabletask/serialization.py index b1b469b4..c8db8a03 100644 --- a/durabletask/serialization.py +++ b/durabletask/serialization.py @@ -107,7 +107,7 @@ def deserialize(self, data: str | None, target_type: type | None = None) -> Any: if target_type is None: return json_codec.from_json(data) try: - return json_codec.from_json(data, target_type) + return json_codec.from_json(data, target_type, converter=self) except Exception as e: # Best-effort: fall back to the raw deserialized value rather than # failing the operation. Logged so the mismatch remains discoverable. @@ -118,7 +118,7 @@ def coerce(self, value: Any, target_type: type | None = None) -> Any: if target_type is None or value is None: return value try: - return json_codec.coerce_to_type(value, target_type) + return json_codec.coerce_to_type(value, target_type, converter=self) except Exception as e: self._log_coercion_fallback(target_type, e) return value diff --git a/tests/durabletask/scheduled/test_models.py b/tests/durabletask/scheduled/test_models.py index 61eef997..49b01202 100644 --- a/tests/durabletask/scheduled/test_models.py +++ b/tests/durabletask/scheduled/test_models.py @@ -108,7 +108,7 @@ def test_config_round_trip(self): ScheduleCreationOptions(schedule_id="s1", orchestration_name="orch", interval=timedelta(seconds=5), start_at=datetime(2026, 1, 1, tzinfo=timezone.utc))) - restored = ScheduleConfiguration.from_dict(shared.from_json(shared.to_json(config.to_dict()))) + restored = shared.from_json(shared.to_json(config), ScheduleConfiguration) assert restored.schedule_id == "s1" assert restored.interval == timedelta(seconds=5) assert restored.start_at == datetime(2026, 1, 1, tzinfo=timezone.utc) @@ -123,9 +123,12 @@ def test_round_trip_and_description(self): ScheduleCreationOptions(schedule_id="s1", orchestration_name="orch", interval=timedelta(seconds=5))) - restored = ScheduleState.from_dict(shared.from_json(shared.to_json(state.to_dict()))) + # The nested ``ScheduleConfiguration`` round-trips automatically. + restored = shared.from_json(shared.to_json(state), ScheduleState) assert restored.status == ScheduleStatus.ACTIVE assert restored.schedule_created_at == datetime(2026, 1, 1, tzinfo=timezone.utc) + assert restored.schedule_configuration is not None + assert restored.schedule_configuration.interval == timedelta(seconds=5) description = restored.to_description() assert description.schedule_id == "s1" diff --git a/tests/durabletask/scheduled/test_schedule_entity.py b/tests/durabletask/scheduled/test_schedule_entity.py index f5326f71..7cbc6397 100644 --- a/tests/durabletask/scheduled/test_schedule_entity.py +++ b/tests/durabletask/scheduled/test_schedule_entity.py @@ -36,6 +36,13 @@ def run(self, operation, input=None): encoded = shared.to_json(input) if input is not None else None result = self.executor.execute("orch-1", self.entity_id, operation, self.state, encoded) self.state.commit() + # Mimic the wire round-trip: the worker serializes the entity state at + # the end of each batch, and the next batch receives it as deserialized + # JSON (a plain dict). This exercises the state ``to_json``/``from_json`` + # hooks between operations and keeps assertions dict-based. + current = self.state._current_state # pyright: ignore[reportPrivateUsage] + if current is not None: + self.state._current_state = shared.from_json(shared.to_json(current)) # pyright: ignore[reportPrivateUsage] actions = self.state.get_operation_actions()[before:] return result, actions diff --git a/tests/durabletask/test_serialization.py b/tests/durabletask/test_serialization.py index 34c612d6..a6b3f889 100644 --- a/tests/durabletask/test_serialization.py +++ b/tests/durabletask/test_serialization.py @@ -298,3 +298,121 @@ class B: # not force-coerced into the first union member. value = {"z": 1} assert json_codec.coerce_to_type(value, Union[A, B]) == {"z": 1} + + +# ----- nested to_json hooks via shallow encode ----- + + +@dataclass +class Container: + """A plain dataclass (no hooks) holding a hook-using child.""" + + name: str + widget: Widget + + +def test_to_json_plain_dataclass_recurses_into_child_hook(): + # The container has no to_json; the nested Widget must still serialize via + # its own hook (not be flattened by a deep asdict that ignores hooks). + encoded = json_codec.to_json(Container("c", Widget("gear", 5))) + assert json.loads(encoded) == {"name": "c", "widget": {"label": "gear", "size": 5}} + + +def test_to_json_list_of_hooked_children_uses_hooks(): + encoded = json_codec.to_json([Widget("a", 1), Widget("b", 2)]) + assert json.loads(encoded) == [{"label": "a", "size": 1}, {"label": "b", "size": 2}] + + +def test_round_trip_plain_dataclass_with_hooked_child(): + original = Container("c", Widget("gear", 5)) + restored = json_codec.from_json(json_codec.to_json(original), Container) + assert isinstance(restored, Container) + assert restored.widget == Widget("gear", 5) + + +# ----- converter-aware from_json hooks ----- + + +@dataclass +class Leaf: + """Child whose ``value`` is encoded as a string, requiring its own hook.""" + + value: int + + def to_json(self) -> dict: + return {"value": str(self.value)} + + @classmethod + def from_json(cls, data: dict) -> "Leaf": + return cls(int(data["value"])) + + +class Branch: + """Parent with its own hook that reconstructs a nested ``Leaf`` via the converter.""" + + def __init__(self, tag: str, leaf: Leaf): + self.tag = tag + self.leaf = leaf + + def to_json(self) -> dict: + return {"tag": self.tag, "leaf": self.leaf} + + @classmethod + def from_json(cls, data: dict, converter) -> "Branch": + return cls(data["tag"], converter.coerce(data["leaf"], Leaf)) + + def __eq__(self, other: object) -> bool: + return (isinstance(other, Branch) + and other.tag == self.tag and other.leaf == self.leaf) + + +def test_converter_aware_from_json_reconstructs_nested_child(): + from durabletask.serialization import JsonDataConverter + + original = Branch("root", Leaf(7)) + converter = JsonDataConverter() + encoded = converter.serialize(original) + # The nested Leaf is encoded via its own hook. + assert json.loads(encoded) == {"tag": "root", "leaf": {"value": "7"}} + + restored = converter.deserialize(encoded, Branch) + assert isinstance(restored, Branch) + assert isinstance(restored.leaf, Leaf) + assert restored == Branch("root", Leaf(7)) + + +def test_converter_aware_from_json_resolves_default_converter_when_none(): + # Calling json_codec directly (no converter threaded) must still supply one + # to a converter-aware hook via the lazy default. + encoded = json_codec.to_json(Branch("root", Leaf(3))) + restored = json_codec.from_json(encoded, Branch) + assert restored == Branch("root", Leaf(3)) + + +def test_one_arg_from_json_hook_still_supported(): + # A legacy single-parameter from_json hook continues to work unchanged. + encoded = json_codec.to_json(Widget("gear", 5)) + assert json_codec.from_json(encoded, Widget) == Widget("gear", 5) + + +# ----- Any handling ----- + + +def test_coerce_any_returns_value_unchanged(): + from typing import Any as TAny + value = {"anything": [1, 2, 3]} + assert json_codec.coerce_to_type(value, TAny) is value + + +def test_coerce_optional_any_field_does_not_raise(): + from typing import Any as TAny + + @dataclass + class HasAny: + name: str + payload: TAny | None = None + + result = json_codec.from_json( + json_codec.to_json(HasAny("x", {"k": "v"})), HasAny) + assert isinstance(result, HasAny) + assert result.payload == {"k": "v"} From 7ea7478bb512b950477282a86f7991e19e43c1d3 Mon Sep 17 00:00:00 2001 From: Andy Staples Date: Mon, 29 Jun 2026 10:08:17 -0600 Subject: [PATCH 5/9] Revert JSON stuff to base --- durabletask/internal/client_helpers.py | 4 +- durabletask/internal/helpers.py | 9 +-- durabletask/internal/json_codec.py | 100 ++++--------------------- 3 files changed, 18 insertions(+), 95 deletions(-) diff --git a/durabletask/internal/client_helpers.py b/durabletask/internal/client_helpers.py index fe0c828b..a42c3c43 100644 --- a/durabletask/internal/client_helpers.py +++ b/durabletask/internal/client_helpers.py @@ -205,16 +205,14 @@ def build_signal_entity_req( entity_instance_id: EntityInstanceId, operation_name: str, input: Any | None = None, - signal_time: datetime | None = None, data_converter: DataConverter | None = None) -> pb.SignalEntityRequest: """Build a SignalEntityRequest for signaling an entity.""" - scheduled_time = helpers.new_timestamp(signal_time) if signal_time is not None else None return pb.SignalEntityRequest( instanceId=str(entity_instance_id), name=operation_name, input=helpers.get_string_value(_serialize(input, data_converter)), requestId=str(uuid.uuid4()), - scheduledTime=scheduled_time, + scheduledTime=None, parentTraceContext=None, requestTime=helpers.new_timestamp(datetime.now(timezone.utc)) ) diff --git a/durabletask/internal/helpers.py b/durabletask/internal/helpers.py index 273096d6..2342afdd 100644 --- a/durabletask/internal/helpers.py +++ b/durabletask/internal/helpers.py @@ -255,16 +255,11 @@ def new_signal_entity_action(id: int, entity_id: EntityInstanceId, operation: str, encoded_input: str | None, - request_id: str, - scheduled_time: datetime | None = None) -> pb.OrchestratorAction: - scheduled_timestamp: timestamp_pb2.Timestamp | None = None - if scheduled_time is not None: - scheduled_timestamp = timestamp_pb2.Timestamp() - scheduled_timestamp.FromDatetime(scheduled_time) + request_id: str) -> pb.OrchestratorAction: return pb.OrchestratorAction(id=id, sendEntityMessage=pb.SendEntityMessageAction(entityOperationSignaled=pb.EntityOperationSignaledEvent( requestId=request_id, operation=operation, - scheduledTime=scheduled_timestamp, + scheduledTime=None, input=get_string_value(encoded_input), targetInstanceId=get_string_value(str(entity_id)), ))) diff --git a/durabletask/internal/json_codec.py b/durabletask/internal/json_codec.py index f616b533..8fda0eae 100644 --- a/durabletask/internal/json_codec.py +++ b/durabletask/internal/json_codec.py @@ -14,7 +14,6 @@ from __future__ import annotations import dataclasses -import inspect import json import types import typing @@ -48,8 +47,7 @@ def to_json(obj: Any) -> str: ) from e -def from_json(json_str: str | bytes | bytearray, expected_type: type | None = None, - converter: Any | None = None) -> Any: +def from_json(json_str: str | bytes | bytearray, expected_type: type | None = None) -> Any: """Deserialize a JSON string, optionally coercing the result to a type. When ``expected_type`` is ``None`` (the default) the raw parsed JSON is @@ -64,15 +62,11 @@ def from_json(json_str: str | bytes | bytearray, expected_type: type | None = No classmethod are reconstructed via that hook, and ``Optional``/``Union`` and ``list`` type hints are honored recursively. The destination type is always supplied by the caller; it is never read from the payload. - - ``converter`` is the active :class:`~durabletask.serialization.DataConverter`. - It is forwarded to ``from_json`` hooks that opt in (see :func:`coerce_to_type`) - so they can reconstruct nested typed values via ``converter.coerce(...)``. """ if expected_type is None: return json.loads(json_str, object_hook=_legacy_object_hook) raw = json.loads(json_str, object_hook=_strip_legacy_marker) - return coerce_to_type(raw, expected_type, converter) + return coerce_to_type(raw, expected_type) def _encode_custom_object(o: Any) -> Any: @@ -82,31 +76,18 @@ def _encode_custom_object(o: Any) -> Any: namedtuples are handled natively by the encoder (serialized as JSON arrays) and never reach this hook. """ + if dataclasses.is_dataclass(o) and not isinstance(o, type): + return dataclasses.asdict(o) + if isinstance(o, SimpleNamespace): + return vars(o) # Custom objects may opt in via a ``to_json`` hook. It is resolved off the # type and called with the instance (``type(o).to_json(o)``) so that both # instance methods and ``@staticmethod`` hooks work -- matching the calling # convention used by ``azure-functions-durable``. The hook returns a # JSON-serializable value (a structure or a string), not a JSON document. - # - # The hook is checked before the dataclass / ``SimpleNamespace`` branches so - # a type may override the default structural encoding -- mirroring the read - # path, where :func:`coerce_to_type` consults ``from_json`` before its - # dataclass branch. This matters for dataclasses whose fields are not - # JSON-native (e.g. ``timedelta`` / ``datetime``), which ``asdict`` alone - # cannot serialize. to_json_hook = getattr(cast(Any, type(o)), "to_json", None) if callable(to_json_hook): return to_json_hook(o) - if dataclasses.is_dataclass(o) and not isinstance(o, type): - # Shallow-convert to a dict whose *values are the original field objects* - # (unlike ``dataclasses.asdict``, which deep-recurses and would convert a - # nested dataclass via ``asdict`` -- bypassing that child's ``to_json`` - # hook). ``json.dumps`` then recurses into each value and re-enters this - # hook for any nested custom object, so nested ``to_json`` hooks fire at - # every depth (including inside lists/dicts). - return {f.name: getattr(o, f.name) for f in dataclasses.fields(o)} - if isinstance(o, SimpleNamespace): - return vars(o) # This will raise a TypeError describing the unsupported type. raise TypeError(f"Object of type '{type(o).__name__}' is not JSON serializable") @@ -124,31 +105,20 @@ def _strip_legacy_marker(d: dict[str, Any]) -> dict[str, Any]: return d -def coerce_to_type(value: Any, expected_type: Any, converter: Any | None = None) -> Any: +def coerce_to_type(value: Any, expected_type: Any) -> Any: """Coerce an already-parsed JSON value to ``expected_type``. Handles ``None``/``Optional``/``Union`` and ``list`` type hints recursively, types exposing a ``from_json()`` classmethod, and dataclasses (including nested dataclass fields). The destination type is always caller-supplied and never derived from the payload, keeping deserialization secure. - - ``converter`` is the active :class:`~durabletask.serialization.DataConverter`. - A ``from_json`` hook may opt in to receiving it (by accepting a second - positional parameter) and delegate nested reconstruction back to the - converter, e.g. ``converter.coerce(child, ChildType)``. This keeps hooks free - of manual nested deserialization and routes children through the same policy. """ if expected_type is None or value is None: return value - # ``Any`` imposes no constraint -- and ``isinstance(x, Any)`` raises -- so - # short-circuit before any type inspection below. - if expected_type is typing.Any: - return value - origin = typing.get_origin(expected_type) if origin is not None: - return _coerce_generic(value, expected_type, origin, converter) + return _coerce_generic(value, expected_type, origin) if not isinstance(expected_type, type): # Not a concrete, instantiable type (e.g. a typing special form we don't @@ -160,10 +130,10 @@ def coerce_to_type(value: Any, expected_type: Any, converter: Any | None = None) from_json_hook = getattr(expected_type, "from_json", None) if callable(from_json_hook): - return _invoke_from_json(from_json_hook, value, converter) + return from_json_hook(value) if dataclasses.is_dataclass(expected_type) and isinstance(value, dict): - return _build_dataclass(expected_type, cast(dict[str, Any], value), converter) + return _build_dataclass(expected_type, cast(dict[str, Any], value)) type_ctor = cast(Any, expected_type) try: @@ -176,52 +146,12 @@ def coerce_to_type(value: Any, expected_type: Any, converter: Any | None = None) ) from e -def _invoke_from_json(from_json_hook: Any, value: Any, converter: Any | None) -> Any: - """Invoke a ``from_json`` hook, passing the converter if the hook accepts it. - - Hooks may be declared as ``from_json(cls, value)`` (the original contract) or - ``from_json(cls, value, converter)`` to opt into managed nested - reconstruction. Arity is detected from the bound hook's signature, so the - extra parameter is fully backwards compatible. When a converter-aware hook is - found but no converter was threaded (e.g. a direct ``json_codec`` call), the - shared default converter is resolved lazily so the hook always receives one. - """ - wants_converter = False - try: - params = [ - p for p in inspect.signature(from_json_hook).parameters.values() - if p.kind in (inspect.Parameter.POSITIONAL_ONLY, - inspect.Parameter.POSITIONAL_OR_KEYWORD) - ] - wants_converter = len(params) >= 2 - except (TypeError, ValueError): - wants_converter = False - - if wants_converter: - if converter is None: - converter = _default_converter() - return from_json_hook(value, converter) - return from_json_hook(value) - - -def _default_converter() -> Any: - # Lazy import to avoid a load-time cycle: ``serialization`` imports this - # module at import time, but by the time a hook actually runs both modules - # are fully initialized. - from durabletask.serialization import DEFAULT_DATA_CONVERTER - return DEFAULT_DATA_CONVERTER - - -def _coerce_generic(value: Any, expected_type: Any, origin: Any, converter: Any | None = None) -> Any: +def _coerce_generic(value: Any, expected_type: Any, origin: Any) -> Any: args = typing.get_args(expected_type) if origin is typing.Union or origin is types.UnionType: # If the value already matches a member type, keep it as-is. non_none = [a for a in args if a is not type(None)] for arg in non_none: - # An ``Any`` member imposes no constraint (and ``isinstance(x, Any)`` - # raises), so the value is acceptable as-is. - if arg is typing.Any: - return value if isinstance(arg, type) and isinstance(value, arg): return value # ``Optional[T]`` (exactly one non-None member): coerce to that member. @@ -229,16 +159,16 @@ def _coerce_generic(value: Any, expected_type: Any, origin: Any, converter: Any # the members, leave it untouched rather than guessing the first arg -- # forcing a coercion there can silently mis-construct the wrong type. if len(non_none) == 1: - return coerce_to_type(value, non_none[0], converter) + return coerce_to_type(value, non_none[0]) return value if origin in (list, Sequence) and isinstance(value, list): elem_type = args[0] if args else None - return [coerce_to_type(item, elem_type, converter) for item in cast(list[Any], value)] + return [coerce_to_type(item, elem_type) for item in cast(list[Any], value)] # Other generics (dict, tuple, ...) are returned as parsed JSON. return value -def _build_dataclass(cls: Any, data: dict[str, Any], converter: Any | None = None) -> Any: +def _build_dataclass(cls: Any, data: dict[str, Any]) -> Any: """Construct a dataclass from its dict payload, recursing into typed fields.""" try: hints = typing.get_type_hints(cls) @@ -249,5 +179,5 @@ def _build_dataclass(cls: Any, data: dict[str, Any], converter: Any | None = Non if field.name not in data: continue field_type = hints.get(field.name) - kwargs[field.name] = coerce_to_type(data[field.name], field_type, converter) + kwargs[field.name] = coerce_to_type(data[field.name], field_type) return cls(**kwargs) From 10657aa0760f5de68468145a5354dc9ec45f65c5 Mon Sep 17 00:00:00 2001 From: Andy Staples Date: Mon, 29 Jun 2026 10:11:41 -0600 Subject: [PATCH 6/9] Revert serialization test --- tests/durabletask/test_serialization.py | 34 ------------------------- 1 file changed, 34 deletions(-) diff --git a/tests/durabletask/test_serialization.py b/tests/durabletask/test_serialization.py index e3b4a518..ae797c2f 100644 --- a/tests/durabletask/test_serialization.py +++ b/tests/durabletask/test_serialization.py @@ -103,40 +103,6 @@ def test_instance_to_json_hook_receives_instance(): assert json.loads(to_json(Widget("gear", 5))) == {"label": "gear", "size": 5} -@dataclass -class HookedDataclass: - """Dataclass that overrides serialization via to_json/from_json. - - Its ``value`` field is not JSON-native on the wire (it is encoded as a - string), so plain ``dataclasses.asdict`` would not round-trip. The hooks - take precedence over the default dataclass encoding. - """ - - value: int - - def to_json(self) -> dict: - return {"value": str(self.value)} - - @classmethod - def from_json(cls, data: dict) -> "HookedDataclass": - return cls(int(data["value"])) - - def __eq__(self, other: object) -> bool: - return isinstance(other, HookedDataclass) and other.value == self.value - - -def test_dataclass_to_json_hook_takes_precedence_over_asdict(): - # A dataclass exposing to_json should serialize via the hook, not asdict. - encoded = json_codec.to_json(HookedDataclass(7)) - assert json.loads(encoded) == {"value": "7"} - - -def test_dataclass_hook_round_trips_with_expected_type(): - encoded = json_codec.to_json(HookedDataclass(7)) - result = json_codec.from_json(encoded, HookedDataclass) - assert result == HookedDataclass(7) - - # ----- to_json ----- From e8186e72508e988c7d407fe5da1e7fe402a7bc17 Mon Sep 17 00:00:00 2001 From: Andy Staples Date: Mon, 29 Jun 2026 10:43:05 -0600 Subject: [PATCH 7/9] Update scheduled tasks with new serialization fixes --- durabletask/internal/client_helpers.py | 4 +- durabletask/internal/helpers.py | 6 +- durabletask/scheduled/models.py | 16 ++-- durabletask/scheduled/orchestrator.py | 8 +- durabletask/serialization.py | 8 ++ tests/durabletask/scheduled/test_models.py | 14 ++-- .../scheduled/test_schedule_entity.py | 83 +++++++++++-------- tests/durabletask/test_serialization.py | 29 ++++++- 8 files changed, 112 insertions(+), 56 deletions(-) diff --git a/durabletask/internal/client_helpers.py b/durabletask/internal/client_helpers.py index a42c3c43..fe0c828b 100644 --- a/durabletask/internal/client_helpers.py +++ b/durabletask/internal/client_helpers.py @@ -205,14 +205,16 @@ def build_signal_entity_req( entity_instance_id: EntityInstanceId, operation_name: str, input: Any | None = None, + signal_time: datetime | None = None, data_converter: DataConverter | None = None) -> pb.SignalEntityRequest: """Build a SignalEntityRequest for signaling an entity.""" + scheduled_time = helpers.new_timestamp(signal_time) if signal_time is not None else None return pb.SignalEntityRequest( instanceId=str(entity_instance_id), name=operation_name, input=helpers.get_string_value(_serialize(input, data_converter)), requestId=str(uuid.uuid4()), - scheduledTime=None, + scheduledTime=scheduled_time, parentTraceContext=None, requestTime=helpers.new_timestamp(datetime.now(timezone.utc)) ) diff --git a/durabletask/internal/helpers.py b/durabletask/internal/helpers.py index 2342afdd..cb0a815b 100644 --- a/durabletask/internal/helpers.py +++ b/durabletask/internal/helpers.py @@ -255,11 +255,13 @@ def new_signal_entity_action(id: int, entity_id: EntityInstanceId, operation: str, encoded_input: str | None, - request_id: str) -> pb.OrchestratorAction: + request_id: str, + signal_time: datetime | None = None) -> pb.OrchestratorAction: + scheduled_time = new_timestamp(signal_time) if signal_time is not None else None return pb.OrchestratorAction(id=id, sendEntityMessage=pb.SendEntityMessageAction(entityOperationSignaled=pb.EntityOperationSignaledEvent( requestId=request_id, operation=operation, - scheduledTime=None, + scheduledTime=scheduled_time, input=get_string_value(encoded_input), targetInstanceId=get_string_value(str(entity_id)), ))) diff --git a/durabletask/scheduled/models.py b/durabletask/scheduled/models.py index 5aa723ce..ab9a4dbe 100644 --- a/durabletask/scheduled/models.py +++ b/durabletask/scheduled/models.py @@ -277,11 +277,12 @@ def to_json(self) -> dict[str, Any]: } @classmethod - def from_json(cls, data: dict[str, Any], converter: Any) -> "ScheduleState": - # The nested configuration is reconstructed through the converter, which - # routes it to ``ScheduleConfiguration``'s own ``from_json`` hook (and - # honors a custom converter). Only this type's datetime leaves are - # rebuilt by hand. + def from_json(cls, data: dict[str, Any]) -> "ScheduleState": + # The nested configuration is reconstructed by calling its own + # ``from_json`` hook directly. ``ScheduleConfiguration`` is an internal + # type, so there is no need to route it through a (possibly custom) + # converter -- keeping this hook converter-free means it round-trips + # under any code path, not only the worker's threaded converter. state = cls() state.status = ScheduleStatus(data["status"]) state.execution_token = data["execution_token"] @@ -289,8 +290,9 @@ def from_json(cls, data: dict[str, Any], converter: Any) -> "ScheduleState": state.next_run_at = _from_iso(data.get("next_run_at")) state.schedule_created_at = _from_iso(data.get("schedule_created_at")) state.schedule_last_modified_at = _from_iso(data.get("schedule_last_modified_at")) - state.schedule_configuration = converter.coerce( - data.get("schedule_configuration"), ScheduleConfiguration) + config_data = data.get("schedule_configuration") + state.schedule_configuration = ( + ScheduleConfiguration.from_json(config_data) if config_data is not None else None) return state def to_description(self) -> ScheduleDescription: diff --git a/durabletask/scheduled/orchestrator.py b/durabletask/scheduled/orchestrator.py index ab94d60a..0493b744 100644 --- a/durabletask/scheduled/orchestrator.py +++ b/durabletask/scheduled/orchestrator.py @@ -13,10 +13,10 @@ class ScheduleOperationRequest: """Request describing an operation to execute against a schedule entity. - A plain dataclass: the serializer round-trips it (and its ``input`` payload) - automatically. ``input`` stays an ``Any`` here -- it is reconstructed into the - concrete options type at the entity-method boundary from that method's - parameter annotation. + A plain dataclass: the serializer round-trips it automatically. ``input`` is + typed ``Any``, so it is reconstructed as the raw deserialized payload; the + concrete options type is rebuilt later, at the entity-method boundary, from + that method's parameter annotation. """ entity_id: str diff --git a/durabletask/serialization.py b/durabletask/serialization.py index 5d1dad89..e6c68935 100644 --- a/durabletask/serialization.py +++ b/durabletask/serialization.py @@ -360,6 +360,9 @@ def _coerce_to_type(value: Any, expected_type: Any, converter: DataConverter | N if expected_type is None or value is None: return value + if expected_type is typing.Any: + return value + origin = typing.get_origin(expected_type) if origin is not None: return _coerce_generic(value, expected_type, origin, converter) @@ -447,6 +450,11 @@ def _coerce_generic(value: Any, expected_type: Any, origin: Any, # If the value already matches a member type, keep it as-is. non_none = [a for a in args if a is not type(None)] for arg in non_none: + # ``Any`` imposes no constraint, so the value already satisfies the + # union. (Checked explicitly because ``isinstance(value, Any)`` would + # raise -- ``typing.Any`` is a class on 3.11+ but not isinstance-able.) + if arg is typing.Any: + return value if isinstance(arg, type) and isinstance(value, arg): return value # ``Optional[T]`` (exactly one non-None member): coerce to that member. diff --git a/tests/durabletask/scheduled/test_models.py b/tests/durabletask/scheduled/test_models.py index 49b01202..ddd7c560 100644 --- a/tests/durabletask/scheduled/test_models.py +++ b/tests/durabletask/scheduled/test_models.py @@ -7,12 +7,14 @@ import pytest -from durabletask.internal import shared +from durabletask.serialization import JsonDataConverter from durabletask.scheduled.models import (ScheduleConfiguration, ScheduleCreationOptions, ScheduleState, ScheduleUpdateOptions) from durabletask.scheduled.schedule_status import ScheduleStatus +converter = JsonDataConverter() + class TestCreationOptionsValidation: def test_requires_schedule_id(self): @@ -51,8 +53,8 @@ def test_round_trip_through_json(self): orchestration_input={"key": "value"}, orchestration_instance_id="inst-1", start_at=start, end_at=end, start_immediately_if_late=True) - encoded = shared.to_json(options) - decoded = shared.from_json(encoded, ScheduleCreationOptions) + encoded = converter.serialize(options) + decoded = converter.deserialize(encoded, ScheduleCreationOptions) assert decoded.schedule_id == "s1" assert decoded.orchestration_name == "orch" @@ -71,7 +73,7 @@ def test_interval_validation(self): def test_round_trip_through_json(self): options = ScheduleUpdateOptions(orchestration_name="orch2", interval=timedelta(seconds=10)) - decoded = shared.from_json(shared.to_json(options), ScheduleUpdateOptions) + decoded = converter.deserialize(converter.serialize(options), ScheduleUpdateOptions) assert decoded.orchestration_name == "orch2" assert decoded.interval == timedelta(seconds=10) assert decoded.start_at is None @@ -108,7 +110,7 @@ def test_config_round_trip(self): ScheduleCreationOptions(schedule_id="s1", orchestration_name="orch", interval=timedelta(seconds=5), start_at=datetime(2026, 1, 1, tzinfo=timezone.utc))) - restored = shared.from_json(shared.to_json(config), ScheduleConfiguration) + restored = converter.deserialize(converter.serialize(config), ScheduleConfiguration) assert restored.schedule_id == "s1" assert restored.interval == timedelta(seconds=5) assert restored.start_at == datetime(2026, 1, 1, tzinfo=timezone.utc) @@ -124,7 +126,7 @@ def test_round_trip_and_description(self): interval=timedelta(seconds=5))) # The nested ``ScheduleConfiguration`` round-trips automatically. - restored = shared.from_json(shared.to_json(state), ScheduleState) + restored = converter.deserialize(converter.serialize(state), ScheduleState) assert restored.status == ScheduleStatus.ACTIVE assert restored.schedule_created_at == datetime(2026, 1, 1, tzinfo=timezone.utc) assert restored.schedule_configuration is not None diff --git a/tests/durabletask/scheduled/test_schedule_entity.py b/tests/durabletask/scheduled/test_schedule_entity.py index 7cbc6397..91733579 100644 --- a/tests/durabletask/scheduled/test_schedule_entity.py +++ b/tests/durabletask/scheduled/test_schedule_entity.py @@ -5,66 +5,76 @@ import logging from datetime import datetime, timedelta, timezone +from typing import Any import pytest +import durabletask.internal.orchestrator_service_pb2 as pb from durabletask.entities import EntityInstanceId -from durabletask.internal import shared from durabletask.internal.entity_state_shim import StateShim from durabletask.scheduled.exceptions import ScheduleInvalidTransitionError -from durabletask.scheduled.models import (ScheduleCreationOptions, +from durabletask.scheduled.models import (ScheduleCreationOptions, ScheduleState, ScheduleUpdateOptions) from durabletask.scheduled.schedule_entity import (ENTITY_NAME, Schedule) from durabletask.scheduled.schedule_status import ScheduleStatus +from durabletask.serialization import JsonDataConverter from durabletask.worker import _EntityExecutor, _Registry SCHEDULE_ID = "sched-1" class Harness: - """Drives Schedule entity operations against a persistent in-memory state.""" + """Drives Schedule entity operations against a persistent in-memory state. + + Mirrors the worker's entity-batch lifecycle: the ``StateShim`` holds the + serialized state between operations, so each ``run`` deserializes on read and + serializes on write through the data converter -- exactly the wire round-trip + a real entity experiences across batches. + """ def __init__(self): registry = _Registry() registry.add_entity(Schedule, ENTITY_NAME) - self.executor = _EntityExecutor(registry, logging.getLogger("test")) - self.state = StateShim(None) + self.converter = JsonDataConverter() + self.executor = _EntityExecutor(registry, logging.getLogger("test"), self.converter) + self.shim = StateShim(None, self.converter) self.entity_id = EntityInstanceId(ENTITY_NAME, SCHEDULE_ID) - def run(self, operation, input=None): - before = len(self.state.get_operation_actions()) - encoded = shared.to_json(input) if input is not None else None - result = self.executor.execute("orch-1", self.entity_id, operation, self.state, encoded) - self.state.commit() - # Mimic the wire round-trip: the worker serializes the entity state at - # the end of each batch, and the next batch receives it as deserialized - # JSON (a plain dict). This exercises the state ``to_json``/``from_json`` - # hooks between operations and keeps assertions dict-based. - current = self.state._current_state # pyright: ignore[reportPrivateUsage] - if current is not None: - self.state._current_state = shared.from_json(shared.to_json(current)) # pyright: ignore[reportPrivateUsage] - actions = self.state.get_operation_actions()[before:] + def run(self, operation: str, input: Any = None) -> tuple[str | None, list[pb.OperationAction]]: + before = len(self.shim.get_operation_actions()) + encoded = self.converter.serialize(input) if input is not None else None + result = self.executor.execute("orch-1", self.entity_id, operation, self.shim, encoded) + self.shim.commit() + actions = self.shim.get_operation_actions()[before:] return result, actions + def state(self) -> ScheduleState | None: + """Reconstruct the typed state object, the way the entity itself would.""" + return self.shim.get_state(ScheduleState) + @property - def state_dict(self): - return self.state._current_state # pyright: ignore[reportPrivateUsage] + def current(self) -> ScheduleState: + """Like :meth:`state` but asserts the state exists (most operations).""" + state = self.state() + assert state is not None + return state @property - def token(self): - return self.state_dict["execution_token"] + def token(self) -> str: + return self.current.execution_token -def _signal_actions(actions): +def _signal_actions(actions: list[pb.OperationAction]) -> list[pb.OperationAction]: return [a for a in actions if a.HasField("sendSignal")] -def _start_actions(actions): +def _start_actions(actions: list[pb.OperationAction]) -> list[pb.OperationAction]: return [a for a in actions if a.HasField("startNewOrchestration")] -def _creation_options(**kwargs): - base = dict(schedule_id=SCHEDULE_ID, orchestration_name="my_orch", interval=timedelta(seconds=30)) +def _creation_options(**kwargs: Any) -> ScheduleCreationOptions: + base: dict[str, Any] = dict( + schedule_id=SCHEDULE_ID, orchestration_name="my_orch", interval=timedelta(seconds=30)) base.update(kwargs) return ScheduleCreationOptions(**base) @@ -74,8 +84,9 @@ def test_create_activates_and_signals_run(self): h = Harness() _, actions = h.run("create_schedule", _creation_options()) - assert h.state_dict["status"] == ScheduleStatus.ACTIVE.value - assert h.state_dict["schedule_created_at"] is not None + state = h.current + assert state.status == ScheduleStatus.ACTIVE + assert state.schedule_created_at is not None signals = _signal_actions(actions) assert len(signals) == 1 assert signals[0].sendSignal.name == "run_schedule" @@ -88,7 +99,7 @@ def test_create_twice_updates_in_place(self): h.run("create_schedule", _creation_options(interval=timedelta(seconds=60))) # Re-creation refreshes the execution token. assert h.token != first_token - assert h.state_dict["status"] == ScheduleStatus.ACTIVE.value + assert h.current.status == ScheduleStatus.ACTIVE class TestPauseResume: @@ -97,11 +108,11 @@ def test_pause_then_resume(self): h.run("create_schedule", _creation_options()) h.run("pause_schedule") - assert h.state_dict["status"] == ScheduleStatus.PAUSED.value - assert h.state_dict["next_run_at"] is None + assert h.current.status == ScheduleStatus.PAUSED + assert h.current.next_run_at is None _, actions = h.run("resume_schedule") - assert h.state_dict["status"] == ScheduleStatus.ACTIVE.value + assert h.current.status == ScheduleStatus.ACTIVE assert len(_signal_actions(actions)) == 1 def test_pause_when_not_active_raises(self): @@ -118,7 +129,9 @@ def test_update_changes_config_and_resignals(self): h.run("create_schedule", _creation_options()) _, actions = h.run("update_schedule", ScheduleUpdateOptions(interval=timedelta(seconds=120))) - assert abs(h.state_dict["schedule_configuration"]["interval_seconds"] - 120) < 0.001 + config = h.current.schedule_configuration + assert config is not None + assert config.interval == timedelta(seconds=120) assert len(_signal_actions(actions)) == 1 def test_update_no_change_does_not_signal(self): @@ -140,7 +153,7 @@ def test_runs_orchestration_when_due_and_rearms(self): starts = _start_actions(actions) assert len(starts) == 1 assert starts[0].startNewOrchestration.name == "my_orch" - assert h.state_dict["last_run_at"] is not None + assert h.current.last_run_at is not None # Re-arm signal should carry a future scheduled time. signals = _signal_actions(actions) @@ -180,4 +193,4 @@ def test_delete_clears_state(self): h = Harness() h.run("create_schedule", _creation_options()) h.run("delete") - assert h.state_dict is None + assert h.state() is None diff --git a/tests/durabletask/test_serialization.py b/tests/durabletask/test_serialization.py index ae797c2f..02861dcd 100644 --- a/tests/durabletask/test_serialization.py +++ b/tests/durabletask/test_serialization.py @@ -12,7 +12,7 @@ from collections import namedtuple from dataclasses import dataclass from types import SimpleNamespace -from typing import List, Optional, Union, get_args +from typing import Any, List, Optional, Union, get_args import pytest @@ -422,6 +422,33 @@ def test_coerce_dict_values_recursively(): assert isinstance(result["home"], Address) +def test_coerce_bare_any_returns_value_unchanged(): + # ``Any`` carries no type info; the parsed value is already the result. + value = {"k": "v"} + assert coerce_to_type(value, Any) is value + + +def test_coerce_optional_any_returns_value_unchanged(): + # ``Any | None`` must pass the value through rather than raising on the + # ``isinstance(value, Any)`` check inside the union loop. + value = {"k": "v"} + assert coerce_to_type(value, Optional[Any]) == {"k": "v"} + + +def test_coerce_dataclass_with_any_field_round_trips(): + # A plain dataclass whose field is annotated ``Any | None`` must reconstruct + # with that field left as the raw parsed JSON, not crash. + @dataclass + class Envelope: + name: str + payload: Any | None = None + + restored = from_json(to_json(Envelope("x", {"nested": [1, 2]})), Envelope) + assert isinstance(restored, Envelope) + assert restored.name == "x" + assert restored.payload == {"nested": [1, 2]} + + # ----- from_json converter hook (PR #154 follow-up) ----- From cd53b73b2f3ab3f182f96ca633ed2569a5c1b075 Mon Sep 17 00:00:00 2001 From: Andy Staples Date: Mon, 29 Jun 2026 10:59:11 -0600 Subject: [PATCH 8/9] PR Feedback --- durabletask/scheduled/client.py | 19 +++--- durabletask/scheduled/registration.py | 3 +- durabletask/testing/in_memory_backend.py | 14 +++- .../test_delayed_signal_normalization.py | 65 +++++++++++++++++++ 4 files changed, 89 insertions(+), 12 deletions(-) create mode 100644 tests/durabletask/test_delayed_signal_normalization.py diff --git a/durabletask/scheduled/client.py b/durabletask/scheduled/client.py index 4e327b68..868e5564 100644 --- a/durabletask/scheduled/client.py +++ b/durabletask/scheduled/client.py @@ -108,10 +108,11 @@ def get_schedule(self, schedule_id: str) -> ScheduleDescription | None: except ScheduleNotFoundError: return None - def list_schedules(self, filter: ScheduleQuery | None = None) -> list[ScheduleDescription]: + def list_schedules(self, schedule_query: ScheduleQuery | None = None) -> list[ScheduleDescription]: """List schedules matching the given filter criteria.""" - prefix = filter.schedule_id_prefix if filter and filter.schedule_id_prefix else "" - page_size = filter.page_size if filter and filter.page_size else ScheduleQuery.DEFAULT_PAGE_SIZE + prefix = schedule_query.schedule_id_prefix if schedule_query and schedule_query.schedule_id_prefix else "" + page_size = (schedule_query.page_size if schedule_query and schedule_query.page_size + else ScheduleQuery.DEFAULT_PAGE_SIZE) query = EntityQuery( instance_id_starts_with=f"@{ENTITY_NAME}@{prefix}", include_state=True, @@ -122,20 +123,20 @@ def list_schedules(self, filter: ScheduleQuery | None = None) -> list[ScheduleDe state = metadata.get_typed_state(ScheduleState) if state is None or state.schedule_configuration is None: continue - if not self._matches_filter(state, filter): + if not self._matches_filter(state, schedule_query): continue results.append(state.to_description()) return results @staticmethod - def _matches_filter(state: ScheduleState, filter: ScheduleQuery | None) -> bool: - if filter is None: + def _matches_filter(state: ScheduleState, schedule_query: ScheduleQuery | None) -> bool: + if schedule_query is None: return True - if filter.status is not None and state.status != filter.status: + if schedule_query.status is not None and state.status != schedule_query.status: return False created_at = state.schedule_created_at - if filter.created_from is not None and not (created_at and created_at > filter.created_from): + if schedule_query.created_from is not None and not (created_at and created_at > schedule_query.created_from): return False - if filter.created_to is not None and not (created_at and created_at < filter.created_to): + if schedule_query.created_to is not None and not (created_at and created_at < schedule_query.created_to): return False return True diff --git a/durabletask/scheduled/registration.py b/durabletask/scheduled/registration.py index 858bca76..4c98a611 100644 --- a/durabletask/scheduled/registration.py +++ b/durabletask/scheduled/registration.py @@ -2,8 +2,7 @@ # Licensed under the MIT License. from durabletask.worker import TaskHubGrpcWorker -from durabletask.scheduled.orchestrator import \ - execute_schedule_operation_orchestrator +from durabletask.scheduled.orchestrator import execute_schedule_operation_orchestrator from durabletask.scheduled.schedule_entity import ENTITY_NAME, Schedule diff --git a/durabletask/testing/in_memory_backend.py b/durabletask/testing/in_memory_backend.py index d112ef25..3f45bc5f 100644 --- a/durabletask/testing/in_memory_backend.py +++ b/durabletask/testing/in_memory_backend.py @@ -153,6 +153,13 @@ def __init__(self, max_history_size: int = 10000, port: int = 50051): self._logger = logging.getLogger(__name__) self._shutdown_event = threading.Event() self._work_available = threading.Event() + # Monotonic lifecycle counter, bumped on every stop()/reset(). Background + # timers (e.g. delayed entity signals) capture it when scheduled and bail + # if it has changed by the time they fire, so a timer created before a + # stop/reset cannot mutate a subsequently-restarted or cleared backend. + # Unlike ``_shutdown_event`` (which reset() clears to allow restart), this + # only ever moves forward, so it reliably invalidates stale timers. + self._generation = 0 def start(self) -> str: """ @@ -183,6 +190,7 @@ def stop(self, grace: float | None = None): """ self._shutdown_event.set() self._work_available.set() # Unblock GetWorkItems loops + self._generation += 1 if self._server: stop_future = self._server.stop(grace) stop_future.wait() @@ -208,6 +216,7 @@ def reset(self): self._state_waiters.clear() self._shutdown_event.clear() self._work_available.clear() + self._generation += 1 # gRPC Service Methods @@ -1799,11 +1808,14 @@ def _schedule_delayed_entity_operation(self, entity_id: str, event: pb.HistoryEv in-memory backend. """ delay = max(0.0, (fire_at - datetime.now(timezone.utc)).total_seconds()) + # Capture the lifecycle generation so a timer outliving a stop()/reset() + # does not enqueue into a restarted or cleared backend. + scheduled_generation = self._generation def fire() -> None: time.sleep(delay) with self._lock: - if self._shutdown_event.is_set(): + if self._shutdown_event.is_set() or self._generation != scheduled_generation: return self._queue_entity_operation(entity_id, event) diff --git a/tests/durabletask/test_delayed_signal_normalization.py b/tests/durabletask/test_delayed_signal_normalization.py new file mode 100644 index 00000000..c2d1cbfc --- /dev/null +++ b/tests/durabletask/test_delayed_signal_normalization.py @@ -0,0 +1,65 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Unit tests for delayed-signal datetime normalization and in-memory backend +timer lifecycle (regression coverage for PR #160 review feedback).""" + +import time +from datetime import datetime, timedelta, timezone + +import durabletask.internal.orchestrator_service_pb2 as pb +from durabletask.entities import EntityInstanceId +from durabletask.internal import helpers +from durabletask.internal.client_helpers import build_signal_entity_req +from durabletask.testing import create_test_backend + +from tests.durabletask._port_utils import find_free_port + + +class TestSignalTimeNormalization: + def test_build_signal_entity_req_naive_matches_aware(self): + entity_id = EntityInstanceId("Recorder", "k") + naive = datetime(2030, 1, 1, 12, 0, 0) + aware = datetime(2030, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + + req_naive = build_signal_entity_req(entity_id, "ping", signal_time=naive) + req_aware = build_signal_entity_req(entity_id, "ping", signal_time=aware) + + assert req_naive.scheduledTime.seconds == req_aware.scheduledTime.seconds + + def test_new_signal_entity_action_naive_matches_aware(self): + entity_id = EntityInstanceId("Recorder", "k") + naive = datetime(2030, 1, 1, 12, 0, 0) + aware = datetime(2030, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + + a_naive = helpers.new_signal_entity_action(1, entity_id, "ping", None, "r1", naive) + a_aware = helpers.new_signal_entity_action(1, entity_id, "ping", None, "r2", aware) + + naive_ts = a_naive.sendEntityMessage.entityOperationSignaled.scheduledTime + aware_ts = a_aware.sendEntityMessage.entityOperationSignaled.scheduledTime + assert naive_ts.seconds == aware_ts.seconds + + +class TestDelayedTimerLifecycle: + def test_delayed_operation_does_not_fire_into_reset_backend(self): + backend = create_test_backend(port=find_free_port()) + try: + entity_id = "@recorder@k" + event = pb.HistoryEvent( + eventId=-1, + entityOperationSignaled=pb.EntityOperationSignaledEvent(operation="ping"), + ) + # Schedule the op to fire shortly in the future, then reset before it + # fires so the timer wakes into a reset (new-generation) backend. + fire_at = datetime.now(timezone.utc) + timedelta(seconds=0.3) + backend._schedule_delayed_entity_operation( # pyright: ignore[reportPrivateUsage] + entity_id, event, fire_at) + backend.reset() + + # Give the timer thread time to wake; it must observe the new + # generation and refuse to recreate state in the reset backend. + time.sleep(0.6) + assert backend._entities == {} # pyright: ignore[reportPrivateUsage] + finally: + backend.stop() + backend.reset() From 9c84751f3dde88f6751224081011c161151af14e Mon Sep 17 00:00:00 2001 From: Andy Staples Date: Mon, 29 Jun 2026 13:10:24 -0600 Subject: [PATCH 9/9] Address PR feedback: capability flag, datetime filter fix, helper consolidation - Advertise WORKER_CAPABILITY_SCHEDULED_TASKS from configure_scheduled_tasks via a new TaskHubGrpcWorker.add_capability(), mirroring .NET's UseScheduledTasks (currently inert against the DTS backend, which only gates on HistoryStreaming). - Fix naive-vs-aware datetime TypeError in schedule list filtering: ScheduleQuery normalizes its created_from/created_to bounds to aware UTC, and the client filter defensively normalizes the stored timestamp. Bounds remain exclusive to match the .NET ScheduledTasks implementation. - Consolidate the duplicated _ensure_aware helpers into a single public helpers.ensure_aware, removing the cross-module private-usage warning. - Document that list_schedules applies status/created filters client-side, so pages may be underfilled (matches .NET). - Add unit tests for query normalization, exclusive filter bounds, and the scheduled-tasks capability advertisement. --- durabletask/internal/helpers.py | 17 +++- durabletask/scheduled/client.py | 16 +++- durabletask/scheduled/models.py | 9 ++ durabletask/scheduled/registration.py | 2 + durabletask/scheduled/schedule_entity.py | 15 +--- durabletask/worker.py | 20 +++++ .../test_client_filter_and_capability.py | 84 +++++++++++++++++++ tests/durabletask/scheduled/test_models.py | 25 +++++- 8 files changed, 173 insertions(+), 15 deletions(-) create mode 100644 tests/durabletask/scheduled/test_client_filter_and_capability.py diff --git a/durabletask/internal/helpers.py b/durabletask/internal/helpers.py index cb0a815b..dd8940c5 100644 --- a/durabletask/internal/helpers.py +++ b/durabletask/internal/helpers.py @@ -2,7 +2,7 @@ # Licensed under the MIT License. import traceback -from datetime import datetime +from datetime import datetime, timezone from google.protobuf import timestamp_pb2, wrappers_pb2 @@ -302,6 +302,21 @@ def new_timestamp(dt: datetime) -> timestamp_pb2.Timestamp: return ts +def ensure_aware(value: datetime | None) -> datetime | None: + """Return ``value`` as a timezone-aware datetime, assuming UTC when naive. + + A naive datetime is tagged as UTC; an already-aware datetime is returned + unchanged. Useful before comparing user-supplied datetimes against the + SDK's always-aware-UTC timestamps to avoid "can't compare offset-naive and + offset-aware datetimes". + """ + if value is None: + return None + if value.tzinfo is None: + return value.replace(tzinfo=timezone.utc) + return value + + def new_create_sub_orchestration_action( id: int, name: str, diff --git a/durabletask/scheduled/client.py b/durabletask/scheduled/client.py index 868e5564..dc6767ac 100644 --- a/durabletask/scheduled/client.py +++ b/durabletask/scheduled/client.py @@ -6,6 +6,7 @@ from durabletask.client import (EntityQuery, OrchestrationStatus, TaskHubGrpcClient) from durabletask.entities import EntityInstanceId +from durabletask.internal.helpers import ensure_aware from durabletask.scheduled import transitions from durabletask.scheduled.exceptions import ScheduleNotFoundError from durabletask.scheduled.models import (ScheduleCreationOptions, @@ -109,7 +110,14 @@ def get_schedule(self, schedule_id: str) -> ScheduleDescription | None: return None def list_schedules(self, schedule_query: ScheduleQuery | None = None) -> list[ScheduleDescription]: - """List schedules matching the given filter criteria.""" + """List schedules matching the given filter criteria. + + > [!NOTE] + > The ``status`` and ``created_from``/``created_to`` filters are applied + > client-side after each page of entities is fetched, so an individual + > page may contain fewer than ``page_size`` matches (or none) even when + > more matching schedules exist. This mirrors the .NET implementation. + """ prefix = schedule_query.schedule_id_prefix if schedule_query and schedule_query.schedule_id_prefix else "" page_size = (schedule_query.page_size if schedule_query and schedule_query.page_size else ScheduleQuery.DEFAULT_PAGE_SIZE) @@ -134,7 +142,11 @@ def _matches_filter(state: ScheduleState, schedule_query: ScheduleQuery | None) return True if schedule_query.status is not None and state.status != schedule_query.status: return False - created_at = state.schedule_created_at + # ``ScheduleQuery`` normalizes its bounds to aware UTC; defensively + # normalize the stored timestamp too (a payload could in principle carry + # a naive value) so the comparison can never raise on naive-vs-aware. + # Bounds are exclusive, matching the .NET ScheduledTasks implementation. + created_at = ensure_aware(state.schedule_created_at) if schedule_query.created_from is not None and not (created_at and created_at > schedule_query.created_from): return False if schedule_query.created_to is not None and not (created_at and created_at < schedule_query.created_to): diff --git a/durabletask/scheduled/models.py b/durabletask/scheduled/models.py index ab9a4dbe..60f59194 100644 --- a/durabletask/scheduled/models.py +++ b/durabletask/scheduled/models.py @@ -6,6 +6,7 @@ from datetime import datetime, timedelta from typing import Any +from durabletask.internal.helpers import ensure_aware from durabletask.scheduled.schedule_status import ScheduleStatus MINIMUM_INTERVAL = timedelta(seconds=1) @@ -133,6 +134,14 @@ class ScheduleQuery: created_to: datetime | None = None page_size: int | None = None + def __post_init__(self): + # Coerce the time-window bounds to timezone-aware UTC. Schedule + # timestamps are always stored as aware UTC, so normalizing here ensures + # a naive bound supplied by a caller can never reach the filter + # comparison and raise "can't compare offset-naive and offset-aware". + self.created_from = ensure_aware(self.created_from) + self.created_to = ensure_aware(self.created_to) + @dataclass class ScheduleDescription: diff --git a/durabletask/scheduled/registration.py b/durabletask/scheduled/registration.py index 4c98a611..2a835094 100644 --- a/durabletask/scheduled/registration.py +++ b/durabletask/scheduled/registration.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +import durabletask.internal.orchestrator_service_pb2 as pb from durabletask.worker import TaskHubGrpcWorker from durabletask.scheduled.orchestrator import execute_schedule_operation_orchestrator from durabletask.scheduled.schedule_entity import ENTITY_NAME, Schedule @@ -18,3 +19,4 @@ def configure_scheduled_tasks(worker: TaskHubGrpcWorker) -> None: """ worker.add_entity(Schedule, ENTITY_NAME) worker.add_orchestrator(execute_schedule_operation_orchestrator) + worker.add_capability(pb.WORKER_CAPABILITY_SCHEDULED_TASKS) diff --git a/durabletask/scheduled/schedule_entity.py b/durabletask/scheduled/schedule_entity.py index c72c9a28..90f385e7 100644 --- a/durabletask/scheduled/schedule_entity.py +++ b/durabletask/scheduled/schedule_entity.py @@ -6,6 +6,7 @@ from typing import Any from durabletask.entities import DurableEntity, EntityInstanceId +from durabletask.internal.helpers import ensure_aware from durabletask.scheduled import transitions from durabletask.scheduled.exceptions import ScheduleInvalidTransitionError from durabletask.scheduled.models import (ScheduleConfiguration, @@ -26,14 +27,6 @@ def _now() -> datetime: return datetime.now(timezone.utc) -def _ensure_aware(value: datetime | None) -> datetime | None: - if value is None: - return None - if value.tzinfo is None: - return value.replace(tzinfo=timezone.utc) - return value - - class Schedule(DurableEntity): """Entity that manages the state and execution of a scheduled task. @@ -184,7 +177,7 @@ def run_schedule(self, execution_token: str) -> None: if state.status != ScheduleStatus.ACTIVE: raise ValueError("Schedule must be in Active status to run.") - end_at = _ensure_aware(config.end_at) + end_at = ensure_aware(config.end_at) if end_at is not None and _now() > end_at: logger.info(f"Schedule '{config.schedule_id}' has passed its end time; deleting.") state.next_run_at = None @@ -225,10 +218,10 @@ def _start_orchestration(self, config: ScheduleConfiguration, scheduled_run_time def _determine_next_run_time(self, state: ScheduleState, config: ScheduleConfiguration) -> datetime: if state.next_run_at is not None: - return _ensure_aware(state.next_run_at) # type: ignore[return-value] + return ensure_aware(state.next_run_at) # type: ignore[return-value] now = _now() - start_time = _ensure_aware(config.start_at) or _ensure_aware(state.schedule_created_at) or now + start_time = ensure_aware(config.start_at) or ensure_aware(state.schedule_created_at) or now time_since_start = now - start_time # Next run is in the future relative to the start time. diff --git a/durabletask/worker.py b/durabletask/worker.py index f948b285..ea42411e 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -577,6 +577,11 @@ def __init__( self._work_item_filters: WorkItemFilters | None = None self._auto_generate_work_item_filters: bool = False self._runLoop: Thread | None = None + # Extra worker capabilities advertised to the backend in + # GetWorkItemsRequest (in addition to ones derived from worker state such + # as LARGE_PAYLOADS). Feature-enablement helpers like + # durabletask.scheduled.configure_scheduled_tasks register theirs here. + self._capabilities: set[int] = set() @property def concurrency_options(self) -> ConcurrencyOptions: @@ -636,6 +641,20 @@ def add_entity(self, fn: task.Entity[Any, Any], name: str | None = None) -> str: ) return self._registry.add_entity(fn, name) + def add_capability(self, capability: int) -> None: + """Advertise a worker capability to the backend in ``GetWorkItemsRequest``. + + Most users do not call this directly; feature-enablement helpers such as + :func:`durabletask.scheduled.configure_scheduled_tasks` use it to + advertise the capabilities (``pb.WORKER_CAPABILITY_*``) their feature + relies on. + """ + if self._is_running: + raise RuntimeError( + "Capabilities cannot be added while the worker is running." + ) + self._capabilities.add(capability) + def use_versioning(self, version: VersioningOptions) -> None: """Initializes versioning options for sub-orchestrators and activities.""" if self._is_running: @@ -861,6 +880,7 @@ def should_invalidate_connection(rpc_error: grpc.RpcError) -> bool: capabilities: list[Any] = [] if self._payload_store is not None: capabilities.append(pb.WORKER_CAPABILITY_LARGE_PAYLOADS) + capabilities.extend(sorted(self._capabilities)) get_work_items_request = pb.GetWorkItemsRequest( maxConcurrentOrchestrationWorkItems=self._concurrency_options.maximum_concurrent_orchestration_work_items, maxConcurrentActivityWorkItems=self._concurrency_options.maximum_concurrent_activity_work_items, diff --git a/tests/durabletask/scheduled/test_client_filter_and_capability.py b/tests/durabletask/scheduled/test_client_filter_and_capability.py new file mode 100644 index 00000000..133de4eb --- /dev/null +++ b/tests/durabletask/scheduled/test_client_filter_and_capability.py @@ -0,0 +1,84 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Unit tests for ScheduleClient filtering and scheduled-tasks worker capability.""" + +from datetime import datetime, timedelta, timezone + +import durabletask.internal.orchestrator_service_pb2 as pb +from durabletask.scheduled import configure_scheduled_tasks +from durabletask.scheduled.client import ScheduledTaskClient +from durabletask.scheduled.models import (ScheduleConfiguration, + ScheduleCreationOptions, ScheduleQuery, + ScheduleState) +from durabletask.scheduled.schedule_status import ScheduleStatus +from durabletask.worker import TaskHubGrpcWorker + + +def _state_created_at(created_at: datetime) -> ScheduleState: + state = ScheduleState() + state.status = ScheduleStatus.ACTIVE + state.schedule_created_at = created_at + state.schedule_configuration = ScheduleConfiguration.from_create_options( + ScheduleCreationOptions(schedule_id="s1", orchestration_name="orch", + interval=timedelta(seconds=5))) + return state + + +class TestMatchesFilter: + def test_naive_bound_does_not_crash_against_aware_created_at(self): + state = _state_created_at(datetime(2026, 1, 15, tzinfo=timezone.utc)) + # A naive bound is normalized by ScheduleQuery, so the comparison must + # not raise "can't compare offset-naive and offset-aware". + q = ScheduleQuery(created_from=datetime(2026, 1, 1, 0, 0, 0)) + assert ScheduledTaskClient._matches_filter(state, q) is True + + def test_created_from_is_exclusive(self): + # Bounds match .NET's exclusive semantics: a schedule created exactly at + # the bound is excluded. + boundary = datetime(2026, 1, 1, tzinfo=timezone.utc) + state = _state_created_at(boundary) + q = ScheduleQuery(created_from=boundary) + assert ScheduledTaskClient._matches_filter(state, q) is False + + def test_created_to_is_exclusive(self): + boundary = datetime(2026, 1, 1, tzinfo=timezone.utc) + state = _state_created_at(boundary) + q = ScheduleQuery(created_to=boundary) + assert ScheduledTaskClient._matches_filter(state, q) is False + + def test_inside_window_matches(self): + state = _state_created_at(datetime(2026, 1, 15, tzinfo=timezone.utc)) + q = ScheduleQuery( + created_from=datetime(2026, 1, 1, tzinfo=timezone.utc), + created_to=datetime(2026, 2, 1, tzinfo=timezone.utc), + ) + assert ScheduledTaskClient._matches_filter(state, q) is True + + def test_outside_window_is_excluded(self): + state = _state_created_at(datetime(2026, 3, 1, tzinfo=timezone.utc)) + q = ScheduleQuery(created_to=datetime(2026, 2, 1, tzinfo=timezone.utc)) + assert ScheduledTaskClient._matches_filter(state, q) is False + + def test_status_filter(self): + state = _state_created_at(datetime(2026, 1, 1, tzinfo=timezone.utc)) + assert ScheduledTaskClient._matches_filter(state, ScheduleQuery(status=ScheduleStatus.ACTIVE)) is True + assert ScheduledTaskClient._matches_filter(state, ScheduleQuery(status=ScheduleStatus.PAUSED)) is False + + +class TestScheduledTasksCapability: + def test_configure_advertises_scheduled_tasks_capability(self): + worker = TaskHubGrpcWorker() + configure_scheduled_tasks(worker) + assert pb.WORKER_CAPABILITY_SCHEDULED_TASKS in worker._capabilities # pyright: ignore[reportPrivateUsage] + + def test_capability_absent_by_default(self): + worker = TaskHubGrpcWorker() + assert pb.WORKER_CAPABILITY_SCHEDULED_TASKS not in worker._capabilities # pyright: ignore[reportPrivateUsage] + + def test_add_capability_rejected_while_running(self): + import pytest + worker = TaskHubGrpcWorker() + worker._is_running = True # pyright: ignore[reportPrivateUsage] + with pytest.raises(RuntimeError): + worker.add_capability(pb.WORKER_CAPABILITY_SCHEDULED_TASKS) diff --git a/tests/durabletask/scheduled/test_models.py b/tests/durabletask/scheduled/test_models.py index ddd7c560..7f1d09f0 100644 --- a/tests/durabletask/scheduled/test_models.py +++ b/tests/durabletask/scheduled/test_models.py @@ -9,7 +9,7 @@ from durabletask.serialization import JsonDataConverter from durabletask.scheduled.models import (ScheduleConfiguration, - ScheduleCreationOptions, + ScheduleCreationOptions, ScheduleQuery, ScheduleState, ScheduleUpdateOptions) from durabletask.scheduled.schedule_status import ScheduleStatus @@ -142,3 +142,26 @@ def test_refresh_execution_token_changes_token(self): original = state.execution_token state.refresh_execution_token() assert state.execution_token != original + + +class TestScheduleQueryNormalization: + def test_naive_bounds_are_coerced_to_aware_utc(self): + q = ScheduleQuery( + created_from=datetime(2026, 1, 1, 0, 0, 0), + created_to=datetime(2026, 2, 1, 0, 0, 0), + ) + assert q.created_from is not None and q.created_to is not None + assert q.created_from == datetime(2026, 1, 1, tzinfo=timezone.utc) + assert q.created_to == datetime(2026, 2, 1, tzinfo=timezone.utc) + assert q.created_from.tzinfo is timezone.utc + assert q.created_to.tzinfo is timezone.utc + + def test_aware_bounds_are_preserved(self): + start = datetime(2026, 1, 1, tzinfo=timezone.utc) + q = ScheduleQuery(created_from=start) + assert q.created_from == start + + def test_none_bounds_stay_none(self): + q = ScheduleQuery() + assert q.created_from is None + assert q.created_to is None