diff --git a/CHANGELOG.md b/CHANGELOG.md index a85930d8..6376c6c0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,18 @@ adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ADDED +- Added `durabletask.scheduled`, a recurring schedule feature built on durable + entities. Use `configure_scheduled_tasks(worker)` to enable it on a worker, + then manage schedules from the client via `ScheduledTaskClient` (and the + per-schedule `ScheduleClient`). Supports creating, describing, listing, + updating, pausing, resuming, and deleting schedules with configurable + `interval`, `start_at`, `end_at`, and `start_immediately_if_late` options. +- Added an optional `signal_time` parameter to `EntityContext.signal_entity` + and `DurableEntity.signal_entity`, allowing an entity signal to be scheduled + for future delivery. +- Added an optional `signal_time` parameter to `OrchestrationContext.signal_entity` + and to the client `signal_entity` methods (sync and async), allowing entity + signals to be scheduled for future delivery from orchestrations and clients. - Added a pluggable `DataConverter` (`durabletask.serialization`) accepted by `TaskHubGrpcWorker`, `TaskHubGrpcClient`, and `AsyncTaskHubGrpcClient` via a `data_converter` argument. Every payload boundary (inputs, outputs, events, diff --git a/durabletask/client.py b/durabletask/client.py index 725bd3dc..c810248e 100644 --- a/durabletask/client.py +++ b/durabletask/client.py @@ -825,8 +825,10 @@ def purge_orchestrations_by(self, def signal_entity(self, entity_instance_id: EntityInstanceId, operation_name: str, - input: Any | None = None) -> None: - req = build_signal_entity_req(entity_instance_id, operation_name, input, self._data_converter) + input: Any | None = None, + signal_time: datetime | None = None) -> None: + req = build_signal_entity_req( + entity_instance_id, operation_name, input, signal_time, self._data_converter) self._logger.info(f"Signaling entity '{entity_instance_id}' operation '{operation_name}'.") if self._payload_store is not None: payload_helpers.externalize_payloads( @@ -1308,8 +1310,10 @@ async def purge_orchestrations_by(self, async def signal_entity(self, entity_instance_id: EntityInstanceId, operation_name: str, - input: Any | None = None) -> None: - req = build_signal_entity_req(entity_instance_id, operation_name, input, self._data_converter) + input: Any | None = None, + signal_time: datetime | None = None) -> None: + req = build_signal_entity_req( + entity_instance_id, operation_name, input, signal_time, self._data_converter) self._logger.info(f"Signaling entity '{entity_instance_id}' operation '{operation_name}'.") if self._payload_store is not None: await payload_helpers.externalize_payloads_async( diff --git a/durabletask/entities/durable_entity.py b/durabletask/entities/durable_entity.py index b93b21af..cfb85a8d 100644 --- a/durabletask/entities/durable_entity.py +++ b/durabletask/entities/durable_entity.py @@ -2,6 +2,7 @@ # Licensed under the MIT License. from typing import Any, TypeVar, overload +from datetime import datetime from durabletask.entities.entity_context import EntityContext from durabletask.entities.entity_instance_id import EntityInstanceId @@ -52,7 +53,9 @@ def set_state(self, state: Any): """ self.entity_context.set_state(state) - def signal_entity(self, entity_instance_id: EntityInstanceId, operation: str, input: Any | None = None) -> None: + def signal_entity(self, entity_instance_id: EntityInstanceId, operation: str, + input: Any | None = None, + signal_time: datetime | None = None) -> None: """Signal another entity to perform an operation. Parameters @@ -63,8 +66,12 @@ def signal_entity(self, entity_instance_id: EntityInstanceId, operation: str, in The operation to perform on the entity. input : Any, optional The input to provide to the entity for the operation. + signal_time : datetime, optional + The time at which the signal should be delivered. If None, the signal is + delivered as soon as possible. Use this to schedule a future operation, + for example to have an entity wake itself up at a later time. """ - self.entity_context.signal_entity(entity_instance_id, operation, input) + self.entity_context.signal_entity(entity_instance_id, operation, input, signal_time) def schedule_new_orchestration(self, orchestration_name: str, input: Any | None = None, instance_id: str | None = None) -> str: """Schedule a new orchestration instance. diff --git a/durabletask/entities/entity_context.py b/durabletask/entities/entity_context.py index 7a44c22e..9721b365 100644 --- a/durabletask/entities/entity_context.py +++ b/durabletask/entities/entity_context.py @@ -1,8 +1,10 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +from datetime import datetime from typing import TYPE_CHECKING, Any, TypeVar, overload import uuid +from google.protobuf import timestamp_pb2 from durabletask.entities.entity_instance_id import EntityInstanceId from durabletask.internal import helpers from durabletask.internal.entity_state_shim import StateShim @@ -88,7 +90,9 @@ def set_state(self, new_state: Any) -> None: """ self._state.set_state(new_state) - def signal_entity(self, entity_instance_id: EntityInstanceId, operation: str, input: Any | None = None) -> None: + def signal_entity(self, entity_instance_id: EntityInstanceId, operation: str, + input: Any | None = None, + signal_time: datetime | None = None) -> None: """Signal another entity to perform an operation. Parameters @@ -99,15 +103,23 @@ def signal_entity(self, entity_instance_id: EntityInstanceId, operation: str, in The operation to perform on the entity. input : Any, optional The input to provide to the entity for the operation. + signal_time : datetime, optional + The time at which the signal should be delivered. If None, the signal is + delivered as soon as possible. Use this to schedule a future operation, + for example to have an entity wake itself up at a later time. """ encoded_input: str | None = self._data_converter.serialize(input) + scheduled_time: timestamp_pb2.Timestamp | None = None + if signal_time is not None: + scheduled_time = timestamp_pb2.Timestamp() + scheduled_time.FromDatetime(signal_time) self._state.add_operation_action( pb.OperationAction( sendSignal=pb.SendSignalAction( instanceId=str(entity_instance_id), name=operation, input=helpers.get_string_value(encoded_input), - scheduledTime=None, + scheduledTime=scheduled_time, requestTime=None, parentTraceContext=None, ) diff --git a/durabletask/internal/client_helpers.py b/durabletask/internal/client_helpers.py index a42c3c43..fe0c828b 100644 --- a/durabletask/internal/client_helpers.py +++ b/durabletask/internal/client_helpers.py @@ -205,14 +205,16 @@ def build_signal_entity_req( entity_instance_id: EntityInstanceId, operation_name: str, input: Any | None = None, + signal_time: datetime | None = None, data_converter: DataConverter | None = None) -> pb.SignalEntityRequest: """Build a SignalEntityRequest for signaling an entity.""" + scheduled_time = helpers.new_timestamp(signal_time) if signal_time is not None else None return pb.SignalEntityRequest( instanceId=str(entity_instance_id), name=operation_name, input=helpers.get_string_value(_serialize(input, data_converter)), requestId=str(uuid.uuid4()), - scheduledTime=None, + scheduledTime=scheduled_time, parentTraceContext=None, requestTime=helpers.new_timestamp(datetime.now(timezone.utc)) ) diff --git a/durabletask/internal/helpers.py b/durabletask/internal/helpers.py index 2342afdd..dd8940c5 100644 --- a/durabletask/internal/helpers.py +++ b/durabletask/internal/helpers.py @@ -2,7 +2,7 @@ # Licensed under the MIT License. import traceback -from datetime import datetime +from datetime import datetime, timezone from google.protobuf import timestamp_pb2, wrappers_pb2 @@ -255,11 +255,13 @@ def new_signal_entity_action(id: int, entity_id: EntityInstanceId, operation: str, encoded_input: str | None, - request_id: str) -> pb.OrchestratorAction: + request_id: str, + signal_time: datetime | None = None) -> pb.OrchestratorAction: + scheduled_time = new_timestamp(signal_time) if signal_time is not None else None return pb.OrchestratorAction(id=id, sendEntityMessage=pb.SendEntityMessageAction(entityOperationSignaled=pb.EntityOperationSignaledEvent( requestId=request_id, operation=operation, - scheduledTime=None, + scheduledTime=scheduled_time, input=get_string_value(encoded_input), targetInstanceId=get_string_value(str(entity_id)), ))) @@ -300,6 +302,21 @@ def new_timestamp(dt: datetime) -> timestamp_pb2.Timestamp: return ts +def ensure_aware(value: datetime | None) -> datetime | None: + """Return ``value`` as a timezone-aware datetime, assuming UTC when naive. + + A naive datetime is tagged as UTC; an already-aware datetime is returned + unchanged. Useful before comparing user-supplied datetimes against the + SDK's always-aware-UTC timestamps to avoid "can't compare offset-naive and + offset-aware datetimes". + """ + if value is None: + return None + if value.tzinfo is None: + return value.replace(tzinfo=timezone.utc) + return value + + def new_create_sub_orchestration_action( id: int, name: str, diff --git a/durabletask/scheduled/__init__.py b/durabletask/scheduled/__init__.py new file mode 100644 index 00000000..36d215a7 --- /dev/null +++ b/durabletask/scheduled/__init__.py @@ -0,0 +1,38 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Scheduled tasks support for the Durable Task SDK. + +This package provides a recurring schedule feature built on top of durable +entities and a helper orchestrator. Register the entity and orchestrator with a +worker via :func:`configure_scheduled_tasks`, then manage schedules from the +client via :class:`ScheduledTaskClient`. +""" + +from durabletask.scheduled.client import ScheduleClient, ScheduledTaskClient +from durabletask.scheduled.exceptions import (ScheduleClientValidationError, + ScheduleError, + ScheduleInvalidTransitionError, + ScheduleNotFoundError) +from durabletask.scheduled.models import (ScheduleCreationOptions, + ScheduleDescription, ScheduleQuery, + ScheduleUpdateOptions) +from durabletask.scheduled.registration import configure_scheduled_tasks +from durabletask.scheduled.schedule_status import ScheduleStatus + +__all__ = [ + "ScheduledTaskClient", + "ScheduleClient", + "ScheduleCreationOptions", + "ScheduleUpdateOptions", + "ScheduleDescription", + "ScheduleQuery", + "ScheduleStatus", + "ScheduleError", + "ScheduleNotFoundError", + "ScheduleClientValidationError", + "ScheduleInvalidTransitionError", + "configure_scheduled_tasks", +] + +PACKAGE_NAME = "durabletask.scheduled" diff --git a/durabletask/scheduled/client.py b/durabletask/scheduled/client.py new file mode 100644 index 00000000..dc6767ac --- /dev/null +++ b/durabletask/scheduled/client.py @@ -0,0 +1,154 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import logging + +from durabletask.client import (EntityQuery, OrchestrationStatus, + TaskHubGrpcClient) +from durabletask.entities import EntityInstanceId +from durabletask.internal.helpers import ensure_aware +from durabletask.scheduled import transitions +from durabletask.scheduled.exceptions import ScheduleNotFoundError +from durabletask.scheduled.models import (ScheduleCreationOptions, + ScheduleDescription, ScheduleQuery, + ScheduleState, ScheduleUpdateOptions) +from durabletask.scheduled.orchestrator import ( + ScheduleOperationRequest, execute_schedule_operation_orchestrator) +from durabletask.scheduled.schedule_entity import (DELETE_OPERATION, + ENTITY_NAME) + +logger = logging.getLogger("durabletask.scheduled") + + +class ScheduleClient: + """Client for managing a single schedule instance.""" + + def __init__(self, client: TaskHubGrpcClient, schedule_id: str, + *, operation_timeout: float = 60): + if not schedule_id: + raise ValueError("schedule_id cannot be empty.") + self._client = client + self._schedule_id = schedule_id + self._entity_id = EntityInstanceId(ENTITY_NAME, schedule_id) + self._operation_timeout = operation_timeout + + @property + def schedule_id(self) -> str: + """Gets the ID of this schedule.""" + return self._schedule_id + + def _run_operation(self, operation_name: str, input: object | None = None) -> None: + request = ScheduleOperationRequest( + entity_id=str(self._entity_id), + operation_name=operation_name, + input=input, + ) + instance_id = self._client.schedule_new_orchestration( + execute_schedule_operation_orchestrator, input=request) + state = self._client.wait_for_orchestration_completion( + instance_id, timeout=self._operation_timeout) + if state is None or state.runtime_status != OrchestrationStatus.COMPLETED: + failure = state.failure_details if state else None + message = failure.message if failure else "unknown error" + raise RuntimeError( + f"Failed to '{operation_name}' schedule '{self._schedule_id}': {message}") + + def create(self, options: ScheduleCreationOptions) -> None: + """Create or update this schedule with the given configuration.""" + self._run_operation(transitions.CREATE_SCHEDULE, options) + + def update(self, options: ScheduleUpdateOptions) -> None: + """Update this schedule's configuration.""" + self._run_operation(transitions.UPDATE_SCHEDULE, options) + + def pause(self) -> None: + """Pause this schedule.""" + self._run_operation(transitions.PAUSE_SCHEDULE) + + def resume(self) -> None: + """Resume this schedule.""" + self._run_operation(transitions.RESUME_SCHEDULE) + + def delete(self) -> None: + """Delete this schedule.""" + self._run_operation(DELETE_OPERATION) + + def describe(self) -> ScheduleDescription: + """Retrieve the current details of this schedule.""" + metadata = self._client.get_entity(self._entity_id, include_state=True) + if metadata is None: + raise ScheduleNotFoundError(self._schedule_id) + state = metadata.get_typed_state(ScheduleState) + if state is None: + raise ScheduleNotFoundError(self._schedule_id) + return state.to_description() + + +class ScheduledTaskClient: + """Client for managing scheduled tasks in a Durable Task application.""" + + def __init__(self, client: TaskHubGrpcClient, *, operation_timeout: float = 60): + self._client = client + self._operation_timeout = operation_timeout + + def get_schedule_client(self, schedule_id: str) -> ScheduleClient: + """Get a handle to manage a specific schedule.""" + return ScheduleClient(self._client, schedule_id, + operation_timeout=self._operation_timeout) + + def create_schedule(self, options: ScheduleCreationOptions) -> ScheduleClient: + """Create a new schedule and return a client for managing it.""" + schedule_client = self.get_schedule_client(options.schedule_id) + schedule_client.create(options) + return schedule_client + + def get_schedule(self, schedule_id: str) -> ScheduleDescription | None: + """Get a schedule description by ID, or None if it does not exist.""" + try: + return self.get_schedule_client(schedule_id).describe() + except ScheduleNotFoundError: + return None + + def list_schedules(self, schedule_query: ScheduleQuery | None = None) -> list[ScheduleDescription]: + """List schedules matching the given filter criteria. + + > [!NOTE] + > The ``status`` and ``created_from``/``created_to`` filters are applied + > client-side after each page of entities is fetched, so an individual + > page may contain fewer than ``page_size`` matches (or none) even when + > more matching schedules exist. This mirrors the .NET implementation. + """ + prefix = schedule_query.schedule_id_prefix if schedule_query and schedule_query.schedule_id_prefix else "" + page_size = (schedule_query.page_size if schedule_query and schedule_query.page_size + else ScheduleQuery.DEFAULT_PAGE_SIZE) + query = EntityQuery( + instance_id_starts_with=f"@{ENTITY_NAME}@{prefix}", + include_state=True, + page_size=page_size, + ) + results: list[ScheduleDescription] = [] + for metadata in self._client.get_all_entities(query): + state = metadata.get_typed_state(ScheduleState) + if state is None or state.schedule_configuration is None: + continue + if not self._matches_filter(state, schedule_query): + continue + results.append(state.to_description()) + return results + + @staticmethod + def _matches_filter(state: ScheduleState, schedule_query: ScheduleQuery | None) -> bool: + if schedule_query is None: + return True + if schedule_query.status is not None and state.status != schedule_query.status: + return False + # ``ScheduleQuery`` normalizes its bounds to aware UTC; defensively + # normalize the stored timestamp too (a payload could in principle carry + # a naive value) so the comparison can never raise on naive-vs-aware. + # Bounds are exclusive, matching the .NET ScheduledTasks implementation. + created_at = ensure_aware(state.schedule_created_at) + if schedule_query.created_from is not None and not (created_at and created_at > schedule_query.created_from): + return False + if schedule_query.created_to is not None and not (created_at and created_at < schedule_query.created_to): + return False + return True diff --git a/durabletask/scheduled/exceptions.py b/durabletask/scheduled/exceptions.py new file mode 100644 index 00000000..f984056e --- /dev/null +++ b/durabletask/scheduled/exceptions.py @@ -0,0 +1,36 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + + +class ScheduleError(Exception): + """Base class for schedule-related errors.""" + + +class ScheduleNotFoundError(ScheduleError): + """Raised when a requested schedule does not exist.""" + + def __init__(self, schedule_id: str): + self.schedule_id = schedule_id + super().__init__(f"Schedule with ID '{schedule_id}' was not found.") + + +class ScheduleClientValidationError(ScheduleError): + """Raised when a schedule operation fails client-side validation.""" + + def __init__(self, schedule_id: str, message: str): + self.schedule_id = schedule_id + super().__init__(f"Validation failed for schedule '{schedule_id}': {message}") + + +class ScheduleInvalidTransitionError(ScheduleError): + """Raised when an operation is not valid for the schedule's current status.""" + + def __init__(self, schedule_id: str, from_status: object, to_status: object, operation_name: str): + self.schedule_id = schedule_id + self.from_status = from_status + self.to_status = to_status + self.operation_name = operation_name + super().__init__( + f"Invalid state transition for schedule '{schedule_id}': operation " + f"'{operation_name}' cannot transition from '{from_status}' to '{to_status}'." + ) diff --git a/durabletask/scheduled/models.py b/durabletask/scheduled/models.py new file mode 100644 index 00000000..60f59194 --- /dev/null +++ b/durabletask/scheduled/models.py @@ -0,0 +1,326 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import uuid +from dataclasses import dataclass +from datetime import datetime, timedelta +from typing import Any + +from durabletask.internal.helpers import ensure_aware +from durabletask.scheduled.schedule_status import ScheduleStatus + +MINIMUM_INTERVAL = timedelta(seconds=1) + + +def _validate_interval(interval: timedelta) -> timedelta: + if interval <= timedelta(0): + raise ValueError("Interval must be positive.") + if interval < MINIMUM_INTERVAL: + raise ValueError("Interval must be at least 1 second.") + return interval + + +def _to_iso(value: datetime | None) -> str | None: + return value.isoformat() if value is not None else None + + +def _from_iso(value: str | None) -> datetime | None: + return datetime.fromisoformat(value) if value else None + + +def _interval_to_seconds(value: timedelta | None) -> float | None: + return value.total_seconds() if value is not None else None + + +def _interval_from_seconds(value: float | None) -> timedelta | None: + return timedelta(seconds=value) if value is not None else None + + +@dataclass +class ScheduleCreationOptions: + """Options for creating a new schedule.""" + + schedule_id: str + orchestration_name: str + interval: timedelta + orchestration_input: Any | None = None + orchestration_instance_id: str | None = None + start_at: datetime | None = None + end_at: datetime | None = None + start_immediately_if_late: bool = False + + def __post_init__(self): + if not self.schedule_id: + raise ValueError("schedule_id cannot be empty.") + if not self.orchestration_name: + raise ValueError("orchestration_name cannot be empty.") + _validate_interval(self.interval) + + def to_json(self) -> dict[str, Any]: + return { + "schedule_id": self.schedule_id, + "orchestration_name": self.orchestration_name, + "interval_seconds": self.interval.total_seconds(), + "orchestration_input": self.orchestration_input, + "orchestration_instance_id": self.orchestration_instance_id, + "start_at": _to_iso(self.start_at), + "end_at": _to_iso(self.end_at), + "start_immediately_if_late": self.start_immediately_if_late, + } + + @classmethod + def from_json(cls, data: dict[str, Any]) -> "ScheduleCreationOptions": + return cls( + schedule_id=data["schedule_id"], + orchestration_name=data["orchestration_name"], + interval=timedelta(seconds=data["interval_seconds"]), + orchestration_input=data.get("orchestration_input"), + orchestration_instance_id=data.get("orchestration_instance_id"), + start_at=_from_iso(data.get("start_at")), + end_at=_from_iso(data.get("end_at")), + start_immediately_if_late=bool(data.get("start_immediately_if_late", False)), + ) + + +@dataclass +class ScheduleUpdateOptions: + """Options for updating an existing schedule. Only set fields are applied.""" + + orchestration_name: str | None = None + orchestration_input: Any | None = None + orchestration_instance_id: str | None = None + start_at: datetime | None = None + end_at: datetime | None = None + interval: timedelta | None = None + start_immediately_if_late: bool | None = None + + def __post_init__(self): + if self.interval is not None: + _validate_interval(self.interval) + + def to_json(self) -> dict[str, Any]: + return { + "orchestration_name": self.orchestration_name, + "orchestration_input": self.orchestration_input, + "orchestration_instance_id": self.orchestration_instance_id, + "start_at": _to_iso(self.start_at), + "end_at": _to_iso(self.end_at), + "interval_seconds": _interval_to_seconds(self.interval), + "start_immediately_if_late": self.start_immediately_if_late, + } + + @classmethod + def from_json(cls, data: dict[str, Any]) -> "ScheduleUpdateOptions": + return cls( + orchestration_name=data.get("orchestration_name"), + orchestration_input=data.get("orchestration_input"), + orchestration_instance_id=data.get("orchestration_instance_id"), + start_at=_from_iso(data.get("start_at")), + end_at=_from_iso(data.get("end_at")), + interval=_interval_from_seconds(data.get("interval_seconds")), + start_immediately_if_late=data.get("start_immediately_if_late"), + ) + + +@dataclass +class ScheduleQuery: + """Query parameters for filtering schedules.""" + + DEFAULT_PAGE_SIZE = 100 + + status: ScheduleStatus | None = None + schedule_id_prefix: str | None = None + created_from: datetime | None = None + created_to: datetime | None = None + page_size: int | None = None + + def __post_init__(self): + # Coerce the time-window bounds to timezone-aware UTC. Schedule + # timestamps are always stored as aware UTC, so normalizing here ensures + # a naive bound supplied by a caller can never reach the filter + # comparison and raise "can't compare offset-naive and offset-aware". + self.created_from = ensure_aware(self.created_from) + self.created_to = ensure_aware(self.created_to) + + +@dataclass +class ScheduleDescription: + """A read-only snapshot of a schedule's configuration and runtime state.""" + + schedule_id: str + orchestration_name: str | None = None + orchestration_input: Any | None = None + orchestration_instance_id: str | None = None + start_at: datetime | None = None + end_at: datetime | None = None + interval: timedelta | None = None + start_immediately_if_late: bool | None = None + status: ScheduleStatus = ScheduleStatus.UNINITIALIZED + execution_token: str = "" + last_run_at: datetime | None = None + next_run_at: datetime | None = None + + +class ScheduleConfiguration: + """Internal configuration for a scheduled task. Persisted as part of the entity state.""" + + def __init__(self, schedule_id: str, orchestration_name: str, interval: timedelta): + if not schedule_id: + raise ValueError("schedule_id cannot be empty.") + if not orchestration_name: + raise ValueError("orchestration_name cannot be empty.") + self.schedule_id = schedule_id + self.orchestration_name = orchestration_name + self.interval = _validate_interval(interval) + self.orchestration_input: Any | None = None + self.orchestration_instance_id: str | None = None + self.start_at: datetime | None = None + self.end_at: datetime | None = None + self.start_immediately_if_late: bool = False + + @staticmethod + def from_create_options(options: ScheduleCreationOptions) -> "ScheduleConfiguration": + config = ScheduleConfiguration(options.schedule_id, options.orchestration_name, options.interval) + config.orchestration_input = options.orchestration_input + config.orchestration_instance_id = options.orchestration_instance_id + config.start_at = options.start_at + config.end_at = options.end_at + config.start_immediately_if_late = options.start_immediately_if_late + config._validate() + return config + + def update(self, options: ScheduleUpdateOptions) -> set[str]: + """Apply the update options and return the set of changed field names.""" + updated: set[str] = set() + + if options.orchestration_name and options.orchestration_name != self.orchestration_name: + self.orchestration_name = options.orchestration_name + updated.add("orchestration_name") + + if options.orchestration_input is not None and options.orchestration_input != self.orchestration_input: + self.orchestration_input = options.orchestration_input + updated.add("orchestration_input") + + if options.orchestration_instance_id and options.orchestration_instance_id != self.orchestration_instance_id: + self.orchestration_instance_id = options.orchestration_instance_id + updated.add("orchestration_instance_id") + + if options.start_at is not None and options.start_at != self.start_at: + self.start_at = options.start_at + updated.add("start_at") + + if options.end_at is not None and options.end_at != self.end_at: + self.end_at = options.end_at + updated.add("end_at") + + if options.interval is not None and options.interval != self.interval: + self.interval = _validate_interval(options.interval) + updated.add("interval") + + if options.start_immediately_if_late is not None \ + and options.start_immediately_if_late != self.start_immediately_if_late: + self.start_immediately_if_late = options.start_immediately_if_late + updated.add("start_immediately_if_late") + + self._validate() + return updated + + def _validate(self): + if self.start_at is not None and self.end_at is not None and self.start_at > self.end_at: + raise ValueError("start_at cannot be later than end_at.") + + def to_json(self) -> dict[str, Any]: + return { + "schedule_id": self.schedule_id, + "orchestration_name": self.orchestration_name, + "interval_seconds": self.interval.total_seconds(), + "orchestration_input": self.orchestration_input, + "orchestration_instance_id": self.orchestration_instance_id, + "start_at": _to_iso(self.start_at), + "end_at": _to_iso(self.end_at), + "start_immediately_if_late": self.start_immediately_if_late, + } + + @classmethod + def from_json(cls, data: dict[str, Any]) -> "ScheduleConfiguration": + config = cls( + data["schedule_id"], + data["orchestration_name"], + timedelta(seconds=data["interval_seconds"]), + ) + config.orchestration_input = data.get("orchestration_input") + config.orchestration_instance_id = data.get("orchestration_instance_id") + config.start_at = _from_iso(data.get("start_at")) + config.end_at = _from_iso(data.get("end_at")) + config.start_immediately_if_late = bool(data.get("start_immediately_if_late", False)) + return config + + +class ScheduleState: + """Internal runtime state for a schedule. Persisted as the entity state.""" + + def __init__(self): + self.status: ScheduleStatus = ScheduleStatus.UNINITIALIZED + self.execution_token: str = _new_token() + self.last_run_at: datetime | None = None + self.next_run_at: datetime | None = None + self.schedule_created_at: datetime | None = None + self.schedule_last_modified_at: datetime | None = None + self.schedule_configuration: ScheduleConfiguration | None = None + + def refresh_execution_token(self): + self.execution_token = _new_token() + + def to_json(self) -> dict[str, Any]: + # ``schedule_configuration`` is returned as the object itself; the + # serializer recurses into it and fires its own ``to_json`` hook. Only + # this type's non-JSON-native leaves (datetimes) are converted here. + return { + "status": self.status.value, + "execution_token": self.execution_token, + "last_run_at": _to_iso(self.last_run_at), + "next_run_at": _to_iso(self.next_run_at), + "schedule_created_at": _to_iso(self.schedule_created_at), + "schedule_last_modified_at": _to_iso(self.schedule_last_modified_at), + "schedule_configuration": self.schedule_configuration, + } + + @classmethod + def from_json(cls, data: dict[str, Any]) -> "ScheduleState": + # The nested configuration is reconstructed by calling its own + # ``from_json`` hook directly. ``ScheduleConfiguration`` is an internal + # type, so there is no need to route it through a (possibly custom) + # converter -- keeping this hook converter-free means it round-trips + # under any code path, not only the worker's threaded converter. + state = cls() + state.status = ScheduleStatus(data["status"]) + state.execution_token = data["execution_token"] + state.last_run_at = _from_iso(data.get("last_run_at")) + state.next_run_at = _from_iso(data.get("next_run_at")) + state.schedule_created_at = _from_iso(data.get("schedule_created_at")) + state.schedule_last_modified_at = _from_iso(data.get("schedule_last_modified_at")) + config_data = data.get("schedule_configuration") + state.schedule_configuration = ( + ScheduleConfiguration.from_json(config_data) if config_data is not None else None) + return state + + def to_description(self) -> ScheduleDescription: + config = self.schedule_configuration + return ScheduleDescription( + schedule_id=config.schedule_id if config else "", + orchestration_name=config.orchestration_name if config else None, + orchestration_input=config.orchestration_input if config else None, + orchestration_instance_id=config.orchestration_instance_id if config else None, + start_at=config.start_at if config else None, + end_at=config.end_at if config else None, + interval=config.interval if config else None, + start_immediately_if_late=config.start_immediately_if_late if config else None, + status=self.status, + execution_token=self.execution_token, + last_run_at=self.last_run_at, + next_run_at=self.next_run_at, + ) + + +def _new_token() -> str: + return uuid.uuid4().hex diff --git a/durabletask/scheduled/orchestrator.py b/durabletask/scheduled/orchestrator.py new file mode 100644 index 00000000..0493b744 --- /dev/null +++ b/durabletask/scheduled/orchestrator.py @@ -0,0 +1,37 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from collections.abc import Generator +from dataclasses import dataclass +from typing import Any + +from durabletask import task +from durabletask.entities import EntityInstanceId + + +@dataclass +class ScheduleOperationRequest: + """Request describing an operation to execute against a schedule entity. + + A plain dataclass: the serializer round-trips it automatically. ``input`` is + typed ``Any``, so it is reconstructed as the raw deserialized payload; the + concrete options type is rebuilt later, at the entity-method boundary, from + that method's parameter annotation. + """ + + entity_id: str + operation_name: str + input: Any | None = None + + +def execute_schedule_operation_orchestrator( + ctx: task.OrchestrationContext, + request: ScheduleOperationRequest) -> Generator[task.Task[Any], Any, Any]: + """Orchestrator that executes a single operation on a schedule entity. + + Client-side write operations route through this orchestrator so callers can await + completion (and surface failures) of the underlying entity operation. + """ + entity_id = EntityInstanceId.parse(request.entity_id) + result = yield ctx.call_entity(entity_id, request.operation_name, request.input) + return result diff --git a/durabletask/scheduled/registration.py b/durabletask/scheduled/registration.py new file mode 100644 index 00000000..2a835094 --- /dev/null +++ b/durabletask/scheduled/registration.py @@ -0,0 +1,22 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import durabletask.internal.orchestrator_service_pb2 as pb +from durabletask.worker import TaskHubGrpcWorker +from durabletask.scheduled.orchestrator import execute_schedule_operation_orchestrator +from durabletask.scheduled.schedule_entity import ENTITY_NAME, Schedule + + +def configure_scheduled_tasks(worker: TaskHubGrpcWorker) -> None: + """Register the scheduled tasks entity and orchestrator with a worker. + + Call this before starting the worker to enable scheduled tasks support. + + Parameters + ---------- + worker : TaskHubGrpcWorker + The worker to register the schedule entity and operation orchestrator with. + """ + worker.add_entity(Schedule, ENTITY_NAME) + worker.add_orchestrator(execute_schedule_operation_orchestrator) + worker.add_capability(pb.WORKER_CAPABILITY_SCHEDULED_TASKS) diff --git a/durabletask/scheduled/schedule_entity.py b/durabletask/scheduled/schedule_entity.py new file mode 100644 index 00000000..90f385e7 --- /dev/null +++ b/durabletask/scheduled/schedule_entity.py @@ -0,0 +1,236 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import logging +from datetime import datetime, timezone +from typing import Any + +from durabletask.entities import DurableEntity, EntityInstanceId +from durabletask.internal.helpers import ensure_aware +from durabletask.scheduled import transitions +from durabletask.scheduled.exceptions import ScheduleInvalidTransitionError +from durabletask.scheduled.models import (ScheduleConfiguration, + ScheduleCreationOptions, ScheduleState, + ScheduleUpdateOptions) +from durabletask.scheduled.schedule_status import ScheduleStatus + +ENTITY_NAME = "schedule" +"""The lowercased entity name used for schedule entity instances.""" + +DELETE_OPERATION = "delete" +RUN_SCHEDULE_OPERATION = "run_schedule" + +logger = logging.getLogger("durabletask.scheduled") + + +def _now() -> datetime: + return datetime.now(timezone.utc) + + +class Schedule(DurableEntity): + """Entity that manages the state and execution of a scheduled task. + + The Schedule entity maintains the configuration and runtime state of a scheduled + task, handling operations like creation, updates, pausing/resuming, and executing + the target orchestration according to the defined schedule. + """ + + def _load_state(self) -> ScheduleState: + # The serializer reconstructs the persisted ``ScheduleState`` via its + # ``from_json`` hook; a missing/uninitialized entity yields a fresh one. + return self.get_state(ScheduleState, ScheduleState()) + + def _save_state(self, state: ScheduleState) -> None: + # Store the object as-is; the serializer persists it via its ``to_json`` + # hook when the entity batch is committed. + self.set_state(state) + + def _entity_id(self, schedule_id: str) -> EntityInstanceId: + return EntityInstanceId(ENTITY_NAME, schedule_id) + + def _can_transition_to(self, state: ScheduleState, operation_name: str, + target_status: ScheduleStatus) -> bool: + return transitions.is_valid_transition(operation_name, state.status, target_status) + + def create_schedule(self, options: ScheduleCreationOptions) -> None: + """Create a new schedule. If one already exists, update it in place.""" + state = self._load_state() + + if not self._can_transition_to(state, transitions.CREATE_SCHEDULE, ScheduleStatus.ACTIVE): + raise ScheduleInvalidTransitionError( + options.schedule_id if options else "", state.status, ScheduleStatus.ACTIVE, + transitions.CREATE_SCHEDULE) + + already_exists = state.schedule_created_at is not None + state.schedule_configuration = ScheduleConfiguration.from_create_options(options) + + if already_exists: + state.schedule_last_modified_at = _now() + state.refresh_execution_token() + state.next_run_at = None + else: + state.status = ScheduleStatus.ACTIVE + state.schedule_created_at = state.schedule_last_modified_at = _now() + + logger.info(f"Created schedule '{state.schedule_configuration.schedule_id}'.") + self._save_state(state) + + # Signal to run the schedule and let run_schedule decide whether to run now or later. + self.signal_entity( + self._entity_id(state.schedule_configuration.schedule_id), + RUN_SCHEDULE_OPERATION, + state.execution_token, + ) + + def update_schedule(self, options: ScheduleUpdateOptions) -> None: + """Update an existing schedule's configuration.""" + state = self._load_state() + + if not self._can_transition_to(state, transitions.UPDATE_SCHEDULE, state.status): + raise ScheduleInvalidTransitionError( + state.schedule_configuration.schedule_id if state.schedule_configuration else "", + state.status, state.status, transitions.UPDATE_SCHEDULE) + + if state.schedule_configuration is None: + raise ValueError("Schedule configuration is missing.") + + updated_fields = state.schedule_configuration.update(options) + if not updated_fields: + logger.debug("Schedule configuration is already up to date.") + self._save_state(state) + return + + state.schedule_last_modified_at = _now() + + if updated_fields & {"start_at", "interval", "start_immediately_if_late"}: + state.next_run_at = None + + state.refresh_execution_token() + logger.info(f"Updated schedule '{state.schedule_configuration.schedule_id}'.") + self._save_state(state) + + if state.status == ScheduleStatus.ACTIVE: + self.signal_entity( + self._entity_id(state.schedule_configuration.schedule_id), + RUN_SCHEDULE_OPERATION, + state.execution_token, + ) + + def pause_schedule(self, _: Any = None) -> None: + """Pause the schedule.""" + state = self._load_state() + schedule_id = state.schedule_configuration.schedule_id if state.schedule_configuration else "" + + if not self._can_transition_to(state, transitions.PAUSE_SCHEDULE, ScheduleStatus.PAUSED): + raise ScheduleInvalidTransitionError( + schedule_id, state.status, ScheduleStatus.PAUSED, transitions.PAUSE_SCHEDULE) + + if state.schedule_configuration is None: + raise ValueError("Schedule configuration is missing.") + + state.status = ScheduleStatus.PAUSED + state.next_run_at = None + state.refresh_execution_token() + logger.info(f"Paused schedule '{schedule_id}'.") + self._save_state(state) + + def resume_schedule(self, _: Any = None) -> None: + """Resume a paused schedule.""" + state = self._load_state() + schedule_id = state.schedule_configuration.schedule_id if state.schedule_configuration else "" + + if not self._can_transition_to(state, transitions.RESUME_SCHEDULE, ScheduleStatus.ACTIVE): + raise ScheduleInvalidTransitionError( + schedule_id, state.status, ScheduleStatus.ACTIVE, transitions.RESUME_SCHEDULE) + + if state.schedule_configuration is None: + raise ValueError("Schedule configuration is missing.") + + state.status = ScheduleStatus.ACTIVE + state.next_run_at = None + logger.info(f"Resumed schedule '{schedule_id}'.") + self._save_state(state) + + self.signal_entity( + self._entity_id(schedule_id), + RUN_SCHEDULE_OPERATION, + state.execution_token, + ) + + def run_schedule(self, execution_token: str) -> None: + """Heartbeat operation: starts the target orchestration when due and re-arms itself.""" + state = self._load_state() + + if state.status == ScheduleStatus.UNINITIALIZED: + # This signal is no longer useful since the schedule was deleted. + self.set_state(None) + return + + config = state.schedule_configuration + if config is None: + raise ValueError("Schedule configuration is missing.") + + if execution_token != state.execution_token: + logger.debug(f"Ignoring stale run signal for schedule '{config.schedule_id}'.") + return + + if state.status != ScheduleStatus.ACTIVE: + raise ValueError("Schedule must be in Active status to run.") + + end_at = ensure_aware(config.end_at) + if end_at is not None and _now() > end_at: + logger.info(f"Schedule '{config.schedule_id}' has passed its end time; deleting.") + state.next_run_at = None + self._save_state(state) + self.signal_entity(self._entity_id(config.schedule_id), DELETE_OPERATION) + return + + state.next_run_at = self._determine_next_run_time(state, config) + + if state.next_run_at <= _now(): + self._start_orchestration(config, state.next_run_at) + state.last_run_at = state.next_run_at + state.next_run_at = None + state.next_run_at = self._determine_next_run_time(state, config) + + self._save_state(state) + + self.signal_entity( + self._entity_id(config.schedule_id), + RUN_SCHEDULE_OPERATION, + state.execution_token, + signal_time=state.next_run_at, + ) + + def _start_orchestration(self, config: ScheduleConfiguration, scheduled_run_time: datetime) -> None: + instance_id = config.orchestration_instance_id + if not instance_id: + instance_id = f"{config.schedule_id}-{scheduled_run_time.isoformat()}" + + logger.info( + f"Starting orchestration '{config.orchestration_name}' with instance ID '{instance_id}' " + f"for schedule '{config.schedule_id}'.") + self.schedule_new_orchestration( + config.orchestration_name, + config.orchestration_input, + instance_id=instance_id, + ) + + def _determine_next_run_time(self, state: ScheduleState, config: ScheduleConfiguration) -> datetime: + if state.next_run_at is not None: + return ensure_aware(state.next_run_at) # type: ignore[return-value] + + now = _now() + start_time = ensure_aware(config.start_at) or ensure_aware(state.schedule_created_at) or now + time_since_start = now - start_time + + # Next run is in the future relative to the start time. + if time_since_start.total_seconds() < 0: + return start_time + + is_first_run = state.last_run_at is None + if is_first_run and config.start_immediately_if_late: + return now + + intervals_elapsed = int(time_since_start.total_seconds() // config.interval.total_seconds()) + return start_time + config.interval * (intervals_elapsed + 1) diff --git a/durabletask/scheduled/schedule_status.py b/durabletask/scheduled/schedule_status.py new file mode 100644 index 00000000..6a21afd2 --- /dev/null +++ b/durabletask/scheduled/schedule_status.py @@ -0,0 +1,17 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from enum import Enum + + +class ScheduleStatus(str, Enum): + """Represents the current status of a schedule.""" + + UNINITIALIZED = "Uninitialized" + """Schedule has not been created.""" + + ACTIVE = "Active" + """Schedule is active and running.""" + + PAUSED = "Paused" + """Schedule is paused.""" diff --git a/durabletask/scheduled/transitions.py b/durabletask/scheduled/transitions.py new file mode 100644 index 00000000..ad2f21ec --- /dev/null +++ b/durabletask/scheduled/transitions.py @@ -0,0 +1,47 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from durabletask.scheduled.schedule_status import ScheduleStatus + +# Operation names used by the Schedule entity. These must match the entity +# method names so that the transition table can be keyed by operation. +CREATE_SCHEDULE = "create_schedule" +UPDATE_SCHEDULE = "update_schedule" +PAUSE_SCHEDULE = "pause_schedule" +RESUME_SCHEDULE = "resume_schedule" + + +def is_valid_transition(operation_name: str, from_status: ScheduleStatus, + target_status: ScheduleStatus) -> bool: + """Check whether a transition to the target status is valid for the given operation. + + Parameters + ---------- + operation_name : str + The name of the operation being performed. + from_status : ScheduleStatus + The current schedule status. + target_status : ScheduleStatus + The status the schedule would transition to. + + Returns + ------- + bool + True if the transition is valid; otherwise False. + """ + if operation_name == CREATE_SCHEDULE: + return ( + (from_status == ScheduleStatus.UNINITIALIZED and target_status == ScheduleStatus.ACTIVE) + or (from_status == ScheduleStatus.ACTIVE and target_status == ScheduleStatus.ACTIVE) + or (from_status == ScheduleStatus.PAUSED and target_status == ScheduleStatus.ACTIVE) + ) + if operation_name == UPDATE_SCHEDULE: + return ( + (from_status == ScheduleStatus.ACTIVE and target_status == ScheduleStatus.ACTIVE) + or (from_status == ScheduleStatus.PAUSED and target_status == ScheduleStatus.PAUSED) + ) + if operation_name == PAUSE_SCHEDULE: + return from_status == ScheduleStatus.ACTIVE and target_status == ScheduleStatus.PAUSED + if operation_name == RESUME_SCHEDULE: + return from_status == ScheduleStatus.PAUSED and target_status == ScheduleStatus.ACTIVE + return False diff --git a/durabletask/serialization.py b/durabletask/serialization.py index 5d1dad89..e6c68935 100644 --- a/durabletask/serialization.py +++ b/durabletask/serialization.py @@ -360,6 +360,9 @@ def _coerce_to_type(value: Any, expected_type: Any, converter: DataConverter | N if expected_type is None or value is None: return value + if expected_type is typing.Any: + return value + origin = typing.get_origin(expected_type) if origin is not None: return _coerce_generic(value, expected_type, origin, converter) @@ -447,6 +450,11 @@ def _coerce_generic(value: Any, expected_type: Any, origin: Any, # If the value already matches a member type, keep it as-is. non_none = [a for a in args if a is not type(None)] for arg in non_none: + # ``Any`` imposes no constraint, so the value already satisfies the + # union. (Checked explicitly because ``isinstance(value, Any)`` would + # raise -- ``typing.Any`` is a class on 3.11+ but not isinstance-able.) + if arg is typing.Any: + return value if isinstance(arg, type) and isinstance(value, arg): return value # ``Optional[T]`` (exactly one non-None member): coerce to that member. diff --git a/durabletask/task.py b/durabletask/task.py index 18af6ccc..2085fd9c 100644 --- a/durabletask/task.py +++ b/durabletask/task.py @@ -220,7 +220,8 @@ def signal_entity( self, entity_id: EntityInstanceId, operation_name: str, - input: Any = None + input: Any = None, + signal_time: datetime | None = None ) -> None: """Signal an entity function for execution. @@ -232,6 +233,10 @@ def signal_entity( The name of the operation to invoke on the entity. input: TInput | None The optional JSON-serializable input to pass to the entity function. + signal_time: datetime | None + The optional time at which the signal should be delivered. If None, the + signal is delivered as soon as possible. Use this to schedule a future + operation on the entity. """ pass diff --git a/durabletask/testing/in_memory_backend.py b/durabletask/testing/in_memory_backend.py index 0eb761c7..3f45bc5f 100644 --- a/durabletask/testing/in_memory_backend.py +++ b/durabletask/testing/in_memory_backend.py @@ -153,6 +153,13 @@ def __init__(self, max_history_size: int = 10000, port: int = 50051): self._logger = logging.getLogger(__name__) self._shutdown_event = threading.Event() self._work_available = threading.Event() + # Monotonic lifecycle counter, bumped on every stop()/reset(). Background + # timers (e.g. delayed entity signals) capture it when scheduled and bail + # if it has changed by the time they fire, so a timer created before a + # stop/reset cannot mutate a subsequently-restarted or cleared backend. + # Unlike ``_shutdown_event`` (which reset() clears to allow restart), this + # only ever moves forward, so it reliably invalidates stale timers. + self._generation = 0 def start(self) -> str: """ @@ -183,6 +190,7 @@ def stop(self, grace: float | None = None): """ self._shutdown_event.set() self._work_available.set() # Unblock GetWorkItems loops + self._generation += 1 if self._server: stop_future = self._server.stop(grace) stop_future.wait() @@ -208,6 +216,7 @@ def reset(self): self._state_waiters.clear() self._shutdown_event.clear() self._work_available.clear() + self._generation += 1 # gRPC Service Methods @@ -919,9 +928,14 @@ def CompleteEntityTask(self, request: pb.EntityBatchResult, context: grpc.Servic try: if action.HasField("sendSignal"): signal = action.sendSignal + scheduled_time = ( + signal.scheduledTime.ToDatetime(tzinfo=timezone.utc) + if signal.HasField("scheduledTime") else None + ) self._signal_entity_internal( signal.instanceId, signal.name, - signal.input.value if signal.input else None + signal.input.value if signal.input else None, + scheduled_time=scheduled_time, ) elif action.HasField("startNewOrchestration"): start_orch = action.startNewOrchestration @@ -945,6 +959,10 @@ def CompleteEntityTask(self, request: pb.EntityBatchResult, context: grpc.Servic def SignalEntity(self, request: pb.SignalEntityRequest, context: grpc.ServicerContext) -> pb.SignalEntityResponse: """Signals an entity, queueing an operation for processing.""" + scheduled_time = ( + request.scheduledTime.ToDatetime(tzinfo=timezone.utc) + if request.HasField("scheduledTime") else None + ) with self._lock: entity_id = request.instanceId entity = self._entities.get(entity_id) @@ -967,8 +985,11 @@ def SignalEntity(self, request: pb.SignalEntityRequest, context: grpc.ServicerCo targetInstanceId=wrappers_pb2.StringValue(value=entity_id), ) ) - entity.pending_operations.append(event) - self._enqueue_entity(entity_id) + if scheduled_time is not None and scheduled_time > datetime.now(timezone.utc): + self._schedule_delayed_entity_operation(entity_id, event, scheduled_time) + else: + entity.pending_operations.append(event) + self._enqueue_entity(entity_id) self._logger.info(f"Signaled entity '{entity_id}' operation '{request.name}'") return pb.SignalEntityResponse() @@ -1604,11 +1625,19 @@ def _process_send_entity_message_action(self, instance: OrchestrationInstance, instance.history.append(history_event) if target_id: - self._queue_entity_operation(target_id, pb.HistoryEvent( + operation_event = pb.HistoryEvent( eventId=-1, timestamp=timestamp_pb2.Timestamp(), entityOperationSignaled=signaled, - )) + ) + scheduled_time = ( + signaled.scheduledTime.ToDatetime(tzinfo=timezone.utc) + if signaled.HasField("scheduledTime") else None + ) + if scheduled_time is not None and scheduled_time > datetime.now(timezone.utc): + self._schedule_delayed_entity_operation(target_id, operation_event, scheduled_time) + else: + self._queue_entity_operation(target_id, operation_event) elif msg.HasField("entityOperationCalled"): called = msg.entityOperationCalled @@ -1748,8 +1777,13 @@ def _queue_entity_operation(self, entity_id: str, event: pb.HistoryEvent): self._enqueue_entity(entity_id) def _signal_entity_internal(self, entity_id: str, operation: str, - input_value: str | None = None): - """Internal method to signal an entity (from entity side-effect actions).""" + input_value: str | None = None, + scheduled_time: datetime | None = None): + """Internal method to signal an entity (from entity side-effect actions). + + If ``scheduled_time`` is set and in the future, the operation is delayed + until that time (mirroring delayed-signal delivery on a real backend). + """ event = pb.HistoryEvent( eventId=-1, timestamp=timestamp_pb2.Timestamp(), @@ -1760,7 +1794,33 @@ def _signal_entity_internal(self, entity_id: str, operation: str, targetInstanceId=wrappers_pb2.StringValue(value=entity_id), ) ) - self._queue_entity_operation(entity_id, event) + if scheduled_time is not None and scheduled_time > datetime.now(timezone.utc): + self._schedule_delayed_entity_operation(entity_id, event, scheduled_time) + else: + self._queue_entity_operation(entity_id, event) + + def _schedule_delayed_entity_operation(self, entity_id: str, event: pb.HistoryEvent, + fire_at: datetime): + """Schedules an entity operation to be enqueued at a future time. + + Uses a background timer thread, mirroring the timer-firing mechanism used + for orchestration timers, so that delayed entity signals are honored by the + in-memory backend. + """ + delay = max(0.0, (fire_at - datetime.now(timezone.utc)).total_seconds()) + # Capture the lifecycle generation so a timer outliving a stop()/reset() + # does not enqueue into a restarted or cleared backend. + scheduled_generation = self._generation + + def fire() -> None: + time.sleep(delay) + with self._lock: + if self._shutdown_event.is_set() or self._generation != scheduled_generation: + return + self._queue_entity_operation(entity_id, event) + + timer_thread = threading.Thread(target=fire, daemon=True) + timer_thread.start() def _enqueue_entity(self, entity_id: str): """Enqueues an entity for processing.""" diff --git a/durabletask/worker.py b/durabletask/worker.py index 66381c97..ea42411e 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -577,6 +577,11 @@ def __init__( self._work_item_filters: WorkItemFilters | None = None self._auto_generate_work_item_filters: bool = False self._runLoop: Thread | None = None + # Extra worker capabilities advertised to the backend in + # GetWorkItemsRequest (in addition to ones derived from worker state such + # as LARGE_PAYLOADS). Feature-enablement helpers like + # durabletask.scheduled.configure_scheduled_tasks register theirs here. + self._capabilities: set[int] = set() @property def concurrency_options(self) -> ConcurrencyOptions: @@ -636,6 +641,20 @@ def add_entity(self, fn: task.Entity[Any, Any], name: str | None = None) -> str: ) return self._registry.add_entity(fn, name) + def add_capability(self, capability: int) -> None: + """Advertise a worker capability to the backend in ``GetWorkItemsRequest``. + + Most users do not call this directly; feature-enablement helpers such as + :func:`durabletask.scheduled.configure_scheduled_tasks` use it to + advertise the capabilities (``pb.WORKER_CAPABILITY_*``) their feature + relies on. + """ + if self._is_running: + raise RuntimeError( + "Capabilities cannot be added while the worker is running." + ) + self._capabilities.add(capability) + def use_versioning(self, version: VersioningOptions) -> None: """Initializes versioning options for sub-orchestrators and activities.""" if self._is_running: @@ -861,6 +880,7 @@ def should_invalidate_connection(rpc_error: grpc.RpcError) -> bool: capabilities: list[Any] = [] if self._payload_store is not None: capabilities.append(pb.WORKER_CAPABILITY_LARGE_PAYLOADS) + capabilities.extend(sorted(self._capabilities)) get_work_items_request = pb.GetWorkItemsRequest( maxConcurrentOrchestrationWorkItems=self._concurrency_options.maximum_concurrent_orchestration_work_items, maxConcurrentActivityWorkItems=self._concurrency_options.maximum_concurrent_activity_work_items, @@ -1742,12 +1762,13 @@ def signal_entity( self, entity_id: EntityInstanceId, operation_name: str, - input: Any = None + input: Any = None, + signal_time: datetime | None = None ) -> None: id = self.next_sequence_number() self.signal_entity_function_helper( - id, entity_id, operation_name, input + id, entity_id, operation_name, input, signal_time ) def lock_entities(self, entities: list[EntityInstanceId]) -> task.CompletableTask[EntityLock]: @@ -1918,7 +1939,8 @@ def signal_entity_function_helper( id: int | None, entity_id: EntityInstanceId, operation: str, - input: Any = None + input: Any = None, + signal_time: datetime | None = None ) -> None: if id is None: id = self.next_sequence_number() @@ -1930,7 +1952,7 @@ def signal_entity_function_helper( encoded_input = self._data_converter.serialize(input) - action = ph.new_signal_entity_action(id, entity_id, operation, encoded_input, self.new_uuid()) + action = ph.new_signal_entity_action(id, entity_id, operation, encoded_input, self.new_uuid(), signal_time) self._pending_actions[id] = action def lock_entities_function_helper(self, id: int | None, entities: list[EntityInstanceId]) -> None: diff --git a/examples/scheduled_tasks.py b/examples/scheduled_tasks.py new file mode 100644 index 00000000..841d3969 --- /dev/null +++ b/examples/scheduled_tasks.py @@ -0,0 +1,79 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""End-to-end sample showing how to create and manage a recurring schedule. + +The schedule periodically starts a target orchestration. This sample uses the +Durable Task Scheduler worker/client (compatible with the DTS emulator). +""" +import os +import time +from collections.abc import Generator +from datetime import datetime, timedelta, timezone +from typing import Any + +from azure.identity import DefaultAzureCredential + +from durabletask import task +from durabletask.azuremanaged.client import DurableTaskSchedulerClient +from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker +from durabletask.scheduled import (ScheduledTaskClient, ScheduleCreationOptions, + configure_scheduled_tasks) + + +def greet_orchestrator(ctx: task.OrchestrationContext, name: str) -> Generator[task.Task[Any], Any, Any]: + """The target orchestration that the schedule will start on each run.""" + yield ctx.create_timer(timedelta(seconds=1)) + return f"Hello, {name}!" + + +# Use environment variables if provided, otherwise use default emulator values +taskhub_name = os.getenv("TASKHUB", "default") +endpoint = os.getenv("ENDPOINT", "http://localhost:8080") + +print(f"Using taskhub: {taskhub_name}") +print(f"Using endpoint: {endpoint}") + +secure_channel = endpoint.startswith("https://") +credential = DefaultAzureCredential() if secure_channel else None + +with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=secure_channel, + taskhub=taskhub_name, token_credential=credential) as worker: + worker.add_orchestrator(greet_orchestrator) + # Register the schedule entity and operation orchestrator. + configure_scheduled_tasks(worker) + worker.start() + + client = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=secure_channel, + taskhub=taskhub_name, token_credential=credential) + scheduled_tasks = ScheduledTaskClient(client) + + # Create a schedule that runs the greet orchestration every 5 seconds. + schedule = scheduled_tasks.create_schedule(ScheduleCreationOptions( + schedule_id="greet-every-5s", + orchestration_name=task.get_name(greet_orchestrator), + interval=timedelta(seconds=5), + orchestration_input="world", + start_at=datetime.now(timezone.utc), + start_immediately_if_late=True, + )) + + print(f"Created schedule '{schedule.schedule_id}'.") + print(f"Description: {scheduled_tasks.get_schedule(schedule.schedule_id)}") + + # Let it run for a bit. + time.sleep(12) + + # Pause, then resume the schedule. + schedule.pause() + print("Schedule paused.") + time.sleep(2) + schedule.resume() + print("Schedule resumed.") + + # List all schedules. + print(f"All schedules: {[s.schedule_id for s in scheduled_tasks.list_schedules()]}") + + # Clean up. + schedule.delete() + print("Schedule deleted.") diff --git a/tests/durabletask-azuremanaged/entities/test_dts_delayed_signals_e2e.py b/tests/durabletask-azuremanaged/entities/test_dts_delayed_signals_e2e.py new file mode 100644 index 00000000..0033bd20 --- /dev/null +++ b/tests/durabletask-azuremanaged/entities/test_dts_delayed_signals_e2e.py @@ -0,0 +1,108 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Live DTS E2E tests for delayed (scheduled) entity signals. + +These tests assume a sidecar/emulator is running. Example command: + docker run -i -p 8080:8080 -p 8082:8082 -d mcr.microsoft.com/dts/dts-emulator:latest +""" + +import os +import threading +import time +import uuid +from datetime import datetime, timedelta, timezone + +import pytest + +from durabletask import entities, task +from durabletask.azuremanaged.client import DurableTaskSchedulerClient +from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker + +pytestmark = pytest.mark.dts + +taskhub_name = os.getenv("TASKHUB", "default") +endpoint = os.getenv("ENDPOINT", "http://localhost:8080") +secure_channel = endpoint.startswith("https://") + + +class _Recorder: + def __init__(self): + self._lock = threading.Lock() + self.times: list[datetime] = [] + + def record(self) -> None: + with self._lock: + self.times.append(datetime.now(timezone.utc)) + + @property + def count(self) -> int: + with self._lock: + return len(self.times) + + +recorder = _Recorder() + + +class Recorder(entities.DurableEntity): + def ping(self, _=None): + recorder.record() + + +def setup_function(_): + recorder.times.clear() + + +def _wait_until(predicate, timeout: float = 30, interval: float = 0.5) -> bool: + deadline = time.time() + timeout + while time.time() < deadline: + if predicate(): + return True + time.sleep(interval) + return False + + +def test_client_delayed_signal_is_deferred(): + with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=secure_channel, + taskhub=taskhub_name, token_credential=None) as w: + w.add_entity(Recorder) + w.start() + + c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=secure_channel, + taskhub=taskhub_name, token_credential=None) + entity_id = entities.EntityInstanceId("Recorder", f"client-delay-{uuid.uuid4().hex[:8]}") + sent_at = datetime.now(timezone.utc) + c.signal_entity(entity_id, "ping", signal_time=sent_at + timedelta(seconds=4)) + + time.sleep(1) + assert recorder.count == 0, "delayed signal fired too early" + + assert _wait_until(lambda: recorder.count >= 1) + elapsed = (recorder.times[0] - sent_at).total_seconds() + assert elapsed >= 3, f"signal fired too early ({elapsed}s)" + + +def test_orchestration_delayed_signal_is_deferred(): + def signaling_orchestrator(ctx: task.OrchestrationContext, key: str): + entity_id = entities.EntityInstanceId("Recorder", key) + ctx.signal_entity(entity_id, "ping", signal_time=ctx.current_utc_datetime + timedelta(seconds=4)) + + with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=secure_channel, + taskhub=taskhub_name, token_credential=None) as w: + w.add_orchestrator(signaling_orchestrator) + w.add_entity(Recorder) + w.start() + + c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=secure_channel, + taskhub=taskhub_name, token_credential=None) + key = f"orch-delay-{uuid.uuid4().hex[:8]}" + sent_at = datetime.now(timezone.utc) + instance_id = c.schedule_new_orchestration(signaling_orchestrator, input=key) + c.wait_for_orchestration_completion(instance_id, timeout=30) + + time.sleep(1) + assert recorder.count == 0, "delayed signal fired too early" + + assert _wait_until(lambda: recorder.count >= 1) + elapsed = (recorder.times[0] - sent_at).total_seconds() + assert elapsed >= 3, f"signal fired too early ({elapsed}s)" diff --git a/tests/durabletask-azuremanaged/scheduled/__init__.py b/tests/durabletask-azuremanaged/scheduled/__init__.py new file mode 100644 index 00000000..59e481eb --- /dev/null +++ b/tests/durabletask-azuremanaged/scheduled/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. diff --git a/tests/durabletask-azuremanaged/scheduled/test_dts_scheduled_e2e.py b/tests/durabletask-azuremanaged/scheduled/test_dts_scheduled_e2e.py new file mode 100644 index 00000000..06d3d76a --- /dev/null +++ b/tests/durabletask-azuremanaged/scheduled/test_dts_scheduled_e2e.py @@ -0,0 +1,192 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""End-to-end tests for scheduled tasks against a live Durable Task Scheduler. + +These tests assume a sidecar/emulator is running. Example command: + docker run -i -p 8080:8080 -p 8082:8082 -d mcr.microsoft.com/dts/dts-emulator:latest +""" + +import threading +import time +import uuid +from datetime import datetime, timedelta, timezone + +import pytest + +from durabletask import task +from durabletask.azuremanaged.client import DurableTaskSchedulerClient +from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker +from durabletask.scheduled import (ScheduledTaskClient, ScheduleCreationOptions, + ScheduleQuery, ScheduleStatus, + ScheduleUpdateOptions, + configure_scheduled_tasks) + +import os + +# NOTE: These tests assume a sidecar process is running. Example command: +# docker run -i -p 8080:8080 -p 8082:8082 -d mcr.microsoft.com/dts/dts-emulator:latest +pytestmark = pytest.mark.dts + +taskhub_name = os.getenv("TASKHUB", "default") +endpoint = os.getenv("ENDPOINT", "http://localhost:8080") +secure_channel = endpoint.startswith("https://") + + +class _RunTracker: + def __init__(self): + self._lock = threading.Lock() + self.inputs: list[object] = [] + + def record(self, value: object) -> None: + with self._lock: + self.inputs.append(value) + + @property + def count(self) -> int: + with self._lock: + return len(self.inputs) + + +tracker = _RunTracker() + + +def target_orchestrator(ctx: task.OrchestrationContext, value): + if not ctx.is_replaying: + tracker.record(value) + return value + + +def _make_worker() -> DurableTaskSchedulerWorker: + w = DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=secure_channel, + taskhub=taskhub_name, token_credential=None) + w.add_orchestrator(target_orchestrator) + configure_scheduled_tasks(w) + return w + + +def _make_client() -> DurableTaskSchedulerClient: + return DurableTaskSchedulerClient(host_address=endpoint, secure_channel=secure_channel, + taskhub=taskhub_name, token_credential=None) + + +def _wait_until(predicate, timeout: float = 30, interval: float = 0.5) -> bool: + deadline = time.time() + timeout + while time.time() < deadline: + if predicate(): + return True + time.sleep(interval) + return False + + +def setup_function(_): + tracker.inputs.clear() + + +def test_create_describe_and_run(): + schedule_id = f"sched-run-{uuid.uuid4().hex[:8]}" + with _make_worker() as w: + w.start() + c = _make_client() + scheduled = ScheduledTaskClient(c) + schedule = scheduled.create_schedule(ScheduleCreationOptions( + schedule_id=schedule_id, + orchestration_name=task.get_name(target_orchestrator), + interval=timedelta(seconds=1), + orchestration_input="hello", + start_at=datetime.now(timezone.utc), + start_immediately_if_late=True, + )) + try: + assert _wait_until(lambda: tracker.count >= 1), "target orchestration did not run" + + description = schedule.describe() + assert description.schedule_id == schedule_id + assert description.status == ScheduleStatus.ACTIVE + assert description.last_run_at is not None + finally: + schedule.delete() + + +def test_pause_and_resume(): + schedule_id = f"sched-pause-{uuid.uuid4().hex[:8]}" + with _make_worker() as w: + w.start() + c = _make_client() + scheduled = ScheduledTaskClient(c) + schedule = scheduled.create_schedule(ScheduleCreationOptions( + schedule_id=schedule_id, + orchestration_name=task.get_name(target_orchestrator), + interval=timedelta(seconds=1), + start_at=datetime.now(timezone.utc), + start_immediately_if_late=True, + )) + try: + assert _wait_until(lambda: tracker.count >= 1) + + schedule.pause() + assert schedule.describe().status == ScheduleStatus.PAUSED + + time.sleep(3) + count_after_pause = tracker.count + time.sleep(3) + assert tracker.count == count_after_pause, "schedule kept running while paused" + + schedule.resume() + assert schedule.describe().status == ScheduleStatus.ACTIVE + assert _wait_until(lambda: tracker.count > count_after_pause) + finally: + schedule.delete() + + +def test_update_interval_and_input(): + schedule_id = f"sched-update-{uuid.uuid4().hex[:8]}" + with _make_worker() as w: + w.start() + c = _make_client() + scheduled = ScheduledTaskClient(c) + schedule = scheduled.create_schedule(ScheduleCreationOptions( + schedule_id=schedule_id, + orchestration_name=task.get_name(target_orchestrator), + interval=timedelta(seconds=30), + orchestration_input="before", + start_at=datetime.now(timezone.utc) + timedelta(hours=1), + )) + try: + schedule.update(ScheduleUpdateOptions( + orchestration_input="after", + interval=timedelta(seconds=2), + )) + + description = schedule.describe() + assert description.interval == timedelta(seconds=2) + assert description.orchestration_input == "after" + finally: + schedule.delete() + + +def test_list_and_delete(): + prefix = f"grp-{uuid.uuid4().hex[:8]}-" + with _make_worker() as w: + w.start() + c = _make_client() + scheduled = ScheduledTaskClient(c) + created = [] + for i in range(3): + schedule_id = f"{prefix}{i}" + scheduled.create_schedule(ScheduleCreationOptions( + schedule_id=schedule_id, + orchestration_name=task.get_name(target_orchestrator), + interval=timedelta(seconds=30), + start_at=datetime.now(timezone.utc) + timedelta(hours=1), + )) + created.append(schedule_id) + + try: + listed = scheduled.list_schedules(ScheduleQuery(schedule_id_prefix=prefix)) + assert {s.schedule_id for s in listed} == set(created) + finally: + for schedule_id in created: + scheduled.get_schedule_client(schedule_id).delete() + + assert _wait_until(lambda: scheduled.get_schedule(created[0]) is None) diff --git a/tests/durabletask/entities/test_delayed_signals_e2e.py b/tests/durabletask/entities/test_delayed_signals_e2e.py new file mode 100644 index 00000000..b759e5cd --- /dev/null +++ b/tests/durabletask/entities/test_delayed_signals_e2e.py @@ -0,0 +1,121 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""E2E tests for delayed (scheduled) entity signals via the client and orchestrator.""" + +import threading +import time +from datetime import datetime, timedelta, timezone + +import pytest + +from durabletask import client, entities, task, worker +from durabletask.testing import create_test_backend + +from tests.durabletask._port_utils import find_free_port + +PORT = find_free_port() +HOST = f"localhost:{PORT}" + + +@pytest.fixture(autouse=True) +def backend(): + b = create_test_backend(port=PORT) + yield b + b.stop() + b.reset() + + +class _Recorder: + """Thread-safe recorder of operation invocation times.""" + + def __init__(self): + self._lock = threading.Lock() + self.times: list[datetime] = [] + + def record(self) -> None: + with self._lock: + self.times.append(datetime.now(timezone.utc)) + + @property + def count(self) -> int: + with self._lock: + return len(self.times) + + +recorder = _Recorder() + + +class Recorder(entities.DurableEntity): + def ping(self, _=None): + recorder.record() + + +def setup_function(_): + recorder.times.clear() + + +def _wait_until(predicate, timeout: float = 10, interval: float = 0.1) -> bool: + deadline = time.time() + timeout + while time.time() < deadline: + if predicate(): + return True + time.sleep(interval) + return False + + +def test_client_delayed_signal_is_deferred(): + with worker.TaskHubGrpcWorker(host_address=HOST) as w: + w.add_entity(Recorder) + w.start() + + with client.TaskHubGrpcClient(host_address=HOST) as c: + entity_id = entities.EntityInstanceId("Recorder", "client-delay") + signal_at = datetime.now(timezone.utc) + timedelta(seconds=2) + sent_at = datetime.now(timezone.utc) + c.signal_entity(entity_id, "ping", signal_time=signal_at) + + # Should not have fired immediately. + time.sleep(0.5) + assert recorder.count == 0, "delayed signal fired too early" + + # Should fire around the scheduled time. + assert _wait_until(lambda: recorder.count >= 1) + elapsed = (recorder.times[0] - sent_at).total_seconds() + assert elapsed >= 1.5, f"signal fired too early ({elapsed}s)" + + +def test_client_immediate_signal_still_works(): + with worker.TaskHubGrpcWorker(host_address=HOST) as w: + w.add_entity(Recorder) + w.start() + + with client.TaskHubGrpcClient(host_address=HOST) as c: + entity_id = entities.EntityInstanceId("Recorder", "client-now") + c.signal_entity(entity_id, "ping") + assert _wait_until(lambda: recorder.count >= 1) + + +def test_orchestration_delayed_signal_is_deferred(): + def signaling_orchestrator(ctx: task.OrchestrationContext, _): + entity_id = entities.EntityInstanceId("Recorder", "orch-delay") + signal_at = ctx.current_utc_datetime + timedelta(seconds=2) + ctx.signal_entity(entity_id, "ping", signal_time=signal_at) + + with worker.TaskHubGrpcWorker(host_address=HOST) as w: + w.add_orchestrator(signaling_orchestrator) + w.add_entity(Recorder) + w.start() + + with client.TaskHubGrpcClient(host_address=HOST) as c: + sent_at = datetime.now(timezone.utc) + instance_id = c.schedule_new_orchestration(signaling_orchestrator) + c.wait_for_orchestration_completion(instance_id, timeout=30) + + # Orchestration completes immediately, but the signal should be deferred. + time.sleep(0.5) + assert recorder.count == 0, "delayed signal fired too early" + + assert _wait_until(lambda: recorder.count >= 1) + elapsed = (recorder.times[0] - sent_at).total_seconds() + assert elapsed >= 1.5, f"signal fired too early ({elapsed}s)" diff --git a/tests/durabletask/scheduled/__init__.py b/tests/durabletask/scheduled/__init__.py new file mode 100644 index 00000000..59e481eb --- /dev/null +++ b/tests/durabletask/scheduled/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. diff --git a/tests/durabletask/scheduled/test_client_filter_and_capability.py b/tests/durabletask/scheduled/test_client_filter_and_capability.py new file mode 100644 index 00000000..133de4eb --- /dev/null +++ b/tests/durabletask/scheduled/test_client_filter_and_capability.py @@ -0,0 +1,84 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Unit tests for ScheduleClient filtering and scheduled-tasks worker capability.""" + +from datetime import datetime, timedelta, timezone + +import durabletask.internal.orchestrator_service_pb2 as pb +from durabletask.scheduled import configure_scheduled_tasks +from durabletask.scheduled.client import ScheduledTaskClient +from durabletask.scheduled.models import (ScheduleConfiguration, + ScheduleCreationOptions, ScheduleQuery, + ScheduleState) +from durabletask.scheduled.schedule_status import ScheduleStatus +from durabletask.worker import TaskHubGrpcWorker + + +def _state_created_at(created_at: datetime) -> ScheduleState: + state = ScheduleState() + state.status = ScheduleStatus.ACTIVE + state.schedule_created_at = created_at + state.schedule_configuration = ScheduleConfiguration.from_create_options( + ScheduleCreationOptions(schedule_id="s1", orchestration_name="orch", + interval=timedelta(seconds=5))) + return state + + +class TestMatchesFilter: + def test_naive_bound_does_not_crash_against_aware_created_at(self): + state = _state_created_at(datetime(2026, 1, 15, tzinfo=timezone.utc)) + # A naive bound is normalized by ScheduleQuery, so the comparison must + # not raise "can't compare offset-naive and offset-aware". + q = ScheduleQuery(created_from=datetime(2026, 1, 1, 0, 0, 0)) + assert ScheduledTaskClient._matches_filter(state, q) is True + + def test_created_from_is_exclusive(self): + # Bounds match .NET's exclusive semantics: a schedule created exactly at + # the bound is excluded. + boundary = datetime(2026, 1, 1, tzinfo=timezone.utc) + state = _state_created_at(boundary) + q = ScheduleQuery(created_from=boundary) + assert ScheduledTaskClient._matches_filter(state, q) is False + + def test_created_to_is_exclusive(self): + boundary = datetime(2026, 1, 1, tzinfo=timezone.utc) + state = _state_created_at(boundary) + q = ScheduleQuery(created_to=boundary) + assert ScheduledTaskClient._matches_filter(state, q) is False + + def test_inside_window_matches(self): + state = _state_created_at(datetime(2026, 1, 15, tzinfo=timezone.utc)) + q = ScheduleQuery( + created_from=datetime(2026, 1, 1, tzinfo=timezone.utc), + created_to=datetime(2026, 2, 1, tzinfo=timezone.utc), + ) + assert ScheduledTaskClient._matches_filter(state, q) is True + + def test_outside_window_is_excluded(self): + state = _state_created_at(datetime(2026, 3, 1, tzinfo=timezone.utc)) + q = ScheduleQuery(created_to=datetime(2026, 2, 1, tzinfo=timezone.utc)) + assert ScheduledTaskClient._matches_filter(state, q) is False + + def test_status_filter(self): + state = _state_created_at(datetime(2026, 1, 1, tzinfo=timezone.utc)) + assert ScheduledTaskClient._matches_filter(state, ScheduleQuery(status=ScheduleStatus.ACTIVE)) is True + assert ScheduledTaskClient._matches_filter(state, ScheduleQuery(status=ScheduleStatus.PAUSED)) is False + + +class TestScheduledTasksCapability: + def test_configure_advertises_scheduled_tasks_capability(self): + worker = TaskHubGrpcWorker() + configure_scheduled_tasks(worker) + assert pb.WORKER_CAPABILITY_SCHEDULED_TASKS in worker._capabilities # pyright: ignore[reportPrivateUsage] + + def test_capability_absent_by_default(self): + worker = TaskHubGrpcWorker() + assert pb.WORKER_CAPABILITY_SCHEDULED_TASKS not in worker._capabilities # pyright: ignore[reportPrivateUsage] + + def test_add_capability_rejected_while_running(self): + import pytest + worker = TaskHubGrpcWorker() + worker._is_running = True # pyright: ignore[reportPrivateUsage] + with pytest.raises(RuntimeError): + worker.add_capability(pb.WORKER_CAPABILITY_SCHEDULED_TASKS) diff --git a/tests/durabletask/scheduled/test_models.py b/tests/durabletask/scheduled/test_models.py new file mode 100644 index 00000000..7f1d09f0 --- /dev/null +++ b/tests/durabletask/scheduled/test_models.py @@ -0,0 +1,167 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Unit tests for scheduled tasks models, validation, and serialization.""" + +from datetime import datetime, timedelta, timezone + +import pytest + +from durabletask.serialization import JsonDataConverter +from durabletask.scheduled.models import (ScheduleConfiguration, + ScheduleCreationOptions, ScheduleQuery, + ScheduleState, ScheduleUpdateOptions) +from durabletask.scheduled.schedule_status import ScheduleStatus + +converter = JsonDataConverter() + + +class TestCreationOptionsValidation: + def test_requires_schedule_id(self): + with pytest.raises(ValueError): + ScheduleCreationOptions(schedule_id="", orchestration_name="orch", + interval=timedelta(seconds=5)) + + def test_requires_orchestration_name(self): + with pytest.raises(ValueError): + ScheduleCreationOptions(schedule_id="s1", orchestration_name="", + interval=timedelta(seconds=5)) + + def test_interval_must_be_at_least_one_second(self): + with pytest.raises(ValueError): + ScheduleCreationOptions(schedule_id="s1", orchestration_name="orch", + interval=timedelta(milliseconds=500)) + + def test_interval_must_be_positive(self): + with pytest.raises(ValueError): + ScheduleCreationOptions(schedule_id="s1", orchestration_name="orch", + interval=timedelta(seconds=-1)) + + def test_valid_options(self): + options = ScheduleCreationOptions(schedule_id="s1", orchestration_name="orch", + interval=timedelta(seconds=30)) + assert options.schedule_id == "s1" + assert options.interval == timedelta(seconds=30) + + +class TestCreationOptionsSerialization: + def test_round_trip_through_json(self): + start = datetime(2026, 1, 1, tzinfo=timezone.utc) + end = datetime(2026, 2, 1, tzinfo=timezone.utc) + options = ScheduleCreationOptions( + schedule_id="s1", orchestration_name="orch", interval=timedelta(minutes=5), + orchestration_input={"key": "value"}, orchestration_instance_id="inst-1", + start_at=start, end_at=end, start_immediately_if_late=True) + + encoded = converter.serialize(options) + decoded = converter.deserialize(encoded, ScheduleCreationOptions) + + assert decoded.schedule_id == "s1" + assert decoded.orchestration_name == "orch" + assert decoded.interval == timedelta(minutes=5) + assert decoded.orchestration_input == {"key": "value"} + assert decoded.orchestration_instance_id == "inst-1" + assert decoded.start_at == start + assert decoded.end_at == end + assert decoded.start_immediately_if_late is True + + +class TestUpdateOptions: + def test_interval_validation(self): + with pytest.raises(ValueError): + ScheduleUpdateOptions(interval=timedelta(milliseconds=100)) + + def test_round_trip_through_json(self): + options = ScheduleUpdateOptions(orchestration_name="orch2", interval=timedelta(seconds=10)) + decoded = converter.deserialize(converter.serialize(options), ScheduleUpdateOptions) + assert decoded.orchestration_name == "orch2" + assert decoded.interval == timedelta(seconds=10) + assert decoded.start_at is None + + +class TestScheduleConfiguration: + def test_from_create_options_rejects_start_after_end(self): + options = ScheduleCreationOptions( + schedule_id="s1", orchestration_name="orch", interval=timedelta(seconds=5), + start_at=datetime(2026, 2, 1, tzinfo=timezone.utc), + end_at=datetime(2026, 1, 1, tzinfo=timezone.utc)) + with pytest.raises(ValueError): + ScheduleConfiguration.from_create_options(options) + + def test_update_returns_changed_fields(self): + config = ScheduleConfiguration.from_create_options( + ScheduleCreationOptions(schedule_id="s1", orchestration_name="orch", + interval=timedelta(seconds=5))) + changed = config.update(ScheduleUpdateOptions(interval=timedelta(seconds=10), + orchestration_name="orch2")) + assert changed == {"interval", "orchestration_name"} + assert config.interval == timedelta(seconds=10) + assert config.orchestration_name == "orch2" + + def test_update_no_changes_returns_empty(self): + config = ScheduleConfiguration.from_create_options( + ScheduleCreationOptions(schedule_id="s1", orchestration_name="orch", + interval=timedelta(seconds=5))) + changed = config.update(ScheduleUpdateOptions(orchestration_name="orch")) + assert changed == set() + + def test_config_round_trip(self): + config = ScheduleConfiguration.from_create_options( + ScheduleCreationOptions(schedule_id="s1", orchestration_name="orch", + interval=timedelta(seconds=5), + start_at=datetime(2026, 1, 1, tzinfo=timezone.utc))) + restored = converter.deserialize(converter.serialize(config), ScheduleConfiguration) + assert restored.schedule_id == "s1" + assert restored.interval == timedelta(seconds=5) + assert restored.start_at == datetime(2026, 1, 1, tzinfo=timezone.utc) + + +class TestScheduleState: + def test_round_trip_and_description(self): + state = ScheduleState() + state.status = ScheduleStatus.ACTIVE + state.schedule_created_at = datetime(2026, 1, 1, tzinfo=timezone.utc) + state.schedule_configuration = ScheduleConfiguration.from_create_options( + ScheduleCreationOptions(schedule_id="s1", orchestration_name="orch", + interval=timedelta(seconds=5))) + + # The nested ``ScheduleConfiguration`` round-trips automatically. + restored = converter.deserialize(converter.serialize(state), ScheduleState) + assert restored.status == ScheduleStatus.ACTIVE + assert restored.schedule_created_at == datetime(2026, 1, 1, tzinfo=timezone.utc) + assert restored.schedule_configuration is not None + assert restored.schedule_configuration.interval == timedelta(seconds=5) + + description = restored.to_description() + assert description.schedule_id == "s1" + assert description.status == ScheduleStatus.ACTIVE + assert description.interval == timedelta(seconds=5) + + def test_refresh_execution_token_changes_token(self): + state = ScheduleState() + original = state.execution_token + state.refresh_execution_token() + assert state.execution_token != original + + +class TestScheduleQueryNormalization: + def test_naive_bounds_are_coerced_to_aware_utc(self): + q = ScheduleQuery( + created_from=datetime(2026, 1, 1, 0, 0, 0), + created_to=datetime(2026, 2, 1, 0, 0, 0), + ) + assert q.created_from is not None and q.created_to is not None + assert q.created_from == datetime(2026, 1, 1, tzinfo=timezone.utc) + assert q.created_to == datetime(2026, 2, 1, tzinfo=timezone.utc) + assert q.created_from.tzinfo is timezone.utc + assert q.created_to.tzinfo is timezone.utc + + def test_aware_bounds_are_preserved(self): + start = datetime(2026, 1, 1, tzinfo=timezone.utc) + q = ScheduleQuery(created_from=start) + assert q.created_from == start + + def test_none_bounds_stay_none(self): + q = ScheduleQuery() + assert q.created_from is None + assert q.created_to is None diff --git a/tests/durabletask/scheduled/test_schedule_entity.py b/tests/durabletask/scheduled/test_schedule_entity.py new file mode 100644 index 00000000..91733579 --- /dev/null +++ b/tests/durabletask/scheduled/test_schedule_entity.py @@ -0,0 +1,196 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Unit tests for the Schedule entity behavior, driven through _EntityExecutor.""" + +import logging +from datetime import datetime, timedelta, timezone +from typing import Any + +import pytest + +import durabletask.internal.orchestrator_service_pb2 as pb +from durabletask.entities import EntityInstanceId +from durabletask.internal.entity_state_shim import StateShim +from durabletask.scheduled.exceptions import ScheduleInvalidTransitionError +from durabletask.scheduled.models import (ScheduleCreationOptions, ScheduleState, + ScheduleUpdateOptions) +from durabletask.scheduled.schedule_entity import (ENTITY_NAME, Schedule) +from durabletask.scheduled.schedule_status import ScheduleStatus +from durabletask.serialization import JsonDataConverter +from durabletask.worker import _EntityExecutor, _Registry + +SCHEDULE_ID = "sched-1" + + +class Harness: + """Drives Schedule entity operations against a persistent in-memory state. + + Mirrors the worker's entity-batch lifecycle: the ``StateShim`` holds the + serialized state between operations, so each ``run`` deserializes on read and + serializes on write through the data converter -- exactly the wire round-trip + a real entity experiences across batches. + """ + + def __init__(self): + registry = _Registry() + registry.add_entity(Schedule, ENTITY_NAME) + self.converter = JsonDataConverter() + self.executor = _EntityExecutor(registry, logging.getLogger("test"), self.converter) + self.shim = StateShim(None, self.converter) + self.entity_id = EntityInstanceId(ENTITY_NAME, SCHEDULE_ID) + + def run(self, operation: str, input: Any = None) -> tuple[str | None, list[pb.OperationAction]]: + before = len(self.shim.get_operation_actions()) + encoded = self.converter.serialize(input) if input is not None else None + result = self.executor.execute("orch-1", self.entity_id, operation, self.shim, encoded) + self.shim.commit() + actions = self.shim.get_operation_actions()[before:] + return result, actions + + def state(self) -> ScheduleState | None: + """Reconstruct the typed state object, the way the entity itself would.""" + return self.shim.get_state(ScheduleState) + + @property + def current(self) -> ScheduleState: + """Like :meth:`state` but asserts the state exists (most operations).""" + state = self.state() + assert state is not None + return state + + @property + def token(self) -> str: + return self.current.execution_token + + +def _signal_actions(actions: list[pb.OperationAction]) -> list[pb.OperationAction]: + return [a for a in actions if a.HasField("sendSignal")] + + +def _start_actions(actions: list[pb.OperationAction]) -> list[pb.OperationAction]: + return [a for a in actions if a.HasField("startNewOrchestration")] + + +def _creation_options(**kwargs: Any) -> ScheduleCreationOptions: + base: dict[str, Any] = dict( + schedule_id=SCHEDULE_ID, orchestration_name="my_orch", interval=timedelta(seconds=30)) + base.update(kwargs) + return ScheduleCreationOptions(**base) + + +class TestCreate: + def test_create_activates_and_signals_run(self): + h = Harness() + _, actions = h.run("create_schedule", _creation_options()) + + state = h.current + assert state.status == ScheduleStatus.ACTIVE + assert state.schedule_created_at is not None + signals = _signal_actions(actions) + assert len(signals) == 1 + assert signals[0].sendSignal.name == "run_schedule" + assert signals[0].sendSignal.instanceId == f"@{ENTITY_NAME}@{SCHEDULE_ID}" + + def test_create_twice_updates_in_place(self): + h = Harness() + h.run("create_schedule", _creation_options()) + first_token = h.token + h.run("create_schedule", _creation_options(interval=timedelta(seconds=60))) + # Re-creation refreshes the execution token. + assert h.token != first_token + assert h.current.status == ScheduleStatus.ACTIVE + + +class TestPauseResume: + def test_pause_then_resume(self): + h = Harness() + h.run("create_schedule", _creation_options()) + + h.run("pause_schedule") + assert h.current.status == ScheduleStatus.PAUSED + assert h.current.next_run_at is None + + _, actions = h.run("resume_schedule") + assert h.current.status == ScheduleStatus.ACTIVE + assert len(_signal_actions(actions)) == 1 + + def test_pause_when_not_active_raises(self): + h = Harness() + h.run("create_schedule", _creation_options()) + h.run("pause_schedule") + with pytest.raises(ScheduleInvalidTransitionError): + h.run("pause_schedule") + + +class TestUpdate: + def test_update_changes_config_and_resignals(self): + h = Harness() + h.run("create_schedule", _creation_options()) + _, actions = h.run("update_schedule", + ScheduleUpdateOptions(interval=timedelta(seconds=120))) + config = h.current.schedule_configuration + assert config is not None + assert config.interval == timedelta(seconds=120) + assert len(_signal_actions(actions)) == 1 + + def test_update_no_change_does_not_signal(self): + h = Harness() + h.run("create_schedule", _creation_options()) + _, actions = h.run("update_schedule", + ScheduleUpdateOptions(orchestration_name="my_orch")) + assert len(_signal_actions(actions)) == 0 + + +class TestRunSchedule: + def test_runs_orchestration_when_due_and_rearms(self): + h = Harness() + past = datetime.now(timezone.utc) - timedelta(hours=1) + h.run("create_schedule", _creation_options(start_at=past, start_immediately_if_late=True)) + + _, actions = h.run("run_schedule", h.token) + + starts = _start_actions(actions) + assert len(starts) == 1 + assert starts[0].startNewOrchestration.name == "my_orch" + assert h.current.last_run_at is not None + + # Re-arm signal should carry a future scheduled time. + signals = _signal_actions(actions) + assert len(signals) == 1 + assert signals[0].sendSignal.HasField("scheduledTime") + + def test_ignores_stale_token(self): + h = Harness() + h.run("create_schedule", _creation_options()) + _, actions = h.run("run_schedule", "stale-token") + assert len(_start_actions(actions)) == 0 + assert len(_signal_actions(actions)) == 0 + + def test_future_start_does_not_run_yet(self): + h = Harness() + future = datetime.now(timezone.utc) + timedelta(days=1) + h.run("create_schedule", _creation_options(start_at=future)) + _, actions = h.run("run_schedule", h.token) + assert len(_start_actions(actions)) == 0 + # Still re-arms with a future scheduled signal. + signals = _signal_actions(actions) + assert len(signals) == 1 + assert signals[0].sendSignal.HasField("scheduledTime") + + def test_past_end_time_deletes(self): + h = Harness() + start = datetime.now(timezone.utc) - timedelta(hours=2) + end = datetime.now(timezone.utc) - timedelta(hours=1) + h.run("create_schedule", _creation_options(start_at=start, end_at=end)) + _, actions = h.run("run_schedule", h.token) + delete_signals = [a for a in _signal_actions(actions) if a.sendSignal.name == "delete"] + assert len(delete_signals) == 1 + + +class TestDelete: + def test_delete_clears_state(self): + h = Harness() + h.run("create_schedule", _creation_options()) + h.run("delete") + assert h.state() is None diff --git a/tests/durabletask/scheduled/test_scheduled_e2e.py b/tests/durabletask/scheduled/test_scheduled_e2e.py new file mode 100644 index 00000000..911cfa64 --- /dev/null +++ b/tests/durabletask/scheduled/test_scheduled_e2e.py @@ -0,0 +1,221 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""End-to-end tests for scheduled tasks against the in-memory backend.""" + +import threading +import time +from datetime import datetime, timedelta, timezone + +import pytest + +from durabletask import client, task, worker +from durabletask.scheduled import (ScheduledTaskClient, ScheduleCreationOptions, + ScheduleQuery, ScheduleStatus, + ScheduleUpdateOptions, + configure_scheduled_tasks) +from durabletask.testing import create_test_backend + +from tests.durabletask._port_utils import find_free_port + +PORT = find_free_port() +HOST = f"localhost:{PORT}" + + +@pytest.fixture(autouse=True) +def backend(): + """Create an in-memory backend for each test.""" + b = create_test_backend(port=PORT) + yield b + b.stop() + b.reset() + + +class _RunTracker: + """Thread-safe tracker that records each target orchestration run.""" + + def __init__(self): + self._lock = threading.Lock() + self.inputs: list[object] = [] + + def record(self, value: object) -> None: + with self._lock: + self.inputs.append(value) + + @property + def count(self) -> int: + with self._lock: + return len(self.inputs) + + +# A module-level tracker is used because orchestrators must be registered by +# reference and run on worker threads. +tracker = _RunTracker() + + +def target_orchestrator(ctx: task.OrchestrationContext, value): + """The orchestration started on each schedule run.""" + if not ctx.is_replaying: + tracker.record(value) + return value + + +def _make_worker() -> worker.TaskHubGrpcWorker: + w = worker.TaskHubGrpcWorker(host_address=HOST) + w.add_orchestrator(target_orchestrator) + configure_scheduled_tasks(w) + return w + + +def _wait_until(predicate, timeout: float = 15, interval: float = 0.2) -> bool: + deadline = time.time() + timeout + while time.time() < deadline: + if predicate(): + return True + time.sleep(interval) + return False + + +def setup_function(_): + # Reset the shared tracker before each test. + tracker.inputs.clear() + + +def test_create_describe_and_run(): + with _make_worker() as w: + w.start() + with client.TaskHubGrpcClient(host_address=HOST) as c: + scheduled = ScheduledTaskClient(c) + schedule = scheduled.create_schedule(ScheduleCreationOptions( + schedule_id="sched-run", + orchestration_name=task.get_name(target_orchestrator), + interval=timedelta(seconds=1), + orchestration_input="hello", + start_at=datetime.now(timezone.utc), + start_immediately_if_late=True, + )) + + assert _wait_until(lambda: tracker.count >= 1), "target orchestration did not run" + + description = schedule.describe() + assert description.schedule_id == "sched-run" + assert description.status == ScheduleStatus.ACTIVE + assert description.orchestration_name == task.get_name(target_orchestrator) + assert description.last_run_at is not None + + schedule.delete() + + +def test_get_nonexistent_returns_none(): + with _make_worker() as w: + w.start() + with client.TaskHubGrpcClient(host_address=HOST) as c: + scheduled = ScheduledTaskClient(c) + assert scheduled.get_schedule("does-not-exist") is None + + +def test_pause_and_resume(): + with _make_worker() as w: + w.start() + with client.TaskHubGrpcClient(host_address=HOST) as c: + scheduled = ScheduledTaskClient(c) + schedule = scheduled.create_schedule(ScheduleCreationOptions( + schedule_id="sched-pause", + orchestration_name=task.get_name(target_orchestrator), + interval=timedelta(seconds=1), + start_at=datetime.now(timezone.utc), + start_immediately_if_late=True, + )) + + assert _wait_until(lambda: tracker.count >= 1) + + schedule.pause() + assert schedule.describe().status == ScheduleStatus.PAUSED + + # After pausing, no further runs should occur. + time.sleep(2) + count_after_pause = tracker.count + time.sleep(2) + assert tracker.count == count_after_pause, "schedule kept running while paused" + + schedule.resume() + assert schedule.describe().status == ScheduleStatus.ACTIVE + assert _wait_until(lambda: tracker.count > count_after_pause) + + schedule.delete() + + +def test_update_interval_and_input(): + with _make_worker() as w: + w.start() + with client.TaskHubGrpcClient(host_address=HOST) as c: + scheduled = ScheduledTaskClient(c) + schedule = scheduled.create_schedule(ScheduleCreationOptions( + schedule_id="sched-update", + orchestration_name=task.get_name(target_orchestrator), + interval=timedelta(seconds=30), + orchestration_input="before", + start_at=datetime.now(timezone.utc) + timedelta(hours=1), + )) + + schedule.update(ScheduleUpdateOptions( + orchestration_input="after", + interval=timedelta(seconds=2), + )) + + description = schedule.describe() + assert description.interval == timedelta(seconds=2) + assert description.orchestration_input == "after" + + schedule.delete() + + +def test_list_schedules_with_prefix_and_status(): + with _make_worker() as w: + w.start() + with client.TaskHubGrpcClient(host_address=HOST) as c: + scheduled = ScheduledTaskClient(c) + for i in range(3): + scheduled.create_schedule(ScheduleCreationOptions( + schedule_id=f"group-a-{i}", + orchestration_name=task.get_name(target_orchestrator), + interval=timedelta(seconds=30), + start_at=datetime.now(timezone.utc) + timedelta(hours=1), + )) + scheduled.create_schedule(ScheduleCreationOptions( + schedule_id="group-b-0", + orchestration_name=task.get_name(target_orchestrator), + interval=timedelta(seconds=30), + start_at=datetime.now(timezone.utc) + timedelta(hours=1), + )) + + all_schedules = scheduled.list_schedules() + ids = {s.schedule_id for s in all_schedules} + assert {"group-a-0", "group-a-1", "group-a-2", "group-b-0"}.issubset(ids) + + group_a = scheduled.list_schedules(ScheduleQuery(schedule_id_prefix="group-a-")) + assert {s.schedule_id for s in group_a} == {"group-a-0", "group-a-1", "group-a-2"} + + active = scheduled.list_schedules(ScheduleQuery(status=ScheduleStatus.ACTIVE)) + assert all(s.status == ScheduleStatus.ACTIVE for s in active) + + for s in all_schedules: + scheduled.get_schedule_client(s.schedule_id).delete() + + +def test_delete_removes_schedule(): + with _make_worker() as w: + w.start() + with client.TaskHubGrpcClient(host_address=HOST) as c: + scheduled = ScheduledTaskClient(c) + schedule = scheduled.create_schedule(ScheduleCreationOptions( + schedule_id="sched-delete", + orchestration_name=task.get_name(target_orchestrator), + interval=timedelta(seconds=30), + start_at=datetime.now(timezone.utc) + timedelta(hours=1), + )) + + assert scheduled.get_schedule("sched-delete") is not None + + schedule.delete() + assert _wait_until(lambda: scheduled.get_schedule("sched-delete") is None) diff --git a/tests/durabletask/scheduled/test_transitions.py b/tests/durabletask/scheduled/test_transitions.py new file mode 100644 index 00000000..1a7f2c05 --- /dev/null +++ b/tests/durabletask/scheduled/test_transitions.py @@ -0,0 +1,58 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Unit tests for schedule state transition rules.""" + +from durabletask.scheduled import transitions +from durabletask.scheduled.schedule_status import ScheduleStatus + + +class TestCreateTransitions: + def test_create_from_uninitialized_to_active_is_valid(self): + assert transitions.is_valid_transition( + transitions.CREATE_SCHEDULE, ScheduleStatus.UNINITIALIZED, ScheduleStatus.ACTIVE) + + def test_create_from_active_to_active_is_valid(self): + assert transitions.is_valid_transition( + transitions.CREATE_SCHEDULE, ScheduleStatus.ACTIVE, ScheduleStatus.ACTIVE) + + def test_create_from_paused_to_active_is_valid(self): + assert transitions.is_valid_transition( + transitions.CREATE_SCHEDULE, ScheduleStatus.PAUSED, ScheduleStatus.ACTIVE) + + +class TestUpdateTransitions: + def test_update_active_to_active_is_valid(self): + assert transitions.is_valid_transition( + transitions.UPDATE_SCHEDULE, ScheduleStatus.ACTIVE, ScheduleStatus.ACTIVE) + + def test_update_paused_to_paused_is_valid(self): + assert transitions.is_valid_transition( + transitions.UPDATE_SCHEDULE, ScheduleStatus.PAUSED, ScheduleStatus.PAUSED) + + def test_update_from_uninitialized_is_invalid(self): + assert not transitions.is_valid_transition( + transitions.UPDATE_SCHEDULE, ScheduleStatus.UNINITIALIZED, ScheduleStatus.UNINITIALIZED) + + +class TestPauseResumeTransitions: + def test_pause_active_is_valid(self): + assert transitions.is_valid_transition( + transitions.PAUSE_SCHEDULE, ScheduleStatus.ACTIVE, ScheduleStatus.PAUSED) + + def test_pause_paused_is_invalid(self): + assert not transitions.is_valid_transition( + transitions.PAUSE_SCHEDULE, ScheduleStatus.PAUSED, ScheduleStatus.PAUSED) + + def test_resume_paused_is_valid(self): + assert transitions.is_valid_transition( + transitions.RESUME_SCHEDULE, ScheduleStatus.PAUSED, ScheduleStatus.ACTIVE) + + def test_resume_active_is_invalid(self): + assert not transitions.is_valid_transition( + transitions.RESUME_SCHEDULE, ScheduleStatus.ACTIVE, ScheduleStatus.ACTIVE) + + +def test_unknown_operation_is_invalid(): + assert not transitions.is_valid_transition( + "unknown_op", ScheduleStatus.ACTIVE, ScheduleStatus.ACTIVE) diff --git a/tests/durabletask/test_delayed_signal_normalization.py b/tests/durabletask/test_delayed_signal_normalization.py new file mode 100644 index 00000000..c2d1cbfc --- /dev/null +++ b/tests/durabletask/test_delayed_signal_normalization.py @@ -0,0 +1,65 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Unit tests for delayed-signal datetime normalization and in-memory backend +timer lifecycle (regression coverage for PR #160 review feedback).""" + +import time +from datetime import datetime, timedelta, timezone + +import durabletask.internal.orchestrator_service_pb2 as pb +from durabletask.entities import EntityInstanceId +from durabletask.internal import helpers +from durabletask.internal.client_helpers import build_signal_entity_req +from durabletask.testing import create_test_backend + +from tests.durabletask._port_utils import find_free_port + + +class TestSignalTimeNormalization: + def test_build_signal_entity_req_naive_matches_aware(self): + entity_id = EntityInstanceId("Recorder", "k") + naive = datetime(2030, 1, 1, 12, 0, 0) + aware = datetime(2030, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + + req_naive = build_signal_entity_req(entity_id, "ping", signal_time=naive) + req_aware = build_signal_entity_req(entity_id, "ping", signal_time=aware) + + assert req_naive.scheduledTime.seconds == req_aware.scheduledTime.seconds + + def test_new_signal_entity_action_naive_matches_aware(self): + entity_id = EntityInstanceId("Recorder", "k") + naive = datetime(2030, 1, 1, 12, 0, 0) + aware = datetime(2030, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + + a_naive = helpers.new_signal_entity_action(1, entity_id, "ping", None, "r1", naive) + a_aware = helpers.new_signal_entity_action(1, entity_id, "ping", None, "r2", aware) + + naive_ts = a_naive.sendEntityMessage.entityOperationSignaled.scheduledTime + aware_ts = a_aware.sendEntityMessage.entityOperationSignaled.scheduledTime + assert naive_ts.seconds == aware_ts.seconds + + +class TestDelayedTimerLifecycle: + def test_delayed_operation_does_not_fire_into_reset_backend(self): + backend = create_test_backend(port=find_free_port()) + try: + entity_id = "@recorder@k" + event = pb.HistoryEvent( + eventId=-1, + entityOperationSignaled=pb.EntityOperationSignaledEvent(operation="ping"), + ) + # Schedule the op to fire shortly in the future, then reset before it + # fires so the timer wakes into a reset (new-generation) backend. + fire_at = datetime.now(timezone.utc) + timedelta(seconds=0.3) + backend._schedule_delayed_entity_operation( # pyright: ignore[reportPrivateUsage] + entity_id, event, fire_at) + backend.reset() + + # Give the timer thread time to wake; it must observe the new + # generation and refuse to recreate state in the reset backend. + time.sleep(0.6) + assert backend._entities == {} # pyright: ignore[reportPrivateUsage] + finally: + backend.stop() + backend.reset() diff --git a/tests/durabletask/test_serialization.py b/tests/durabletask/test_serialization.py index ae797c2f..02861dcd 100644 --- a/tests/durabletask/test_serialization.py +++ b/tests/durabletask/test_serialization.py @@ -12,7 +12,7 @@ from collections import namedtuple from dataclasses import dataclass from types import SimpleNamespace -from typing import List, Optional, Union, get_args +from typing import Any, List, Optional, Union, get_args import pytest @@ -422,6 +422,33 @@ def test_coerce_dict_values_recursively(): assert isinstance(result["home"], Address) +def test_coerce_bare_any_returns_value_unchanged(): + # ``Any`` carries no type info; the parsed value is already the result. + value = {"k": "v"} + assert coerce_to_type(value, Any) is value + + +def test_coerce_optional_any_returns_value_unchanged(): + # ``Any | None`` must pass the value through rather than raising on the + # ``isinstance(value, Any)`` check inside the union loop. + value = {"k": "v"} + assert coerce_to_type(value, Optional[Any]) == {"k": "v"} + + +def test_coerce_dataclass_with_any_field_round_trips(): + # A plain dataclass whose field is annotated ``Any | None`` must reconstruct + # with that field left as the raw parsed JSON, not crash. + @dataclass + class Envelope: + name: str + payload: Any | None = None + + restored = from_json(to_json(Envelope("x", {"nested": [1, 2]})), Envelope) + assert isinstance(restored, Envelope) + assert restored.name == "x" + assert restored.payload == {"nested": [1, 2]} + + # ----- from_json converter hook (PR #154 follow-up) -----