diff --git a/README.md b/README.md index ba22a0e..eedc645 100644 --- a/README.md +++ b/README.md @@ -150,6 +150,17 @@ failed run can be resumed by re-running the same command. | `1` | Completed with failures — see the printed report. | | `2` | Could not run (setup error or `--on-name-conflict=abort` collision). | +#### Compatibility + +`unstract clone` is capability-probed: each phase checks for its endpoint on the +source and target, and clones only what both orgs support. A capability missing on +either side is reported and skipped — the run never fails because of a version +difference. Cloning a newer source into an older target therefore drops the entity +types the target lacks (listed in the end-of-run report). + +- Run the source and target on the same (or a newer-target) Unstract build. +- Use `unstract-client >= 1.4.0`, the first release that ships `unstract clone`. + ## Questions and Feedback On Slack, [join great conversations](https://join-slack.unstract.com/) around LLMs, their ecosystem and leveraging them to automate the previously unautomatable! diff --git a/src/unstract/clone/client.py b/src/unstract/clone/client.py index 44ced51..ed8a70b 100644 --- a/src/unstract/clone/client.py +++ b/src/unstract/clone/client.py @@ -94,12 +94,12 @@ def _request( return resp.json() def get_post_schema(self, entity_path: str) -> frozenset[str]: - """Return the set of fields the backend's POST serializer accepts. + """Return the set of fields the backend's POST accepts. - Reads it from a DRF ``OPTIONS`` response (``actions.POST``) once - per path and caches the result. DRF ``SimpleMetadata`` already - excludes ``read_only`` fields from ``actions.POST``, so the - returned set is exactly the writable subset. + Reads it from an ``OPTIONS`` response (``actions.POST``) once per + path and caches the result. Read-only fields are already absent + from ``actions.POST``, so the returned set is exactly the writable + subset. """ cached = self._post_schema_cache.get(entity_path) if cached is not None: @@ -113,6 +113,21 @@ def get_post_schema(self, entity_path: str) -> frozenset[str]: self._post_schema_cache[entity_path] = writable return writable + def probe(self, path: str) -> bool: + """Capability probe: is this feature's route installed on this deployment? + + GET ``path`` and return True on 200, False on 404 (route absent = + feature not built into this deployment). Any other status / transport + error re-raises — a real failure must not look like "feature missing". + """ + try: + self._request("GET", path) + except PlatformAPIError as e: + if e.status_code == 404: + return False + raise + return True + # ----- org users & groups ----- def list_users(self) -> list[dict[str, Any]]: @@ -215,21 +230,20 @@ def list_custom_tools(self) -> list[dict[str, Any]]: return result if isinstance(result, list) else result.get("results", []) def get_custom_tool(self, tool_id: str) -> dict[str, Any]: - """Fetch a single prompt-studio project (full serializer). + """Fetch a single prompt-studio project. - Returns ``fields = "__all__"`` per ``CustomToolSerializer`` — - notably includes ``output`` (the default DocumentManager id the - FE binds to ``selectedDoc`` on load). + Notably includes ``output`` (the default document id the UI + selects on load). """ return self._request("GET", f"prompt-studio/{tool_id}/") def update_custom_tool(self, tool_id: str, body: dict[str, Any]) -> dict[str, Any]: """PATCH a prompt-studio project. Used to set ``output`` (the - default doc id) after the files phase populates DM rows.""" + default doc id) after the files phase uploads documents.""" return self._request("PATCH", f"prompt-studio/{tool_id}/", json=body) def list_profiles(self, tool_id: str) -> list[dict[str, Any]]: - """List ProfileManager rows for a tool. + """List the adapter profiles for a tool. The clone reads this on the source only — to discover the default profile's adapter UUIDs so they can be remapped to @@ -238,6 +252,18 @@ def list_profiles(self, tool_id: str) -> list[dict[str, Any]]: result = self._request("GET", f"prompt-studio/prompt-studio-profile/{tool_id}/") return result if isinstance(result, list) else result.get("results", []) + def list_prompts(self, tool_id: str) -> list[dict[str, Any]]: + """List a tool's prompts (``prompt_id`` + ``prompt_key`` per row). + + Used to map source prompt ids to the target prompts created by + ``import_project`` / ``sync_prompts`` (matched by ``prompt_key``), + so prompt-scoped cloud config can remap its FKs. + """ + result = self._request( + "GET", "prompt-studio/prompt/", params={"tool_id": tool_id} + ) + return result if isinstance(result, list) else result.get("results", []) + def export_project(self, tool_id: str) -> dict[str, Any]: """Export a prompt-studio project as a portable JSON blob. @@ -255,17 +281,15 @@ def import_project( ) -> dict[str, Any]: """Import a prompt-studio project from an export blob. - Backend creates the tool, builds the default ProfileManager from - the supplied target-org adapter ids, and imports all prompts in - one call. On name collision the backend silently uniquifies the - new tool's name — callers should pre-check via - ``list_custom_tools`` to avoid that. + Creates the tool, the default adapter profile from the supplied + target-org adapter ids, and all prompts in one call. On name + collision the new tool comes back with a uniquified name — callers + should pre-check via ``list_custom_tools`` to avoid that. - ``adapter_ids`` keys are the backend's form fields: - ``llm_adapter_id``, ``vector_db_adapter_id``, - ``embedding_adapter_id``, ``x2text_adapter_id``. All four - required to wire the profile; otherwise backend falls back to - a profile without adapters and flags ``needs_adapter_config``. + ``adapter_ids`` keys are the form fields ``llm_adapter_id``, + ``vector_db_adapter_id``, ``embedding_adapter_id``, + ``x2text_adapter_id``. All four required to wire the profile; + otherwise the response flags ``needs_adapter_config``. """ tool_name = export_data.get("tool_metadata", {}).get("tool_name") or "export" content = json_lib.dumps(export_data).encode() @@ -308,12 +332,11 @@ def sync_prompts( ) def list_prompt_documents(self, tool_id: str) -> list[dict[str, Any]]: - """List DocumentManager rows for a tool. + """List a tool's prompt documents. Used by FilesPhase for target-side idempotency and source-side enumeration. Response items carry ``document_id``, - ``document_name``, and ``tool`` (per the serializer's - ``to_representation`` filter). + ``document_name``, and ``tool``. """ result = self._request( "GET", "prompt-studio/prompt-document/", params={"tool_id": tool_id} @@ -321,13 +344,13 @@ def list_prompt_documents(self, tool_id: str) -> list[dict[str, Any]]: return result if isinstance(result, list) else result.get("results", []) def download_prompt_file(self, tool_id: str, document_id: str) -> dict[str, Any]: - """GET a Prompt Studio document by tool + DM row id. + """GET a Prompt Studio document by tool + document id. - ``fetch_contents_ide`` resolves the filename internally from the - DocumentManager row, so the SDK passes the ``document_id`` it - already has from ``list_prompt_documents`` rather than reposting - the filename. Returns ``{"data": ..., "mime_type": ...}`` — - PDFs base64, text/csv utf-8, Excel placeholder. + The endpoint resolves the filename from the document id, so the + SDK passes the ``document_id`` it already has from + ``list_prompt_documents`` rather than reposting the filename. + Returns ``{"data": ..., "mime_type": ...}`` — PDFs base64, + text/csv utf-8, Excel placeholder. """ return self._request( "GET", @@ -344,20 +367,19 @@ def upload_prompt_file( ) -> dict[str, Any]: """Upload a file into a target Prompt Studio tool. - Backend writes bytes to storage and creates a ``DocumentManager`` - row. The DM model has ``UniqueConstraint(document_name, tool)``, - so callers must pre-check via ``list_prompt_documents`` to avoid - an IntegrityError → 500 on re-runs. + Filenames are unique per tool, so callers must pre-check via + ``list_prompt_documents`` to avoid a duplicate-name error on + re-runs. """ files = {"file": (file_name, data, mime_type)} return self._request("POST", f"prompt-studio/file/{tool_id}", files=files) def export_custom_tool(self, tool_id: str, *, force: bool = True) -> Any: - """Republish ``PromptStudioRegistry`` from the tool's current state. + """Republish the tool's registry entry from its current state. - Called after import/sync so the registry row reflects the - freshly landed prompts. Required for ToolInstancePhase to find - a target registry id to remap. + Called after import/sync so the registry reflects the freshly + landed prompts. Required for ToolInstancePhase to find a target + registry id to remap. """ return self._request( "POST", @@ -391,7 +413,7 @@ def create_workflow(self, payload: dict[str, Any]) -> dict[str, Any]: def list_registries( self, *, custom_tool: str | None = None ) -> list[dict[str, Any]]: - """List PromptStudioRegistry rows. The list endpoint returns nothing + """List prompt-studio registry rows. The list endpoint returns nothing unless a filter is supplied; pass ``custom_tool`` to look up the registry id for a given tool. """ @@ -414,17 +436,17 @@ def list_tool_instances( return result if isinstance(result, list) else result.get("results", []) def create_tool_instance(self, payload: dict[str, Any]) -> dict[str, Any]: - """Create a tool instance (max 1 per workflow). The backend overwrites - the ``metadata`` field with tool defaults — caller must PATCH after - create to transfer source metadata. + """Create a tool instance (max 1 per workflow). The created row comes + back with default ``metadata`` — caller must PATCH after create to + transfer source metadata. """ return self._request("POST", "tool_instance/", json=payload) def update_tool_instance_metadata( self, instance_id: str, metadata: dict[str, Any] ) -> dict[str, Any]: - """PATCH a tool instance's metadata. Backend resolves adapter names - in the payload to local UUIDs via ``update_instance_metadata``. + """PATCH a tool instance's metadata. Adapter names in the payload are + resolved to local UUIDs server-side. """ return self._request( "PATCH", f"tool_instance/{instance_id}/", json={"metadata": metadata} @@ -475,8 +497,8 @@ def get_pipeline(self, pipeline_id: str) -> dict[str, Any]: return self._request("GET", f"pipeline/{pipeline_id}/") def create_pipeline(self, payload: dict[str, Any]) -> dict[str, Any]: - """Create a pipeline. Backend force-sets ``active=True`` and auto-creates - a single active API key on the new pipeline. + """Create a pipeline. The new pipeline comes back active with a single + active API key auto-provisioned. """ return self._request("POST", "pipeline/", json=payload) @@ -533,3 +555,369 @@ def create_api_key(self, payload: dict[str, Any]) -> dict[str, Any]: and cannot be carried over from source. """ return self._request("POST", "api/keys/api/", json=payload) + + # ----- lookups (cloud-only) ----- + + def list_lookup_definitions(self) -> list[dict[str, Any]]: + """List lookup definitions in this org. Also the capability-probe path.""" + result = self._request("GET", "lookups/definitions/") + return result if isinstance(result, list) else (result or {}).get("results", []) + + def get_lookup_definition(self, lookup_id: str) -> dict[str, Any]: + """Fetch a lookup definition's detail. + + Detail inlines the draft content: ``prompt_template``, + ``draft_version_id``, ``input_vars``, and ``adapters`` (a dict with + ``llm`` / ``x2text`` adapter UUIDs, either possibly ``None``). + """ + return self._request("GET", f"lookups/definitions/{lookup_id}/") + + def create_lookup_definition(self, payload: dict[str, Any]) -> dict[str, Any]: + """Create a lookup definition. The new definition comes with an empty + DRAFT version and default adapters; populate it via the + draft/adapters/file endpoints below. + """ + return self._request("POST", "lookups/definitions/", json=payload) + + def update_lookup_draft_template( + self, lookup_id: str, prompt_template: str + ) -> dict[str, Any]: + """Set the draft version's prompt template.""" + return self._request( + "PATCH", + f"lookups/definitions/{lookup_id}/draft/", + json={"prompt_template": prompt_template}, + ) + + def update_lookup_draft_adapters( + self, lookup_id: str, adapters: dict[str, str] + ) -> dict[str, Any]: + """Set the draft version's LLM and/or X2Text adapters by target UUID. + + ``adapters`` may carry either or both of ``llm`` / ``x2text``; absent + keys leave the existing draft adapter untouched. + """ + return self._request( + "PATCH", + f"lookups/definitions/{lookup_id}/adapters/", + json=adapters, + ) + + def list_lookup_files(self, lookup_id: str) -> list[dict[str, Any]]: + """List a lookup's draft reference files (rows carry ``file_id``, + ``file_name``, ``file_size``). + """ + result = self._request("GET", f"lookups/definitions/{lookup_id}/files/") + return result if isinstance(result, list) else (result or {}).get("results", []) + + def download_lookup_file(self, lookup_id: str, file_id: str) -> bytes: + """Download a reference file's original bytes. + + Returns raw bytes — the content route serves an ``HttpResponse`` body + (not a JSON envelope), so this bypasses the JSON-decoding request path. + """ + url = self._url(f"lookups/definitions/{lookup_id}/files/{file_id}/content/") + logger.debug("GET %s", url) + resp = self._session.get(url, timeout=self.timeout, verify=self.verify) + if not 200 <= resp.status_code < 300: + raise PlatformAPIError( + f"GET lookups/definitions/{lookup_id}/files/{file_id}/content/ " + f"returned {resp.status_code}", + status_code=resp.status_code, + body=resp.text[:2000], + ) + return resp.content + + def upload_lookup_file( + self, lookup_id: str, file_name: str, data: bytes, mime_type: str + ) -> dict[str, Any]: + """Upload a reference file into a lookup's draft version. + + Filenames are unique per draft version, so callers pre-check via + ``list_lookup_files`` to avoid a 409. + """ + files = {"file": (file_name, data, mime_type)} + return self._request( + "POST", f"lookups/definitions/{lookup_id}/files/", files=files + ) + + def list_lookup_assignments(self) -> list[dict[str, Any]]: + """List prompt-lookup assignment rows in this org. + + Each row carries ``assignment_id``, ``prompt`` (source prompt uuid), + ``version`` (source lookup-version uuid), ``lookup_definition`` + (source lookup_id), ``is_draft_version``, and ``variable_mappings``. + """ + result = self._request("GET", "lookups/assignments/") + return result if isinstance(result, list) else (result or {}).get("results", []) + + def create_lookup_assignment(self, payload: dict[str, Any]) -> dict[str, Any]: + """Create a prompt-lookup assignment. + + Writable: ``prompt``, ``lookup_definition`` (required), ``version``, + ``variable_mappings``. At most one assignment per prompt, so callers + pre-check target assignments. + """ + return self._request("POST", "lookups/assignments/", json=payload) + + def update_lookup_share( + self, lookup_id: str, payload: dict[str, Any] + ) -> dict[str, Any]: + """Replicate share state onto a lookup via its detail PATCH. + + ``payload`` carries ``shared_to_org`` + ``shared_users`` (target user + pks). Lookups expose no group-sharing axis, so no ``shared_groups``. + """ + return self._request( + "PATCH", f"lookups/definitions/{lookup_id}/", json=payload + ) + + def list_lookup_versions(self, lookup_id: str) -> list[dict[str, Any]]: + """List a lookup's versions (draft + published). + + Rows carry ``version_id``, ``is_draft``, ``version_number``, + ``version_name``; the detail (``get_lookup_version``) inlines content. + """ + result = self._request( + "GET", f"lookups/definitions/{lookup_id}/versions/" + ) + if isinstance(result, list): + return result + # This endpoint wraps rows as {"versions": [...], "next_version_number"}. + return (result or {}).get("versions", (result or {}).get("results", [])) + + def get_lookup_version( + self, lookup_id: str, version_id: str + ) -> dict[str, Any]: + """Fetch a version's detail (``prompt_template``, adapters, files).""" + return self._request( + "GET", f"lookups/definitions/{lookup_id}/versions/{version_id}/" + ) + + def download_lookup_version_file( + self, lookup_id: str, version_id: str, file_id: str + ) -> bytes: + """Download a published version's reference-file bytes (raw body).""" + path = ( + f"lookups/definitions/{lookup_id}/versions/{version_id}/" + f"files/{file_id}/content/" + ) + url = self._url(path) + logger.debug("GET %s", url) + resp = self._session.get(url, timeout=self.timeout, verify=self.verify) + if not 200 <= resp.status_code < 300: + raise PlatformAPIError( + f"GET {path} returned {resp.status_code}", + status_code=resp.status_code, + body=resp.text[:2000], + ) + return resp.content + + def publish_lookup_version( + self, lookup_id: str, payload: dict[str, Any] + ) -> dict[str, Any]: + """Freeze the current draft into a published version. + + ``payload`` carries ``version_name`` (+ optional ``rebind_assignments``). + Returns the new published version (``version_id``). Used to replay a + source lookup's published-version history onto the target. + """ + return self._request( + "POST", f"lookups/definitions/{lookup_id}/versions/", json=payload + ) + + # ----- manual review / HITL (cloud-only) ----- + # + # Each workflow holds one review-rule row per ``rule_type`` (DB / API) + # and one HITL-settings row. The workflow-scoped GET routes take the + # workflow id in the URL path and wrap the row in ``{"data": ...}``; + # they 404 (rules) / 500 (settings) when none exists — callers treat a + # missing row as "nothing to clone", not an error. + + MR_RULE_TYPES: tuple[str, ...] = ("DB", "API") + + def get_review_rule( + self, workflow_id: str, rule_type: str + ) -> dict[str, Any] | None: + """Fetch a workflow's review rule for one ``rule_type``. + + Returns the rule dict (with nested ``confidence_filters``) or ``None`` + when no rule of that type exists (backend answers 404). + """ + try: + body = self._request( + "GET", + f"manual_review/rule_engine/workflow/{workflow_id}/", + params={"rule_type": rule_type}, + ) + except PlatformAPIError as e: + if e.status_code == 404: + return None + raise + return (body or {}).get("data") + + def create_review_rule(self, payload: dict[str, Any]) -> dict[str, Any]: + """Create a review rule (+ nested ``confidence_filters``). + + Writable: ``workflow`` (required), ``rule_type``, ``percentage``, + ``rule_string``, ``rule_json``, ``rule_logic``, ``confidence_filters``. + Unique per workflow + ``rule_type`` within the org. + """ + return self._request("POST", "manual_review/rule_engine/", json=payload) + + def get_review_settings(self, workflow_id: str) -> dict[str, Any] | None: + """Fetch a workflow's review settings, or ``None`` if absent. + + The route answers 500 (not 404) when no row exists, so only a 500 is + treated as "no settings to clone". Other errors (401/403/429) must + surface — suppressing them would silently drop configured settings. + """ + try: + body = self._request( + "GET", f"manual_review/settings/workflow/{workflow_id}/" + ) + except PlatformAPIError as e: + if e.status_code == 500: + return None + raise + return (body or {}).get("data") + + def create_review_settings(self, payload: dict[str, Any]) -> dict[str, Any]: + """Create a review settings row. + + Writable: ``workflow`` (one per workflow, required), ``sync_with``, + ``ttl_hours``. + """ + return self._request("POST", "manual_review/settings/", json=payload) + + def list_auto_approval_settings(self) -> list[dict[str, Any]]: + """List org-level auto-approval settings (0 or 1 per org). + + Returns 200 bare with no query params, so it doubles as the + manual-review capability probe path. + """ + result = self._request("GET", "manual_review/auto_approval_settings/") + return result if isinstance(result, list) else (result or {}).get("results", []) + + def create_auto_approval_settings( + self, payload: dict[str, Any] + ) -> dict[str, Any]: + """Create org-level auto-approval settings. + + Writable: ``auto_approved_document_classes``, ``auto_approved_users``. + ``organization`` is server-set. Unique per org. + """ + return self._request( + "POST", "manual_review/auto_approval_settings/", json=payload + ) + + def list_review_api_keys(self) -> list[dict[str, Any]]: + """List review API keys in this org.""" + result = self._request("GET", "manual_review/api/keys/") + return result if isinstance(result, list) else (result or {}).get("results", []) + + def create_review_api_key(self, payload: dict[str, Any]) -> dict[str, Any]: + """Create a review API key. The ``api_key`` secret is server-minted + and cannot be carried over from source. + Writable: ``class_name``, ``description``, ``is_active``. + """ + return self._request("POST", "manual_review/api/key/", json=payload) + + # ----- agentic studio (cloud-only) ----- + + def list_agentic_projects(self) -> list[dict[str, Any]]: + """List agentic projects in this org. Also the capability-probe path. + + Rows carry ``id``, ``name``, ``description``, the four adapter FK ids + (``llm_connector_id`` / ``agent_llm_connector_id`` / + ``lightweight_llm_connector_id`` / ``text_extractor_connector_id``), + and ``canary_fields``. + """ + result = self._request("GET", "agentic/projects/") + return result if isinstance(result, list) else (result or {}).get("results", []) + + def create_agentic_project(self, payload: dict[str, Any]) -> dict[str, Any]: + """Create an agentic project. Returns the created row (carries ``id``).""" + return self._request("POST", "agentic/projects/", json=payload) + + def list_agentic_prompt_versions( + self, *, project_id: str | None = None + ) -> list[dict[str, Any]]: + """List agentic prompt versions, optionally scoped to a project. + + Rows carry ``id``, ``project``, ``version``, ``prompt_text``, + ``accuracy``, ``is_active``, and the self-FK ``parent_version``. + """ + params: dict[str, Any] = {} + if project_id is not None: + params["project_id"] = project_id + result = self._request("GET", "agentic/prompt-versions/", params=params) + return result if isinstance(result, list) else (result or {}).get("results", []) + + def create_agentic_prompt_version(self, payload: dict[str, Any]) -> dict[str, Any]: + """Create an agentic prompt version (flat endpoint, ``project`` in body).""" + return self._request("POST", "agentic/prompt-versions/", json=payload) + + def list_agentic_schemas( + self, *, project_id: str | None = None + ) -> list[dict[str, Any]]: + """List agentic schemas, optionally scoped to a project. + + Rows carry ``id``, ``project``, ``json_schema``, ``version``, + ``is_active``. + """ + params: dict[str, Any] = {} + if project_id is not None: + params["project_id"] = project_id + result = self._request("GET", "agentic/schemas/", params=params) + return result if isinstance(result, list) else (result or {}).get("results", []) + + def create_agentic_schema(self, payload: dict[str, Any]) -> dict[str, Any]: + """Create an agentic schema (flat endpoint, ``project`` in body).""" + return self._request("POST", "agentic/schemas/", json=payload) + + def list_agentic_settings(self) -> list[dict[str, Any]]: + """List agentic settings. Org-wide key/value rows (no project FK).""" + result = self._request("GET", "agentic/settings/") + return result if isinstance(result, list) else (result or {}).get("results", []) + + def create_agentic_setting(self, payload: dict[str, Any]) -> dict[str, Any]: + """Create an org-wide agentic setting.""" + return self._request("POST", "agentic/settings/", json=payload) + + def update_agentic_setting( + self, setting_id: str, payload: dict[str, Any] + ) -> dict[str, Any]: + """PATCH an existing agentic setting by id.""" + return self._request("PATCH", f"agentic/settings/{setting_id}/", json=payload) + + def export_agentic_project(self, project_id: str, *, force: bool = True) -> Any: + """Republish the project's registry entry from its active schema + + prompt. Mirror of ``export_custom_tool``. + + Requires an active schema and active prompt; ``force_export`` + bypasses the completion check. Caller re-reads the registry to learn + the new id. + """ + return self._request( + "POST", + f"agentic/projects/{project_id}/export/", + json={ + "is_shared_with_org": False, + "user_ids": [], + "force_export": force, + }, + ) + + def list_agentic_registries( + self, *, agentic_project: str | None = None + ) -> list[dict[str, Any]]: + """List agentic registry rows. The list endpoint returns nothing + unless a filter is supplied; pass ``agentic_project`` to look up the + registry id for a given project. + """ + params: dict[str, Any] = {} + if agentic_project is not None: + params["agentic_project"] = agentic_project + result = self._request("GET", "agentic-studio-registry/", params=params) + return result if isinstance(result, list) else (result or {}).get("results", []) diff --git a/src/unstract/clone/context.py b/src/unstract/clone/context.py index 238e837..aeef4c3 100644 --- a/src/unstract/clone/context.py +++ b/src/unstract/clone/context.py @@ -145,3 +145,20 @@ class CloneContext: # touches them once per endpoint, never per resource). share_cache: dict[str, Any] = field(default_factory=dict) share_cache_lock: threading.Lock = field(default_factory=threading.Lock) + # Capability-probe memo: (id(client), feature_path) -> present?. Probed + # once per (deployment, feature) so cloud-phase gating costs one GET total. + probe_cache: dict[tuple[int, str], bool] = field(default_factory=dict) + + def feature_present(self, client: "PlatformClient", path: str) -> bool: + """Is ``path`` (a feature's list endpoint) installed on ``client``? + + Memoised per run. Plain dict, no lock — probing runs in the + single-threaded orchestrator loop, before any parallel_map fan-out. + """ + key = (id(client), path) + cached = self.probe_cache.get(key) + if cached is not None: + return cached + present = client.probe(path) + self.probe_cache[key] = present + return present diff --git a/src/unstract/clone/orchestrator.py b/src/unstract/clone/orchestrator.py index 85df6ea..c516254 100644 --- a/src/unstract/clone/orchestrator.py +++ b/src/unstract/clone/orchestrator.py @@ -19,11 +19,14 @@ from unstract.clone.exceptions import CloneError from unstract.clone.phases import ( AdapterPhase, + AgenticStudioPhase, APIDeploymentPhase, ConnectorPhase, CustomToolPhase, FilesPhase, GroupPhase, + LookupsPhase, + ManualReviewPhase, PipelinePhase, TagPhase, ToolInstancePhase, @@ -46,11 +49,20 @@ PHASES: list[tuple[str, type[Phase]]] = [ ("group", GroupPhase), ("adapter", AdapterPhase), + # Cloud-only; standalone (own project + registry) and FKs four adapters. + # Probe-gated: auto-skips on OSS deployments via ``probe_path``. + ("agentic_studio", AgenticStudioPhase), ("connector", ConnectorPhase), ("tag", TagPhase), ("custom_tool", CustomToolPhase), ("files", FilesPhase), + # Cloud-only; consumes custom_tool's prompt + adapter remaps. Probe-gated: + # auto-skips on OSS deployments via ``probe_path``. + ("lookups", LookupsPhase), ("workflow", WorkflowPhase), + # Cloud-only; review rules and settings bind to the workflow. + # Probe-gated: auto-skips on OSS deployments via ``probe_path``. + ("manual_review", ManualReviewPhase), ("tool_instance", ToolInstancePhase), ("workflow_endpoint", WorkflowEndpointPhase), ("pipeline", PipelinePhase), @@ -58,6 +70,37 @@ ] +def _cloud_phase_runnable( + ctx: CloneContext, report: CloneReport, name: str, probe_path: str +) -> bool: + """Decide whether a cloud-only phase should run on this deployment pair. + + Probe source first; only probe target if source has the feature. A probe + failure (unexpected status / transport) must not abort an otherwise-fine + run — treat it like target-absent: warn + skip, never raise. + """ + try: + if not ctx.feature_present(ctx.source, probe_path): + # OSS source: behave exactly as if this phase didn't exist. + logger.debug("Phase '%s' skipped: feature absent on source", name) + return False + target_present = ctx.feature_present(ctx.target, probe_path) + except Exception as e: + msg = f"Phase '{name}' skipped: capability probe failed ({e})" + logger.warning(msg) + report.warnings.append(msg) + return False + if not target_present: + msg = ( + f"Phase '{name}' skipped: feature present on source but not on " + "target deployment" + ) + logger.warning(msg) + report.warnings.append(msg) + return False + return True + + def clone( source: OrgEndpoint, target: OrgEndpoint, @@ -93,6 +136,11 @@ def clone( report.skipped_phases.append(name) logger.info("Phase '%s' skipped (excluded)", name) continue + probe_path = getattr(phase_cls, "probe_path", None) + if probe_path is not None and not _cloud_phase_runnable( + ctx, report, name, probe_path + ): + continue logger.info("=== Phase: %s ===", name) phase_started = time.perf_counter() try: diff --git a/src/unstract/clone/phases/__init__.py b/src/unstract/clone/phases/__init__.py index 0c3a9a6..ffff94f 100644 --- a/src/unstract/clone/phases/__init__.py +++ b/src/unstract/clone/phases/__init__.py @@ -8,12 +8,15 @@ """ from unstract.clone.phases.adapter import AdapterPhase +from unstract.clone.phases.agentic_studio import AgenticStudioPhase from unstract.clone.phases.api_deployment import APIDeploymentPhase from unstract.clone.phases.base import Phase from unstract.clone.phases.connector import ConnectorPhase from unstract.clone.phases.custom_tool import CustomToolPhase from unstract.clone.phases.files import FilesPhase from unstract.clone.phases.group import GroupPhase +from unstract.clone.phases.lookups import LookupsPhase +from unstract.clone.phases.manual_review import ManualReviewPhase from unstract.clone.phases.pipeline import PipelinePhase from unstract.clone.phases.tag import TagPhase from unstract.clone.phases.tool_instance import ToolInstancePhase @@ -23,10 +26,13 @@ __all__ = [ "APIDeploymentPhase", "AdapterPhase", + "AgenticStudioPhase", "ConnectorPhase", "CustomToolPhase", "FilesPhase", "GroupPhase", + "LookupsPhase", + "ManualReviewPhase", "Phase", "PipelinePhase", "TagPhase", diff --git a/src/unstract/clone/phases/adapter.py b/src/unstract/clone/phases/adapter.py index 873fa13..6717cec 100644 --- a/src/unstract/clone/phases/adapter.py +++ b/src/unstract/clone/phases/adapter.py @@ -4,9 +4,8 @@ against target, POST create if missing, record source->target UUID in the remap table for downstream phases. -Frictionless onboarding adapters are excluded — the backend's -service-account queryset already filters them out, so clone never -sees them. +Frictionless onboarding adapters are excluded — they're already filtered +out of what this org's Platform key can list, so clone never sees them. """ from __future__ import annotations diff --git a/src/unstract/clone/phases/agentic_studio.py b/src/unstract/clone/phases/agentic_studio.py new file mode 100644 index 0000000..6f51311 --- /dev/null +++ b/src/unstract/clone/phases/agentic_studio.py @@ -0,0 +1,587 @@ +"""Migrate cloud-only Agentic ("Agentic Prompt Studio") projects + children. + +Cloud-only: gated by ``probe_path`` — the orchestrator probes +``agentic/projects/`` on source/target and skips the phase entirely on an +OSS deployment. Runs after ``adapter`` (a project FKs four adapters). + +Standalone: owns its own project + registry, independent of ``custom_tool``. + +Per source project: + +1. Adopt-by-name if the target already has a project with the same ``name`` + (respecting ``on_name_conflict``), else create fresh from the OPTIONS + schema, remapping the four adapter FKs + (``llm_connector_id`` / ``agent_llm_connector_id`` / + ``lightweight_llm_connector_id`` / ``text_extractor_connector_id``) via the + ``adapter`` remap. An adapter that doesn't resolve is omitted with a warning + (the slot stays unset, like a draft lookup adapter). Records an + ``agentic_project`` remap. +2. In dependency order under the target project: + - **prompt-versions**: ``parent_version`` is a self-FK, so parents (whose + ``parent_version`` is ``None``) clone first; each new id is recorded in a + per-run ``agentic_prompt_version`` table and the child's ``parent_version`` + resolves through it. ``project`` is bound to the target id. + - **schemas**: bound to the target ``project`` and created. +3. **Registry**: if the project has an active schema + prompt, republish its + registry entry via the ``export`` action (mirror of custom_tool) and + record an ``agentic_studio_registry`` remap. Projects with no source + registry are left unexported. + +Org-level agentic-setting rows (global key/value, no project FK) are cloned +once as a flat adopt-by-key / create pass — they are org singletons, not +per-project config. +""" + +from __future__ import annotations + +import logging +import threading +from typing import Any + +from unstract.clone.exceptions import NameConflictError +from unstract.clone.phases.base import Phase, build_post_payload +from unstract.clone.report import CloneReport, PhaseResult +from unstract.clone.sharing import apply_share_state + +logger = logging.getLogger(__name__) + +AGENTIC_PROJECTS_PATH = "agentic/projects/" + +# Project adapter FK slots: serializer field name -> ``adapter`` remap source. +_ADAPTER_SLOTS: tuple[str, ...] = ( + "llm_connector_id", + "agent_llm_connector_id", + "lightweight_llm_connector_id", + "text_extractor_connector_id", +) + + +class AgenticStudioPhase(Phase): + name = "agentic_studio" + probe_path = AGENTIC_PROJECTS_PATH + + def run(self, report: CloneReport) -> PhaseResult: + result = report.get_phase(self.name) + try: + self._writable = self.ctx.target.get_post_schema(AGENTIC_PROJECTS_PATH) + except Exception as e: + logger.exception("Failed to fetch target POST schema for agentic: %s", e) + result.failed += 1 + result.errors.append(f"OPTIONS agentic projects: {e}") + return result + try: + src_projects = self.ctx.source.list_agentic_projects() + except Exception as e: + logger.exception("Failed to list source agentic projects: %s", e) + result.failed += 1 + result.errors.append(f"list source agentic projects: {e}") + return result + try: + tgt_projects = self.ctx.target.list_agentic_projects() + except Exception as e: + logger.exception("Failed to list target agentic projects: %s", e) + result.failed += 1 + result.errors.append(f"list target agentic projects: {e}") + return result + + logger.info("Found %d agentic project(s) in source org", len(src_projects)) + # Updated under lock on fresh create so same-name source rows adopt. + target_by_name: dict[str, dict[str, Any]] = { + p["name"]: p for p in tgt_projects + } + + self.parallel_map( + src_projects, + lambda src, lock: self._clone_project(src, target_by_name, result, lock), + ) + + # Org-global settings: a flat pass, not tied to any project. + self._clone_settings(result) + return result + + # ----- projects ----- + + def _clone_project( + self, + src: dict[str, Any], + target_by_name: dict[str, dict[str, Any]], + result: PhaseResult, + lock: threading.Lock, + ) -> None: + name = src["name"] + src_project_id = src["id"] + + with lock: + match = target_by_name.get(name) + + if match is not None: + if self.ctx.options.on_name_conflict == "abort": + raise NameConflictError( + f"agentic project '{name}' already exists in target " + f"as {match['id']}" + ) + tgt_project_id = match["id"] + with lock: + result.adopted += 1 + self.ctx.remap.record("agentic_project", src_project_id, tgt_project_id) + logger.info( + "adopted agentic project '%s' src=%s -> tgt=%s", + name, + src_project_id, + tgt_project_id, + ) + elif self.ctx.options.dry_run: + with lock: + result.created += 1 + self.ctx.remap.record_planned("agentic_project", src_project_id) + logger.info( + "[dry-run] would create agentic project '%s' src=%s", + name, + src_project_id, + ) + # Plan child ids so downstream plan-counts stay consistent. + self._plan_children(src_project_id, result, lock) + return + else: + payload = self._build_project_payload(src, name, result, lock) + try: + tgt = self.ctx.target.create_agentic_project(payload) + except Exception as e: + logger.exception("Failed to create agentic project '%s': %s", name, e) + with lock: + result.failed += 1 + result.errors.append(f"create agentic project {name}: {e}") + return + tgt_project_id = tgt["id"] + with lock: + result.created += 1 + target_by_name[name] = {"id": tgt_project_id, "name": name} + self.ctx.remap.record("agentic_project", src_project_id, tgt_project_id) + logger.info( + "created agentic project '%s' src=%s -> tgt=%s", + name, + src_project_id, + tgt_project_id, + ) + + # Children + registry + share write to the real target only. + if self.ctx.options.dry_run: + return + self._replicate_share(src, name, tgt_project_id, result, lock) + self._clone_prompt_versions(name, src_project_id, tgt_project_id, result, lock) + self._clone_schemas(name, src_project_id, tgt_project_id, result, lock) + self._republish_registry(name, src_project_id, tgt_project_id, result, lock) + + def _build_project_payload( + self, + src: dict[str, Any], + name: str, + result: PhaseResult, + lock: threading.Lock, + ) -> dict[str, Any]: + payload = build_post_payload(src, self._writable) + # Remap each adapter FK; omit any slot that doesn't resolve (leaves it + # unset — the operator wires it on target later). + for slot in _ADAPTER_SLOTS: + src_adapter_id = src.get(slot) + if not src_adapter_id: + payload.pop(slot, None) + continue + tgt_adapter_id = self.ctx.remap.resolve("adapter", str(src_adapter_id)) + if tgt_adapter_id is None: + payload.pop(slot, None) + logger.warning( + "agentic project '%s': %s adapter %s has no target mapping — " + "leaving it unset", + name, + slot, + src_adapter_id, + ) + with lock: + result.warnings.append( + f"agentic project {name}: {slot} adapter not remapped — " + "left unset" + ) + continue + payload[slot] = tgt_adapter_id + return payload + + # ----- share state ----- + + def _replicate_share( + self, + src: dict[str, Any], + name: str, + tgt_project_id: str, + result: PhaseResult, + lock: threading.Lock, + ) -> None: + # The list row already carries shared_users / shared_groups / + # shared_to_org, so no detail fetch needed. Share axes are + # write-protected on the detail endpoint (a detail PATCH is a silent + # no-op); they're written via the dedicated share action, which handles + # the group axis too — so groups replicate like every other shared + # resource. + apply_share_state( + self.ctx, + share_path=f"agentic/projects/{tgt_project_id}/share/", + entity_label=f"agentic project '{name}'", + src=src, + result=result, + lock=lock, + ) + + # ----- prompt versions ----- + + def _clone_prompt_versions( + self, + name: str, + src_project_id: str, + tgt_project_id: str, + result: PhaseResult, + lock: threading.Lock, + ) -> None: + try: + src_versions = self.ctx.source.list_agentic_prompt_versions( + project_id=src_project_id + ) + except Exception as e: + logger.exception("agentic '%s': prompt-version listing failed: %s", name, e) + with lock: + result.failed += 1 + result.errors.append(f"agentic {name} list prompt-versions: {e}") + return + + # Adopt versions already on target (keyed by version number) so a + # re-run against the same pair doesn't re-create duplicates. + try: + tgt_versions = self.ctx.target.list_agentic_prompt_versions( + project_id=tgt_project_id + ) + except Exception as e: + logger.warning( + "agentic '%s': target prompt-version listing failed " + "(re-run may duplicate): %s", + name, + e, + ) + tgt_versions = [] + tgt_by_version = {v.get("version"): v for v in tgt_versions} + + # parent_version is a self-FK: clone roots (no parent) first so a child's + # parent already resolves. Sort by version ascending as a stable order. + ordered = sorted( + src_versions, + key=lambda v: (v.get("parent_version") is not None, v.get("version") or 0), + ) + for src in ordered: + self._clone_one_prompt_version( + name, src, tgt_project_id, tgt_by_version, result, lock + ) + + def _clone_one_prompt_version( + self, + name: str, + src: dict[str, Any], + tgt_project_id: str, + tgt_by_version: dict[Any, dict[str, Any]], + result: PhaseResult, + lock: threading.Lock, + ) -> None: + src_vid = src["id"] + + existing = tgt_by_version.get(src.get("version")) + if existing is not None: + with lock: + result.adopted += 1 + self.ctx.remap.record( + "agentic_prompt_version", src_vid, existing["id"] + ) + return + payload = { + k: v + for k, v in src.items() + if k + in { + "version", + "short_desc", + "long_desc", + "prompt_text", + "accuracy", + "is_active", + "created_by_agent", + } + and v is not None + } + payload["project"] = tgt_project_id + + src_parent = src.get("parent_version") + if src_parent is not None: + tgt_parent = self.ctx.remap.resolve( + "agentic_prompt_version", str(src_parent) + ) + if tgt_parent is not None: + payload["parent_version"] = tgt_parent + else: + # Root cloned first should always resolve; warn but keep the row. + logger.warning( + "agentic '%s': prompt v%s parent %s unresolved — left unset", + name, + src.get("version"), + src_parent, + ) + with lock: + result.warnings.append( + f"agentic {name}: prompt version {src.get('version')} " + "parent not remapped — left unset" + ) + + try: + tgt = self.ctx.target.create_agentic_prompt_version(payload) + except Exception as e: + logger.exception( + "agentic '%s': prompt v%s create failed: %s", + name, + src.get("version"), + e, + ) + with lock: + result.failed += 1 + result.errors.append( + f"agentic {name} create prompt v{src.get('version')}: {e}" + ) + return + with lock: + result.created += 1 + self.ctx.remap.record("agentic_prompt_version", src_vid, tgt["id"]) + + # ----- schemas ----- + + def _clone_schemas( + self, + name: str, + src_project_id: str, + tgt_project_id: str, + result: PhaseResult, + lock: threading.Lock, + ) -> None: + try: + src_schemas = self.ctx.source.list_agentic_schemas( + project_id=src_project_id + ) + except Exception as e: + logger.exception("agentic '%s': schema listing failed: %s", name, e) + with lock: + result.failed += 1 + result.errors.append(f"agentic {name} list schemas: {e}") + return + + # Adopt schemas already on target (keyed by version) so a re-run + # doesn't re-create duplicates. + try: + tgt_schemas = self.ctx.target.list_agentic_schemas( + project_id=tgt_project_id + ) + except Exception as e: + logger.warning( + "agentic '%s': target schema listing failed (re-run may " + "duplicate): %s", + name, + e, + ) + tgt_schemas = [] + existing_versions = {s.get("version") for s in tgt_schemas} + + for src in src_schemas: + if src.get("version") in existing_versions: + with lock: + result.adopted += 1 + continue + payload = { + k: v + for k, v in src.items() + if k in {"json_schema", "version", "is_active", "created_by_agent"} + and v is not None + } + payload["project"] = tgt_project_id + try: + self.ctx.target.create_agentic_schema(payload) + except Exception as e: + logger.exception( + "agentic '%s': schema v%s create failed: %s", + name, + src.get("version"), + e, + ) + with lock: + result.failed += 1 + result.errors.append( + f"agentic {name} create schema v{src.get('version')}: {e}" + ) + continue + with lock: + result.created += 1 + + # ----- registry ----- + + def _republish_registry( + self, + name: str, + src_project_id: str, + tgt_project_id: str, + result: PhaseResult, + lock: threading.Lock, + ) -> None: + # Projects never exported on source (no active schema/prompt) have no + # registry row; republishing would fail the same backend guard. + try: + src_regs = self.ctx.source.list_agentic_registries( + agentic_project=src_project_id + ) + except Exception as e: + logger.warning( + "agentic '%s': source registry lookup failed — registry not " + "republished: %s", + name, + e, + ) + with lock: + result.warnings.append( + f"agentic {name}: source registry lookup failed — " + "registry not republished" + ) + return + + if not src_regs: + # Nothing to republish — project was never exported. + return + + try: + self.ctx.target.export_agentic_project(tgt_project_id) + except Exception as e: + # Export needs an active schema + prompt on target; a partial clone + # can leave it unexportable. Non-fatal: warn and move on. + logger.warning( + "agentic '%s': registry republish failed (export) tgt=%s: %s", + name, + tgt_project_id, + e, + ) + with lock: + result.warnings.append( + f"agentic {name}: registry not republished in v1 " + f"(export failed: {e})" + ) + return + + try: + tgt_regs = self.ctx.target.list_agentic_registries( + agentic_project=tgt_project_id + ) + except Exception as e: + logger.warning( + "agentic '%s': target registry lookup failed after export: %s", + name, + e, + ) + with lock: + result.warnings.append( + f"agentic {name}: target registry id not recorded after export" + ) + return + + if src_regs and tgt_regs: + with lock: + self.ctx.remap.record( + "agentic_studio_registry", + src_regs[0]["registry_id"], + tgt_regs[0]["registry_id"], + ) + logger.info("republished agentic registry for project '%s'", name) + + # ----- dry-run planning ----- + + def _plan_children( + self, + src_project_id: str, + result: PhaseResult, + lock: threading.Lock, + ) -> None: + """Dry-run: count source prompt versions + schemas as planned and + record planned prompt-version ids so downstream resolves don't miss. + """ + try: + src_versions = self.ctx.source.list_agentic_prompt_versions( + project_id=src_project_id + ) + except Exception: + src_versions = [] + try: + src_schemas = self.ctx.source.list_agentic_schemas( + project_id=src_project_id + ) + except Exception: + src_schemas = [] + with lock: + for v in src_versions: + self.ctx.remap.record_planned("agentic_prompt_version", v["id"]) + result.created += 1 + result.created += len(src_schemas) + + # ----- settings ----- + + def _clone_settings(self, result: PhaseResult) -> None: + if self.ctx.options.dry_run: + try: + src_settings = self.ctx.source.list_agentic_settings() + except Exception: + src_settings = [] + result.created += len(src_settings) + return + + try: + src_settings = self.ctx.source.list_agentic_settings() + except Exception as e: + logger.exception("Failed to list source agentic settings: %s", e) + result.failed += 1 + result.errors.append(f"list source agentic settings: {e}") + return + if not src_settings: + return + try: + tgt_by_key = { + s["key"]: s for s in self.ctx.target.list_agentic_settings() + } + except Exception as e: + logger.exception("Failed to list target agentic settings: %s", e) + result.failed += 1 + result.errors.append(f"list target agentic settings: {e}") + return + + for src in src_settings: + key = src.get("key") + if not key: + continue + payload = { + k: v + for k, v in src.items() + if k in {"key", "value", "description"} and v is not None + } + existing = tgt_by_key.get(key) + try: + if existing is not None: + self.ctx.target.update_agentic_setting(existing["id"], payload) + result.adopted += 1 + else: + self.ctx.target.create_agentic_setting(payload) + result.created += 1 + except Exception as e: + # An agentic-setting key can collide with a row this org's + # listing doesn't surface — not data loss the clone can + # resolve. Warn, don't fail. + logger.warning("agentic setting '%s' not replicated: %s", key, e) + result.skipped += 1 + result.warnings.append( + f"agentic setting {key}: not replicated " + f"(org-global key may already exist elsewhere): {e}" + ) diff --git a/src/unstract/clone/phases/base.py b/src/unstract/clone/phases/base.py index c7710d3..cc45b4c 100644 --- a/src/unstract/clone/phases/base.py +++ b/src/unstract/clone/phases/base.py @@ -18,13 +18,12 @@ logger = logging.getLogger(__name__) -# DRF OPTIONS reports any ModelSerializer FK/M2M as writable, but the -# backend's perform_create overrides these server-side. Posting them is -# either noise (silently overwritten) or a 400 (when a source-org value -# doesn't validate against the target org). Strip them universally — -# the phase OPTIONS schema covers the entity-specific writable subset. -# ``shared_users`` stays stripped on create — share state is replicated -# post-create instead (see sharing.py). +# OPTIONS reports any FK/M2M as writable, but the backend overrides these +# server-side on create. Posting them is either noise (silently overwritten) +# or a 400 (when a source-org value doesn't validate against the target org). +# Strip them universally — the phase OPTIONS schema covers the entity-specific +# writable subset. ``shared_users`` stays stripped on create — share state is +# replicated post-create instead (see sharing.py). SERVER_MANAGED: frozenset[str] = frozenset( { "id", @@ -61,6 +60,11 @@ class Phase(ABC): # Share endpoint template for shareable resource types, e.g. # "adapter/{id}/share/" ({id} = target pk). None = not shareable. share_path_template: str | None = None + # Capability-gate for cloud-only phases. When set, the orchestrator probes + # this list endpoint on source/target before running and applies the skip + # matrix (source absent → silent skip; target absent → warn + skip). Core + # OSS phases leave it None and always run (no probe call at all). + probe_path: str | None = None def __init__(self, ctx: CloneContext): self.ctx = ctx diff --git a/src/unstract/clone/phases/connector.py b/src/unstract/clone/phases/connector.py index 9f9c8ca..c593e42 100644 --- a/src/unstract/clone/phases/connector.py +++ b/src/unstract/clone/phases/connector.py @@ -3,18 +3,17 @@ Same list -> per-id GET -> POST/adopt pattern as AdapterPhase. Two connector-specific wrinkles: -1. **Connectors with redacted metadata are skipped.** The backend - serializer strips ``connector_metadata`` for auto-provisioned rows - (e.g. Unstract Cloud Storage), so the SDK cannot reconstruct them - on the target. We detect this by inspecting the source GET response: +1. **Connectors with redacted metadata are skipped.** Auto-provisioned + rows (e.g. Unstract Cloud Storage) come back without + ``connector_metadata``, so the SDK cannot reconstruct them on the + target. We detect this by inspecting the source GET response: a falsy ``connector_metadata`` means the operator must rely on the target's own provisioning (or re-create the row manually) — the remap table records no entry for these. -2. **OAuth ``connector_auth`` is stripped from responses.** Tokens are - stored in a sibling ``ConnectorAuth`` row that the public API never - exposes, so OAuth-backed connectors land on the target without - refresh tokens. Operator must re-authorise on target. +2. **OAuth ``connector_auth`` is stripped from responses.** OAuth refresh + tokens are never returned by the API, so OAuth-backed connectors land + on the target without them. Operator must re-authorise on target. """ from __future__ import annotations @@ -31,9 +30,9 @@ CONNECTOR_PATH = "connector/" -# Backend POST serializer trips on these keys (connector_v2/serializers.py) -# by trying to refresh against the source user's social auth — guaranteed -# OAuthTimeOut on target. Detect here and skip ahead of POST. +# A POST carrying these OAuth token keys triggers a token refresh against +# the source user's credentials — guaranteed to fail on the target. Detect +# here and skip ahead of POST. _OAUTH_TOKEN_KEYS: frozenset[str] = frozenset({"access_token", "refresh_token"}) diff --git a/src/unstract/clone/phases/custom_tool.py b/src/unstract/clone/phases/custom_tool.py index 372c3fd..84d3f6d 100644 --- a/src/unstract/clone/phases/custom_tool.py +++ b/src/unstract/clone/phases/custom_tool.py @@ -6,7 +6,7 @@ portable JSON blob (tool_metadata, tool_settings, default_profile_settings, prompts, export_metadata). 2. Decides fresh vs adopt by looking up the target tool by name. -3. **Fresh path**: reads source's default ProfileManager to learn the +3. **Fresh path**: reads the source's default adapter profile to learn the adapter UUIDs the profile is bound to, remaps each via the running ``adapter`` remap table, and POSTs the import as a multipart upload with target-org adapter ids on the form. Backend creates the tool, @@ -15,16 +15,20 @@ Backend rip-and-replaces prompts + ``tool_settings`` and leaves the target's locally-configured profiles + adapters untouched (which is what the operator wants — they may have rewired adapters on target). -5. Republishes ``PromptStudioRegistry`` via the export action and +5. Republishes the tool's registry entry via the export action and records the ``custom_tool`` + ``prompt_studio_registry`` remaps so - downstream ToolInstancePhase can rewrite ``ToolInstance.tool_id``. + downstream ToolInstancePhase can rewrite the tool instance's tool id. Skipped for tools with no source registry entry (never exported — e.g. empty projects, which the backend refuses to export). -Adapter id discovery for the fresh path needs all four of LLM, -vector_db, embedding, x2text. If any source adapter can't be resolved -via the adapter remap, the tool is failed cleanly — we never want to -land a half-wired profile. +Adapter id discovery for the fresh path resolves each of LLM, +vector_db, embedding, x2text via the adapter remap on a best-effort +basis. Any that can't be resolved are left unconfigured — the backend +imports the tool with a partial/empty profile and flags +``needs_adapter_config`` for the operator to finish wiring on target +and re-run. Frictionless-bound tools (adapters not even visible to the +source org's Platform key) are the exception: cloud-only with no target +equivalent, so they are skipped + cascade. """ from __future__ import annotations @@ -82,7 +86,7 @@ def run(self, report: CloneReport) -> PhaseResult: result.errors.append(f"list target tools: {e}") return result - # Source's service-account view hides frictionless adapters; a + # The source's visible adapter set hides frictionless adapters; a # profile-referenced name missing here flags a tool we can't migrate. try: self._src_adapter_names = { @@ -167,8 +171,15 @@ def _clone_one( # needs a prompt_studio_registry remap to plan-count. Mirror it # with a planned id derived from the source registry (read-only). self._record_planned_registry(src_tool_id, tool_name, lock) + self._record_planned_prompts(src_tool_id, lock) return + # Map source prompt ids -> target prompt ids by prompt_key so + # prompt-scoped phases (e.g. lookup assignments) can rewrite their + # prompt FKs. Target prompts already exist here (created by + # import_project on fresh, sync_prompts on adopt). + self._remap_prompts(src_tool_id, tgt_tool_id, tool_name, lock) + # Tools never exported on source (e.g. empty projects — backend # blocks their export) have no registry entry and no workflow # references; republishing would fail the same backend guard. @@ -199,10 +210,16 @@ def _clone_one( "republished registry for tool '%s' tgt=%s", tool_name, tgt_tool_id ) except Exception as e: - logger.exception("Registry republish failed for tool %s: %s", tool_name, e) + # Republish can 500 on incomplete/stale source registries (e.g. + # empty run prompts). The tool itself cloned fine; only its + # registry entry is missing, so downstream tool_instances + # cascade-skip. Warn rather than fail the whole tool. + logger.warning("Registry republish failed for tool '%s': %s", tool_name, e) with lock: - result.failed += 1 - result.errors.append(f"export {tool_name}: {e}") + result.warnings.append( + f"republish {tool_name}: skipped ({e}) — downstream tool " + "instances will cascade-skip until re-published" + ) return try: @@ -251,6 +268,58 @@ def _record_planned_registry( "prompt_studio_registry", src_regs[0]["prompt_registry_id"] ) + def _remap_prompts( + self, + src_tool_id: str, + tgt_tool_id: str, + tool_name: str, + lock: threading.Lock, + ) -> None: + """Record source->target prompt-id remaps, matched by prompt_key. + + Best-effort: a prompt without a matching key on target is skipped + (the dependent phase counts it as unresolved), and a listing + failure leaves the remap empty rather than failing the tool. + """ + try: + src_prompts = self.ctx.source.list_prompts(src_tool_id) + tgt_prompts = self.ctx.target.list_prompts(tgt_tool_id) + except Exception as e: + logger.warning( + "prompt-id remap skipped for tool '%s' " + "(dependent prompt-scoped phases may under-resolve): %s", + tool_name, + e, + ) + return + # prompt_key is effectively unique per tool; first match wins. + tgt_by_key = {p["prompt_key"]: p["prompt_id"] for p in tgt_prompts} + with lock: + for sp in src_prompts: + tgt_pid = tgt_by_key.get(sp["prompt_key"]) + if tgt_pid: + self.ctx.remap.record("prompt", sp["prompt_id"], tgt_pid) + + def _record_planned_prompts( + self, src_tool_id: str, lock: threading.Lock + ) -> None: + """Dry-run: record a planned prompt remap per source prompt so + prompt-scoped phases can resolve their FK and plan-count. + """ + try: + src_prompts = self.ctx.source.list_prompts(src_tool_id) + except Exception as e: + logger.warning( + "[dry-run] source prompt listing failed for tool %s " + "(prompt-scoped plan may under-count): %s", + src_tool_id, + e, + ) + return + with lock: + for sp in src_prompts: + self.ctx.remap.record_planned("prompt", sp["prompt_id"]) + def _adopt( self, match: dict[str, Any], @@ -306,29 +375,26 @@ def _create_fresh( result: PhaseResult, lock: threading.Lock, ) -> str | None: - # Run the source-side validations even in dry-run — they decide - # whether a real run would create or frictionless-skip, so the plan - # counts must reflect them. Only the target-write steps are stubbed. + # Run the source-side checks even in dry-run — they decide whether a + # real run would create or frictionless-skip, so plan counts must + # reflect them. Only the target-write steps are stubbed. default_profile = self._source_default_profile(src_tool_id, tool_name) - if default_profile is None: - with lock: - result.failed += 1 - result.errors.append( - f"import {tool_name}: no default profile on source" - ) - return None - invisible = self._invisible_source_adapter_names(default_profile) - if invisible: - self._register_frictionless_skip( - src_tool_id, tool_name, invisible, result, lock - ) - return None + # Frictionless adapters are cloud-only with no target equivalent — + # skip + cascade. Only checkable when a default profile exists; a + # profile-less tool is mirrored unconfigured (below). + if default_profile is not None: + invisible = self._invisible_source_adapter_names(default_profile) + if invisible: + self._register_frictionless_skip( + src_tool_id, tool_name, invisible, result, lock + ) + return None if self.ctx.options.dry_run: # Target adapter resolution is skipped: adapters this run would - # create don't exist on target yet, so it can't resolve. The - # frictionless check above already caught the real skip cases. + # create don't exist on target yet. The frictionless check above + # already caught the real skip cases. with lock: result.created += 1 tgt_tool_id = self.ctx.remap.record_planned("custom_tool", src_tool_id) @@ -337,14 +403,15 @@ def _create_fresh( ) return tgt_tool_id - adapter_ids = self._resolve_target_adapter_ids(default_profile, tool_name) - if adapter_ids is None: - with lock: - result.failed += 1 - result.errors.append( - f"import {tool_name}: missing target adapter remap for default" - ) - return None + # Best-effort adapter wiring: resolve what maps, leave the rest + # unconfigured. The backend tolerates a partial/empty set and flags + # needs_adapter_config — mirror an incomplete source tool rather than + # fail the clone (operator finishes wiring on target and re-runs). + adapter_ids = ( + self._resolve_target_adapter_ids(default_profile, tool_name) + if default_profile is not None + else {} + ) try: tgt = self.ctx.target.import_project(export_data, adapter_ids=adapter_ids) @@ -356,15 +423,21 @@ def _create_fresh( return None tgt_tool_id = tgt["tool_id"] + needs_cfg = tgt.get("needs_adapter_config") with lock: result.created += 1 self.ctx.remap.record("custom_tool", src_tool_id, tgt_tool_id) + if needs_cfg: + result.warnings.append( + f"tool {tool_name}: imported without full adapter config — " + "wire adapters on target and re-run to complete downstream" + ) logger.info( "created tool '%s' src=%s -> tgt=%s (needs_adapter_config=%s)", tool_name, src_tool_id, tgt_tool_id, - tgt.get("needs_adapter_config"), + needs_cfg, ) return tgt_tool_id @@ -416,7 +489,7 @@ def _register_frictionless_skip( """ logger.warning( "skipping tool '%s' src=%s — default profile references adapters " - "not visible to the source service account (frictionless?): %s. " + "not visible to this org's adapter listing (frictionless?): %s. " "Wire equivalents on target and re-run.", tool_name, src_tool_id, @@ -441,14 +514,15 @@ def _register_frictionless_skip( def _resolve_target_adapter_ids( self, default_profile: dict[str, Any], tool_name: str - ) -> dict[str, str] | None: + ) -> dict[str, str]: """Source profile carries adapter NAMES (per serializer); resolve - each name to a target adapter UUID via ``list_adapters(name=...)``. + each to a target adapter UUID via ``list_adapters(name=...)``. - Returns ``None`` if any of the four required adapters can't be - found on target — caller fails the tool. AdapterPhase preserves - names across orgs so this lookup should always hit when the - adapter clone ran cleanly. + Best-effort: adapters that can't be resolved are omitted (not fatal). + The backend tolerates a partial/empty set and flags + ``needs_adapter_config``. AdapterPhase preserves names across orgs, + so a miss means the adapter wasn't cloned (frictionless, or a failed + adapter clone) — the operator wires it on target and re-runs. """ resolved: dict[str, str] = {} for src_field, form_field in _PROFILE_ADAPTER_FIELDS: @@ -459,7 +533,7 @@ def _resolve_target_adapter_ids( tool_name, src_field, ) - return None + continue try: matches = self.ctx.target.list_adapters(name=adapter_name) except Exception as e: @@ -469,14 +543,15 @@ def _resolve_target_adapter_ids( tool_name, e, ) - return None + continue if not matches: logger.warning( - "no target adapter named '%s' for field %s on tool '%s'", + "no target adapter named '%s' for field %s on tool '%s' — " + "left unconfigured", adapter_name, src_field, tool_name, ) - return None + continue resolved[form_field] = matches[0]["id"] return resolved diff --git a/src/unstract/clone/phases/lookups.py b/src/unstract/clone/phases/lookups.py new file mode 100644 index 0000000..8597356 --- /dev/null +++ b/src/unstract/clone/phases/lookups.py @@ -0,0 +1,869 @@ +"""Migrate cloud-only Lookup definitions + their prompt assignments. + +Cloud-only: gated by ``probe_path`` — the orchestrator probes +``lookups/definitions/`` on source/target and skips the phase entirely on an +OSS deployment. Runs after ``custom_tool`` (consumes its ``prompt`` and +``adapter`` remaps) and after ``files``. + +Two passes: + +1. **Definitions** — per source lookup, adopt-by-name or create fresh. + Creating a definition auto-spawns an empty DRAFT version with default + adapters; this phase then patches the draft's ``prompt_template`` and + remaps the draft adapters (LLM / X2Text) to target-org ids, and replays + the reference files into the draft (reusing the size-cap / file-strategy + semantics of the files phase). Records a ``lookup_definition`` remap. + +2. **Assignments** — after every lookup + prompt remap exists, replay the + prompt-lookup assignment rows. Each row's source prompt FK remaps via the + ``custom_tool`` phase's ``prompt`` table; its target version resolves via + the ``lookup_version`` remap (recorded by the version replay below — draft + pins fall back to the cloned lookup's ``draft_version_id``); + ``variable_mappings`` values that are source prompt UUIDs remap too. + +Published-version history is reproduced per definition (see ``_replay_versions``) +so published-pinned assignments resolve, before the target draft is restored to +the source's current draft state. +""" + +from __future__ import annotations + +import logging +import mimetypes +import threading +from typing import Any + +from unstract.clone.exceptions import NameConflictError +from unstract.clone.phases.base import Phase, build_post_payload +from unstract.clone.report import CloneReport, PhaseResult +from unstract.clone.sharing import replicate_share + +logger = logging.getLogger(__name__) + +LOOKUP_DEFINITIONS_PATH = "lookups/definitions/" + +# Draft adapter slots; each maps a detail ``adapters`` key to the PATCH key. +_ADAPTER_SLOTS: tuple[str, ...] = ("llm", "x2text") + +_DEFAULT_MIME = "application/octet-stream" + + +class LookupsPhase(Phase): + name = "lookups" + probe_path = LOOKUP_DEFINITIONS_PATH + + def run(self, report: CloneReport) -> PhaseResult: + result = report.get_phase(self.name) + try: + self._writable = self.ctx.target.get_post_schema(LOOKUP_DEFINITIONS_PATH) + except Exception as e: + logger.exception("Failed to fetch target POST schema for lookups: %s", e) + result.failed += 1 + result.errors.append(f"OPTIONS lookups: {e}") + return result + try: + src_lookups = self.ctx.source.list_lookup_definitions() + except Exception as e: + logger.exception("Failed to list source lookup definitions: %s", e) + result.failed += 1 + result.errors.append(f"list source lookups: {e}") + return result + try: + tgt_lookups = self.ctx.target.list_lookup_definitions() + except Exception as e: + logger.exception("Failed to list target lookup definitions: %s", e) + result.failed += 1 + result.errors.append(f"list target lookups: {e}") + return result + + logger.info("Found %d lookup definition(s) in source org", len(src_lookups)) + # Updated under lock on fresh create so same-name source rows adopt. + target_by_name: dict[str, dict[str, Any]] = { + lk["name"]: lk for lk in tgt_lookups + } + + # Pass 1: definitions (+ draft content + reference files). + self.parallel_map( + src_lookups, + lambda src, lock: self._clone_definition(src, target_by_name, result, lock), + ) + + # Pass 2: assignments — needs every lookup + prompt remap from above. + self._clone_assignments(result) + return result + + # ----- definitions ----- + + def _clone_definition( + self, + src: dict[str, Any], + target_by_name: dict[str, dict[str, Any]], + result: PhaseResult, + lock: threading.Lock, + ) -> None: + name = src["name"] + src_lookup_id = src["lookup_id"] + + with lock: + match = target_by_name.get(name) + + if match is not None: + if self.ctx.options.on_name_conflict == "abort": + raise NameConflictError( + f"lookup '{name}' already exists in target as {match['lookup_id']}" + ) + tgt_lookup_id = match["lookup_id"] + with lock: + result.adopted += 1 + self.ctx.remap.record("lookup_definition", src_lookup_id, tgt_lookup_id) + logger.info( + "adopted lookup '%s' src=%s -> tgt=%s", + name, + src_lookup_id, + tgt_lookup_id, + ) + elif self.ctx.options.dry_run: + with lock: + result.created += 1 + self.ctx.remap.record_planned("lookup_definition", src_lookup_id) + logger.info( + "[dry-run] would create lookup '%s' src=%s", name, src_lookup_id + ) + self._plan_versions(name, src_lookup_id, lock) + return + else: + payload = build_post_payload(src, self._writable) + try: + tgt = self.ctx.target.create_lookup_definition(payload) + except Exception as e: + logger.exception("Failed to create lookup '%s': %s", name, e) + with lock: + result.failed += 1 + result.errors.append(f"create lookup {name}: {e}") + return + tgt_lookup_id = tgt["lookup_id"] + with lock: + result.created += 1 + target_by_name[name] = {"lookup_id": tgt_lookup_id, "name": name} + self.ctx.remap.record("lookup_definition", src_lookup_id, tgt_lookup_id) + logger.info( + "created lookup '%s' src=%s -> tgt=%s", + name, + src_lookup_id, + tgt_lookup_id, + ) + + # Draft content + reference files write to the real target draft only. + if self.ctx.options.dry_run: + self._plan_versions(name, src_lookup_id, lock) + return + + try: + detail = self.ctx.source.get_lookup_definition(src_lookup_id) + except Exception as e: + logger.warning( + "lookup '%s': source detail fetch failed — draft content not " + "replicated: %s", + name, + e, + ) + with lock: + result.warnings.append( + f"lookup {name}: source detail fetch failed — " + f"draft template/adapters not replicated: {e}" + ) + detail = None + + if detail is not None: + # Share state is server-managed on create; mirror it post-create. + replicate_share( + self.ctx, + apply_fn=lambda p: self.ctx.target.update_lookup_share( + tgt_lookup_id, p + ), + entity_label=f"lookup '{name}'", + src=detail, + result=result, + lock=lock, + include_groups=False, # lookups model has no group sharing + ) + + # Reproduce published-version history first so its draft churn lands + # before the final draft restore (last publish spawns a fresh draft). + self._replay_versions(name, src_lookup_id, tgt_lookup_id, result, lock) + + # Restore target draft to the source's CURRENT draft — must run AFTER + # the replay so the final draft matches source, not the last published. + if detail is not None: + self._replicate_draft(name, tgt_lookup_id, detail, result, lock) + self._replicate_files(name, src_lookup_id, tgt_lookup_id, result, lock) + + # Draft pins resolve uniformly through the version remap too. + if detail is not None: + self._record_draft_version_remap(name, tgt_lookup_id, detail, lock) + + def _replicate_draft( + self, + name: str, + tgt_lookup_id: str, + detail: dict[str, Any], + result: PhaseResult, + lock: threading.Lock, + ) -> None: + template = detail.get("prompt_template") or "" + if template: + try: + self.ctx.target.update_lookup_draft_template(tgt_lookup_id, template) + except Exception as e: + logger.exception( + "lookup '%s': draft template patch failed: %s", name, e + ) + with lock: + result.failed += 1 + result.errors.append(f"lookup {name} draft template: {e}") + + tgt_adapters = self._remap_adapters( + name, detail.get("adapters") or {}, result, lock + ) + if tgt_adapters: + try: + self.ctx.target.update_lookup_draft_adapters( + tgt_lookup_id, tgt_adapters + ) + except Exception as e: + logger.exception( + "lookup '%s': draft adapters patch failed: %s", name, e + ) + with lock: + result.failed += 1 + result.errors.append(f"lookup {name} draft adapters: {e}") + + def _remap_adapters( + self, + name: str, + src_adapters: dict[str, Any], + result: PhaseResult, + lock: threading.Lock, + ) -> dict[str, str]: + """Resolve a detail/version ``adapters`` dict to target ids; unresolved + slots are omitted with a warning so the draft keeps its default. + """ + tgt_adapters: dict[str, str] = {} + for slot in _ADAPTER_SLOTS: + src_adapter_id = src_adapters.get(slot) + if not src_adapter_id: + continue + tgt_adapter_id = self.ctx.remap.resolve("adapter", src_adapter_id) + if tgt_adapter_id is None: + logger.warning( + "lookup '%s': %s adapter %s has no target mapping — " + "leaving draft default", + name, + slot, + src_adapter_id, + ) + with lock: + result.warnings.append( + f"lookup {name}: {slot} adapter not remapped — " + "draft kept its default adapter" + ) + continue + tgt_adapters[slot] = tgt_adapter_id + return tgt_adapters + + # ----- version replay ----- + + def _plan_versions( + self, name: str, src_lookup_id: str, lock: threading.Lock + ) -> None: + """Dry-run: record a planned ``lookup_version`` remap per source version + so the assignment pass can plan-count published-pinned ones too. + """ + try: + versions = self.ctx.source.list_lookup_versions(src_lookup_id) + except Exception as e: + logger.warning( + "[dry-run] lookup '%s': version listing failed: %s", name, e + ) + return + with lock: + for v in versions: + vid = v.get("version_id") + if vid: + self.ctx.remap.record_planned("lookup_version", str(vid)) + + def _replay_versions( + self, + name: str, + src_lookup_id: str, + tgt_lookup_id: str, + result: PhaseResult, + lock: threading.Lock, + ) -> None: + """Reproduce the source's published versions on the target in + ``version_number`` order: set the target draft to each version's + content + files, then publish. Records a ``lookup_version`` remap per + published version so its pinned assignments resolve. A target version + with the same ``version_name`` is adopted (no re-publish) so re-runs + and adopted definitions stay idempotent. + """ + try: + versions = self.ctx.source.list_lookup_versions(src_lookup_id) + except Exception as e: + logger.warning("lookup '%s': version listing failed: %s", name, e) + with lock: + result.warnings.append( + f"lookup {name}: version listing failed — published " + f"versions not replayed: {e}" + ) + return + + # Existing target versions — re-publishing a name that already exists + # would error or duplicate history; adopt those instead. + try: + tgt_versions = self.ctx.target.list_lookup_versions(tgt_lookup_id) + except Exception as e: + logger.warning( + "lookup '%s': target version listing failed: %s", name, e + ) + tgt_versions = [] + tgt_by_name: dict[str, str] = { + v["version_name"]: str(v["version_id"]) + for v in tgt_versions + if v.get("version_name") and v.get("version_id") + } + + published = sorted( + (v for v in versions if not v.get("is_draft")), + key=lambda v: v.get("version_number") or 0, + ) + for v in published: + existing_id = tgt_by_name.get(v.get("version_name")) + if existing_id is not None: + src_version_id = v.get("version_id") + with lock: + result.adopted += 1 + if src_version_id: + self.ctx.remap.record( + "lookup_version", str(src_version_id), existing_id + ) + logger.info( + "lookup '%s': adopted published version '%s' (already present)", + name, + v.get("version_name"), + ) + continue + self._replay_one_version( + name, src_lookup_id, tgt_lookup_id, v, result, lock + ) + + def _replay_one_version( + self, + name: str, + src_lookup_id: str, + tgt_lookup_id: str, + version_row: dict[str, Any], + result: PhaseResult, + lock: threading.Lock, + ) -> None: + src_version_id = version_row.get("version_id") + version_name = version_row.get("version_name") or "" + try: + detail = self.ctx.source.get_lookup_version( + src_lookup_id, src_version_id + ) + except Exception as e: + logger.exception( + "lookup '%s': version %s detail fetch failed: %s", + name, + src_version_id, + e, + ) + with lock: + result.failed += 1 + result.errors.append( + f"lookup {name} version {version_name}: detail fetch: {e}" + ) + return + + # Stage the version's content onto the target draft before publishing. + template = detail.get("prompt_template") or "" + if template: + try: + self.ctx.target.update_lookup_draft_template(tgt_lookup_id, template) + except Exception as e: + logger.exception( + "lookup '%s': version %s template patch failed: %s", + name, + version_name, + e, + ) + with lock: + result.failed += 1 + result.errors.append( + f"lookup {name} version {version_name} template: {e}" + ) + # Don't publish — a stale draft would freeze wrong content + # into the named version. + return + tgt_adapters = self._remap_adapters( + name, detail.get("adapters") or {}, result, lock + ) + if tgt_adapters: + try: + self.ctx.target.update_lookup_draft_adapters( + tgt_lookup_id, tgt_adapters + ) + except Exception as e: + logger.exception( + "lookup '%s': version %s adapters patch failed: %s", + name, + version_name, + e, + ) + with lock: + result.failed += 1 + result.errors.append( + f"lookup {name} version {version_name} adapters: {e}" + ) + # Don't publish a version with unmapped adapters. + return + + self._replay_version_files( + name, src_lookup_id, tgt_lookup_id, src_version_id, detail, result, lock + ) + + # Frozen assignment-value snapshots are captured at publish time from + # then-existing assignments; publishing happens before assignments are + # recreated, so historical snapshots may differ from source. + # Structure + pinning reproduce; frozen values are best-effort. + try: + published = self.ctx.target.publish_lookup_version( + tgt_lookup_id, + {"version_name": version_name, "rebind_assignments": False}, + ) + except Exception as e: + logger.exception( + "lookup '%s': publish of version %s failed: %s", + name, + version_name, + e, + ) + with lock: + result.failed += 1 + result.errors.append( + f"lookup {name} publish version {version_name}: {e}" + ) + return + + tgt_version_id = published.get("version_id") + with lock: + if src_version_id and tgt_version_id: + self.ctx.remap.record( + "lookup_version", str(src_version_id), str(tgt_version_id) + ) + result.warnings.append( + f"lookup {name} version {version_name}: published; frozen " + "assignment-value snapshots are best-effort (assignments " + "recreated after publish)" + ) + logger.info( + "lookup '%s': replayed published version '%s' src=%s -> tgt=%s", + name, + version_name, + src_version_id, + tgt_version_id, + ) + + def _replay_version_files( + self, + name: str, + src_lookup_id: str, + tgt_lookup_id: str, + src_version_id: str, + version_detail: dict[str, Any], + result: PhaseResult, + lock: threading.Lock, + ) -> None: + """Upload a published version's reference files onto the target draft, + deduped by filename against the current draft files. Honors the same + size-cap / file-strategy semantics as ``_clone_one_file``. + """ + if self.ctx.options.file_strategy == "skip": + return + src_files = version_detail.get("files") or [] + if not src_files: + return + try: + tgt_files = self.ctx.target.list_lookup_files(tgt_lookup_id) + except Exception as e: + logger.exception( + "lookup '%s': target file listing (version replay) failed: %s", + name, + e, + ) + with lock: + result.failed += 1 + result.errors.append(f"lookup {name} list target files: {e}") + return + target_names = {f.get("file_name") for f in tgt_files} + + for f in src_files: + file_name = f.get("file_name") + file_id = f.get("file_id") + if not file_name or not file_id or file_name in target_names: + continue + try: + raw = self.ctx.source.download_lookup_version_file( + src_lookup_id, src_version_id, file_id + ) + except Exception as e: + logger.exception( + "lookup '%s': version file '%s' download failed: %s", + name, + file_name, + e, + ) + with lock: + result.failed += 1 + result.errors.append( + f"lookup {name} download version file {file_name}: {e}" + ) + continue + if len(raw) > self.ctx.options.max_file_size: + with lock: + result.skipped += 1 + result.warnings.append( + f"lookup {name}: version reference file '{file_name}' " + f"({len(raw)} bytes) exceeds cap " + f"{self.ctx.options.max_file_size} — re-upload via UI" + ) + continue + mime = mimetypes.guess_type(file_name)[0] or _DEFAULT_MIME + try: + self.ctx.target.upload_lookup_file( + tgt_lookup_id, file_name, raw, mime + ) + except Exception as e: + logger.exception( + "lookup '%s': version file '%s' upload failed: %s", + name, + file_name, + e, + ) + with lock: + result.failed += 1 + result.errors.append( + f"lookup {name} upload version file {file_name}: {e}" + ) + continue + target_names.add(file_name) + + def _record_draft_version_remap( + self, + name: str, + tgt_lookup_id: str, + detail: dict[str, Any], + lock: threading.Lock, + ) -> None: + """Map the source DRAFT version id -> target draft_version_id so + draft-pinned assignments resolve uniformly via the version remap. + """ + src_draft_id = detail.get("draft_version_id") + if not src_draft_id: + return + try: + tgt_detail = self.ctx.target.get_lookup_definition(tgt_lookup_id) + tgt_draft_id = tgt_detail.get("draft_version_id") + except Exception as e: + logger.warning( + "lookup '%s': target draft id fetch failed: %s", name, e + ) + return + if tgt_draft_id: + with lock: + self.ctx.remap.record( + "lookup_version", str(src_draft_id), str(tgt_draft_id) + ) + + def _replicate_files( + self, + name: str, + src_lookup_id: str, + tgt_lookup_id: str, + result: PhaseResult, + lock: threading.Lock, + ) -> None: + try: + src_files = self.ctx.source.list_lookup_files(src_lookup_id) + except Exception as e: + logger.exception("lookup '%s': source file listing failed: %s", name, e) + with lock: + result.failed += 1 + result.errors.append(f"lookup {name} list files: {e}") + return + + if self.ctx.options.file_strategy == "skip": + for f in src_files: + file_name = f.get("file_name") + if not file_name: + continue + with lock: + result.skipped += 1 + result.warnings.append( + f"lookup {name}: reference file '{file_name}' not cloned " + "(file_strategy=skip) — re-upload via UI" + ) + return + + try: + tgt_files = self.ctx.target.list_lookup_files(tgt_lookup_id) + except Exception as e: + logger.exception("lookup '%s': target file listing failed: %s", name, e) + with lock: + result.failed += 1 + result.errors.append(f"lookup {name} list target files: {e}") + return + target_names = {f.get("file_name") for f in tgt_files} + + for f in src_files: + file_name = f.get("file_name") + file_id = f.get("file_id") + if not file_name or not file_id: + continue + if file_name in target_names: + with lock: + result.skipped += 1 + continue + self._clone_one_file( + name, src_lookup_id, tgt_lookup_id, file_name, file_id, result, lock + ) + + def _clone_one_file( + self, + name: str, + src_lookup_id: str, + tgt_lookup_id: str, + file_name: str, + file_id: str, + result: PhaseResult, + lock: threading.Lock, + ) -> None: + try: + raw = self.ctx.source.download_lookup_file(src_lookup_id, file_id) + except Exception as e: + logger.exception( + "lookup '%s': download of '%s' failed: %s", name, file_name, e + ) + with lock: + result.failed += 1 + result.errors.append(f"lookup {name} download {file_name}: {e}") + return + + if len(raw) > self.ctx.options.max_file_size: + with lock: + result.skipped += 1 + result.warnings.append( + f"lookup {name}: reference file '{file_name}' " + f"({len(raw)} bytes) exceeds cap " + f"{self.ctx.options.max_file_size} — re-upload via UI" + ) + return + + mime = mimetypes.guess_type(file_name)[0] or _DEFAULT_MIME + try: + self.ctx.target.upload_lookup_file(tgt_lookup_id, file_name, raw, mime) + except Exception as e: + logger.exception( + "lookup '%s': upload of '%s' failed: %s", name, file_name, e + ) + with lock: + result.failed += 1 + result.errors.append(f"lookup {name} upload {file_name}: {e}") + return + with lock: + result.created += 1 + logger.info("lookup '%s': uploaded reference file '%s'", name, file_name) + + # ----- assignments ----- + + def _clone_assignments(self, result: PhaseResult) -> None: + try: + src_assignments = self.ctx.source.list_lookup_assignments() + except Exception as e: + logger.exception("Failed to list source lookup assignments: %s", e) + result.failed += 1 + result.errors.append(f"list source lookup assignments: {e}") + return + + if not src_assignments: + return + + # Index existing target assignments by target prompt id (one per prompt) + # to honor that uniqueness on re-runs. + existing_by_prompt: dict[str, dict[str, Any]] = {} + if not self.ctx.options.dry_run: + try: + for a in self.ctx.target.list_lookup_assignments(): + pid = a.get("prompt") + if pid: + existing_by_prompt[str(pid)] = a + except Exception as e: + logger.warning( + "target assignment listing failed — re-run idempotency may " + "create duplicates: %s", + e, + ) + + # Cache target draft_version_id per target lookup id. + draft_cache: dict[str, str | None] = {} + + self.parallel_map( + src_assignments, + lambda a, lock: self._clone_one_assignment( + a, existing_by_prompt, draft_cache, result, lock + ), + ) + + def _clone_one_assignment( + self, + src: dict[str, Any], + existing_by_prompt: dict[str, dict[str, Any]], + draft_cache: dict[str, str | None], + result: PhaseResult, + lock: threading.Lock, + ) -> None: + assignment_id = src.get("assignment_id") + src_prompt_id = src.get("prompt") + src_lookup_id = src.get("lookup_definition") + src_version_id = src.get("version") + is_draft_version = bool(src.get("is_draft_version")) + + tgt_prompt_id = ( + self.ctx.remap.resolve("prompt", str(src_prompt_id)) + if src_prompt_id is not None + else None + ) + if tgt_prompt_id is None: + with lock: + result.skipped += 1 + result.warnings.append( + f"assignment {assignment_id}: source prompt {src_prompt_id} " + "has no target mapping (its tool wasn't cloned) — skipped" + ) + return + + tgt_lookup_id = ( + self.ctx.remap.resolve("lookup_definition", str(src_lookup_id)) + if src_lookup_id is not None + else None + ) + if tgt_lookup_id is None: + with lock: + result.skipped += 1 + result.warnings.append( + f"assignment {assignment_id}: source lookup {src_lookup_id} " + "has no target mapping — skipped" + ) + return + + mappings = self._remap_mappings(src.get("variable_mappings")) + + if self.ctx.options.dry_run: + with lock: + result.created += 1 + logger.info( + "[dry-run] would create assignment for prompt %s -> lookup %s", + tgt_prompt_id, + tgt_lookup_id, + ) + return + + with lock: + if str(tgt_prompt_id) in existing_by_prompt: + result.adopted += 1 + logger.info( + "adopted assignment: target prompt %s already has a lookup " + "assignment", + tgt_prompt_id, + ) + return + + # Both draft- and published-pinned assignments resolve through the + # version remap recorded by the replay; draft pins fall back to the + # target's current draft when the remap is absent. + tgt_version_id = ( + self.ctx.remap.resolve("lookup_version", str(src_version_id)) + if src_version_id is not None + else None + ) + if tgt_version_id is None and is_draft_version: + tgt_version_id = self._target_draft_version( + tgt_lookup_id, draft_cache, lock + ) + if tgt_version_id is None: + with lock: + result.skipped += 1 + result.warnings.append( + f"assignment {assignment_id}: source version {src_version_id} " + "could not be resolved on target — skipped" + ) + return + + payload = { + "prompt": tgt_prompt_id, + "lookup_definition": tgt_lookup_id, + "version": tgt_version_id, + "variable_mappings": mappings, + } + try: + self.ctx.target.create_lookup_assignment(payload) + except Exception as e: + logger.exception( + "Failed to create assignment for prompt %s: %s", tgt_prompt_id, e + ) + with lock: + result.failed += 1 + result.errors.append(f"create assignment {assignment_id}: {e}") + return + with lock: + result.created += 1 + existing_by_prompt[str(tgt_prompt_id)] = payload + logger.info( + "created assignment: prompt %s -> lookup %s", tgt_prompt_id, tgt_lookup_id + ) + + def _remap_mappings(self, mappings: Any) -> Any: + """Deep-walk ``mappings``; any string value that is a source prompt + UUID (present in the ``prompt`` remap) rewrites to its target id. + Non-prompt strings pass through untouched. + """ + if isinstance(mappings, dict): + return {k: self._remap_mappings(v) for k, v in mappings.items()} + if isinstance(mappings, list): + return [self._remap_mappings(v) for v in mappings] + if isinstance(mappings, str): + return self.ctx.remap.resolve("prompt", mappings) or mappings + return mappings + + def _target_draft_version( + self, + tgt_lookup_id: str, + draft_cache: dict[str, str | None], + lock: threading.Lock, + ) -> str | None: + with lock: + if tgt_lookup_id in draft_cache: + return draft_cache[tgt_lookup_id] + try: + detail = self.ctx.target.get_lookup_definition(tgt_lookup_id) + draft_id = detail.get("draft_version_id") + except Exception as e: + logger.warning("target lookup %s draft fetch failed: %s", tgt_lookup_id, e) + draft_id = None + with lock: + # A racing peer may have already cached a valid id; the GET is + # read-only so a real id always wins over this thread's failure. + if draft_cache.get(tgt_lookup_id) is None: + draft_cache[tgt_lookup_id] = draft_id + return draft_cache[tgt_lookup_id] diff --git a/src/unstract/clone/phases/manual_review.py b/src/unstract/clone/phases/manual_review.py new file mode 100644 index 0000000..3b73b41 --- /dev/null +++ b/src/unstract/clone/phases/manual_review.py @@ -0,0 +1,454 @@ +"""Migrate cloud-only Manual Review (HITL) configuration. + +Cloud-only: gated by ``probe_path`` — the orchestrator probes +``manual_review/auto_approval_settings/`` on source/target and skips the +phase entirely on an OSS deployment. ``auto_approval_settings/`` returns 200 +with no query params, so it is the only MR GET route safe to probe; the +``rule_engine`` / ``settings`` routes are workflow-scoped and need a workflow +id in the URL. + +Runs after ``workflow`` — every review-rule and review-settings row FKs a +workflow, so the workflow remap must already exist. + +Config only. Runtime/queue data (review-queue packets, edited_data, +highlights, documents) is deliberately out of scope. + +Three passes: + +1. **Per-workflow** — for each source workflow that has a target mapping, + replay its review rules (one per ``rule_type``, with nested + ``confidence_filters``) and its review-settings row, rebinding ``workflow`` + to the target id. Adopt-by-presence on the target. +2. **Auto-approval settings** — org-level, cloned once. ``auto_approved_users`` + holds source-org user pks; remapped by email (same as share replication). + ``auto_approved_document_classes`` holds workflow/class-name strings with no + reliable cross-org remap; carried verbatim with a warning. + +These config rows are workflow- or org-scoped and inherit visibility from +there — no per-entity share replication. +3. **Review API keys** — recreated (the secret is server-minted and cannot be + copied); operator is warned to re-wire external consumers. +""" + +from __future__ import annotations + +import logging +import threading +from typing import Any + +from unstract.clone.phases.base import Phase +from unstract.clone.report import CloneReport, PhaseResult +from unstract.clone.sharing import ( + is_service_account, + source_user_by_id, + target_user_id_by_email, +) + +logger = logging.getLogger(__name__) + +AUTO_APPROVAL_PATH = "manual_review/auto_approval_settings/" + +# Review-rule create fields (server-managed id / created_by / modified_by +# excluded); ``confidence_filters`` is nested. +_RULE_FIELDS: tuple[str, ...] = ( + "rule_type", + "percentage", + "rule_string", + "rule_json", + "rule_logic", +) +# Confidence-filter create fields (id server-managed). +_FILTER_FIELDS: tuple[str, ...] = ("field_key", "confidence_threshold") +# HITL-settings writable fields (workflow rebound separately; rest +# server-managed). +_SETTINGS_FIELDS: tuple[str, ...] = ("sync_with", "ttl_hours") + + +class ManualReviewPhase(Phase): + name = "manual_review" + probe_path = AUTO_APPROVAL_PATH + + def run(self, report: CloneReport) -> PhaseResult: + result = report.get_phase(self.name) + try: + src_workflows = self.ctx.source.list_workflows() + except Exception as e: + logger.exception("Failed to list source workflows for manual_review: %s", e) + result.failed += 1 + result.errors.append(f"list source workflows: {e}") + return result + + logger.info( + "manual_review: scanning %d source workflow(s) for HITL config", + len(src_workflows), + ) + self.parallel_map( + src_workflows, + lambda wf, lock: self._clone_workflow_config(wf, result, lock), + ) + + # Org-level entities — cloned once, outside the per-workflow fan-out. + self._clone_auto_approval(result) + self._clone_api_keys(result) + return result + + # ----- per-workflow rules + settings ----- + + def _clone_workflow_config( + self, src_wf: dict[str, Any], result: PhaseResult, lock: threading.Lock + ) -> None: + src_wf_id = src_wf["id"] + tgt_wf_id = self.ctx.remap.resolve("workflow", str(src_wf_id)) + if tgt_wf_id is None: + # Workflow wasn't cloned (e.g. its tool was skipped) — nothing to bind to. + logger.debug( + "manual_review: workflow %s has no target mapping — skipping its " + "HITL config", + src_wf_id, + ) + return + + self._clone_rules(src_wf_id, tgt_wf_id, result, lock) + self._clone_settings(src_wf_id, tgt_wf_id, result, lock) + + def _clone_rules( + self, + src_wf_id: str, + tgt_wf_id: str, + result: PhaseResult, + lock: threading.Lock, + ) -> None: + for rule_type in self.ctx.source.MR_RULE_TYPES: + try: + src_rule = self.ctx.source.get_review_rule(src_wf_id, rule_type) + except Exception as e: + logger.exception( + "manual_review: failed to GET source %s rule for workflow %s: %s", + rule_type, + src_wf_id, + e, + ) + with lock: + result.failed += 1 + result.errors.append( + f"GET source {rule_type} rule wf={src_wf_id}: {e}" + ) + continue + if not src_rule: + continue + + if self.ctx.options.dry_run: + with lock: + result.created += 1 + logger.info( + "[dry-run] would create %s rule for workflow %s", + rule_type, + tgt_wf_id, + ) + continue + + # Adopt if the target workflow already carries a rule of this type + # (unique per workflow+rule_type+org). + if not self.ctx.remap.is_planned(tgt_wf_id): + try: + existing = self.ctx.target.get_review_rule(tgt_wf_id, rule_type) + except Exception as e: + logger.exception( + "manual_review: failed to GET target %s rule for wf %s: %s", + rule_type, + tgt_wf_id, + e, + ) + with lock: + result.failed += 1 + result.errors.append( + f"GET target {rule_type} rule wf={tgt_wf_id}: {e}" + ) + continue + if existing: + with lock: + result.adopted += 1 + logger.info( + "adopted %s rule on workflow %s (already present)", + rule_type, + tgt_wf_id, + ) + continue + + payload = self._rule_payload(src_rule, tgt_wf_id) + try: + self.ctx.target.create_review_rule(payload) + except Exception as e: + logger.exception( + "manual_review: failed to create %s rule for wf %s: %s", + rule_type, + tgt_wf_id, + e, + ) + with lock: + result.failed += 1 + result.errors.append( + f"create {rule_type} rule wf={tgt_wf_id}: {e}" + ) + continue + with lock: + result.created += 1 + logger.info("created %s rule for workflow %s", rule_type, tgt_wf_id) + + def _rule_payload( + self, src_rule: dict[str, Any], tgt_wf_id: str + ) -> dict[str, Any]: + payload: dict[str, Any] = {"workflow": tgt_wf_id} + for field in _RULE_FIELDS: + if field in src_rule and src_rule[field] is not None: + payload[field] = src_rule[field] + filters = [ + {f: cf[f] for f in _FILTER_FIELDS if f in cf} + for cf in src_rule.get("confidence_filters") or [] + ] + if filters: + payload["confidence_filters"] = filters + return payload + + def _clone_settings( + self, + src_wf_id: str, + tgt_wf_id: str, + result: PhaseResult, + lock: threading.Lock, + ) -> None: + try: + src_settings = self.ctx.source.get_review_settings(src_wf_id) + except Exception as e: + logger.exception( + "manual_review: failed to GET source settings for wf %s: %s", + src_wf_id, + e, + ) + with lock: + result.failed += 1 + result.errors.append(f"GET source settings wf={src_wf_id}: {e}") + return + if not src_settings: + return + + if self.ctx.options.dry_run: + with lock: + result.created += 1 + logger.info( + "[dry-run] would create HITL settings for workflow %s", tgt_wf_id + ) + return + + # Review settings are one-per-workflow — adopt if one already exists. + if not self.ctx.remap.is_planned(tgt_wf_id): + try: + existing = self.ctx.target.get_review_settings(tgt_wf_id) + except Exception as e: + logger.exception( + "manual_review: failed to GET target settings for wf %s: %s", + tgt_wf_id, + e, + ) + with lock: + result.failed += 1 + result.errors.append(f"GET target settings wf={tgt_wf_id}: {e}") + return + if existing: + with lock: + result.adopted += 1 + logger.info( + "adopted HITL settings on workflow %s (already present)", tgt_wf_id + ) + return + + payload: dict[str, Any] = {"workflow": tgt_wf_id} + for field in _SETTINGS_FIELDS: + if field in src_settings and src_settings[field] is not None: + payload[field] = src_settings[field] + try: + self.ctx.target.create_review_settings(payload) + except Exception as e: + logger.exception( + "manual_review: failed to create settings for wf %s: %s", tgt_wf_id, e + ) + with lock: + result.failed += 1 + result.errors.append(f"create settings wf={tgt_wf_id}: {e}") + return + with lock: + result.created += 1 + logger.info("created HITL settings for workflow %s", tgt_wf_id) + + # ----- org-level: auto-approval ----- + + def _clone_auto_approval(self, result: PhaseResult) -> None: + try: + src_rows = self.ctx.source.list_auto_approval_settings() + except Exception as e: + logger.exception( + "manual_review: failed to list source auto-approval: %s", e + ) + result.failed += 1 + result.errors.append(f"list source auto_approval: {e}") + return + if not src_rows: + return + src = src_rows[0] # Unique per org — at most one row. + + doc_classes = src.get("auto_approved_document_classes") or [] + # Mixed workflow-id / class-name strings with no reliable cross-org + # remap — carried verbatim and flagged for manual verification. + if doc_classes: + result.warnings.append( + "auto-approval cloned with source-org strings in " + "auto_approved_document_classes — these may need manual " + "verification on the target" + ) + + users = self._remap_auto_approved_users( + src.get("auto_approved_users") or [], result + ) + + if self.ctx.options.dry_run: + result.created += 1 + logger.info("[dry-run] would create org auto-approval settings") + return + + try: + existing = self.ctx.target.list_auto_approval_settings() + except Exception as e: + logger.exception( + "manual_review: failed to list target auto-approval: %s", e + ) + result.failed += 1 + result.errors.append(f"list target auto_approval: {e}") + return + if existing: + result.adopted += 1 + logger.info("adopted org auto-approval settings (already present)") + return + + payload = { + "auto_approved_document_classes": doc_classes, + "auto_approved_users": users, + } + try: + self.ctx.target.create_auto_approval_settings(payload) + except Exception as e: + logger.exception("manual_review: failed to create auto-approval: %s", e) + result.failed += 1 + result.errors.append(f"create auto_approval: {e}") + return + result.created += 1 + logger.info("created org auto-approval settings") + + def _remap_auto_approved_users( + self, src_user_ids: list[Any], result: PhaseResult + ) -> list[str]: + """Map source-org user pks to target pks by email (mirrors share + replication). Unmappable users are skipped with a warning; an + unavailable listing carries the field empty rather than failing.""" + if not src_user_ids: + return [] + src_users = source_user_by_id(self.ctx) + tgt_by_email = target_user_id_by_email(self.ctx) + if src_users is None or tgt_by_email is None: + result.warnings.append( + "auto-approval: users listing unavailable — " + f"{len(src_user_ids)} auto-approved user(s) not replicated" + ) + return [] + mapped: list[str] = [] + for uid in src_user_ids: + row = src_users.get(str(uid)) + if row is None: + result.warnings.append( + f"auto-approval: source user id {uid} not in source users " + "listing — skipped" + ) + continue + if is_service_account(row): + continue + email = row["email"] + tgt_uid = tgt_by_email.get(email.lower()) + if tgt_uid is None: + result.warnings.append( + f"auto-approval: user {email} not found in target org — skipped" + ) + continue + # The field holds string ids. + mapped.append(str(tgt_uid)) + return mapped + + # ----- org-level: review api keys ----- + + def _clone_api_keys(self, result: PhaseResult) -> None: + try: + src_keys = self.ctx.source.list_review_api_keys() + except Exception as e: + logger.exception( + "manual_review: failed to list source review api keys: %s", e + ) + result.failed += 1 + result.errors.append(f"list source review_api_keys: {e}") + return + if not src_keys: + return + + if self.ctx.options.dry_run: + result.created += len(src_keys) + logger.info("[dry-run] would recreate %d review api key(s)", len(src_keys)) + return + + # Review API keys have no cross-org natural key (api_key is + # server-minted), so adopt by (class_name, description) to stay + # idempotent on re-runs. + try: + tgt_keys = self.ctx.target.list_review_api_keys() + except Exception as e: + logger.exception( + "manual_review: failed to list target review api keys: %s", e + ) + result.failed += 1 + result.errors.append(f"list target review_api_keys: {e}") + return + existing = { + (k.get("class_name"), k.get("description")) for k in tgt_keys + } + + recreated = 0 + for src in src_keys: + if (src.get("class_name"), src.get("description")) in existing: + result.adopted += 1 + logger.info( + "adopted review api key class=%s (already present)", + src.get("class_name"), + ) + continue + payload = { + k: src[k] + for k in ("class_name", "description", "is_active") + if k in src and src[k] is not None + } + try: + self.ctx.target.create_review_api_key(payload) + except Exception as e: + logger.exception( + "manual_review: failed to create review api key: %s", e + ) + result.failed += 1 + result.errors.append(f"create review_api_key: {e}") + continue + result.created += 1 + recreated += 1 + + # The api_key secret is server-minted and non-copyable; newly created + # keys get a NEW secret, so external consumers must be re-wired. + if recreated: + result.warnings.append( + f"{recreated} review API key(s) recreated with freshly minted " + "secrets — the original key values cannot be copied; re-wire any " + "external consumers to the new keys" + ) + logger.info("recreated %d review api key(s) on target", recreated) diff --git a/src/unstract/clone/phases/pipeline.py b/src/unstract/clone/phases/pipeline.py index f1a93ea..7151d13 100644 --- a/src/unstract/clone/phases/pipeline.py +++ b/src/unstract/clone/phases/pipeline.py @@ -145,6 +145,10 @@ def _clone_one( logger.info( "created pipeline '%s' src=%s -> tgt=%s", name, src_id, tgt["id"] ) + # Backend force-activates every pipeline on create; restore the + # source's disabled state so an inactive source pipeline doesn't + # start running (and failing) on the target's scheduler. + self._restore_active_state(full_src, tgt["id"], name) self._warn_if_extra_source_keys(src_id, name) with lock: @@ -159,6 +163,25 @@ def _clone_one( src_detail_fn=lambda: self.ctx.source.get_pipeline(src_id), ) + def _restore_active_state( + self, full_src: dict[str, Any], tgt_id: str, name: str + ) -> None: + # Only act when the source was inactive; created pipelines are already + # active. PATCHing active=False re-writes the scheduler job as disabled. + if full_src.get("active", True): + return + try: + self.ctx.target.update_pipeline(tgt_id, {"active": False}) + except Exception as e: + logger.warning( + "pipeline '%s': could not restore inactive state — it will run " + "on schedule until disabled in the UI: %s", + name, + e, + ) + return + logger.info("pipeline '%s': restored source inactive state on target", name) + def _warn_if_extra_source_keys(self, src_pipeline_id: str, name: str) -> None: try: keys = self.ctx.source.list_pipeline_keys(src_pipeline_id) diff --git a/src/unstract/clone/phases/tool_instance.py b/src/unstract/clone/phases/tool_instance.py index 3234189..af20171 100644 --- a/src/unstract/clone/phases/tool_instance.py +++ b/src/unstract/clone/phases/tool_instance.py @@ -1,20 +1,17 @@ """Migrate ToolInstance rows from source org to target org. -Each workflow holds at most one ToolInstance, enforced server-side -(``tool_instance_v2/serializers.py`` raises if a workflow already has one). +Each workflow holds at most one ToolInstance, enforced server-side. The row carries: - ``workflow`` FK — remapped from the WorkflowPhase remap table. -- ``tool_id`` (CharField, not FK) — a ``prompt_registry_id`` UUID. The - target's registry was rebuilt in CustomToolPhase, so we remap via the - ``prompt_studio_registry`` table populated there. -- ``metadata`` JSON — backend's ``create()`` discards the POST metadata - and rebuilds it from tool defaults. So we POST a bare instance, then - PATCH the metadata afterwards. Source metadata stores adapter values - as NAMES (via to_representation in source GET); on PATCH the backend's - ``update_metadata_with_adapter_instances`` resolves those names to - the target's adapter UUIDs. Names match across orgs because - AdapterPhase preserved them. +- ``tool_id`` — a ``prompt_registry_id`` UUID. The target's registry was + rebuilt in CustomToolPhase, so we remap via the ``prompt_studio_registry`` + table populated there. +- ``metadata`` JSON — the create response carries default metadata rebuilt + from tool defaults, so we POST a bare instance then PATCH the metadata + afterwards. Source metadata stores adapter values as NAMES; on PATCH the + backend resolves those names to the target's adapter UUIDs. Names match + across orgs because AdapterPhase preserved them. """ from __future__ import annotations @@ -28,12 +25,11 @@ logger = logging.getLogger(__name__) -# Source backend's ToolInstanceSerializer.to_representation emits these -# sentinel strings when an adapter UUID/name in the stored metadata can -# no longer be resolved (deleted or renamed on source). Round-tripping -# them to target produces an AdapterNotFound on PATCH, so we detect and -# skip the metadata PATCH instead — the ToolInstance row exists with the -# backend's safe defaults and the operator can re-bind in the UI. +# The source response emits these sentinel strings when an adapter UUID/name +# in the stored metadata can no longer be resolved (deleted or renamed on +# source). Round-tripping them to target produces an AdapterNotFound on PATCH, +# so we detect and skip the metadata PATCH instead — the row exists with safe +# defaults and the operator can re-bind in the UI. _BROKEN_ADAPTER_SENTINELS: tuple[str, ...] = ( "NOT FOUND", "[DELETED ADAPTER", diff --git a/src/unstract/clone/phases/workflow.py b/src/unstract/clone/phases/workflow.py index d9256a1..e8b34b7 100644 --- a/src/unstract/clone/phases/workflow.py +++ b/src/unstract/clone/phases/workflow.py @@ -129,9 +129,9 @@ def _clone_one( logger.info("[dry-run] would create workflow '%s' src=%s", name, src_id) return else: - # List endpoints serve stripped payloads (e.g. AdapterListSerializer - # omits adapter_metadata_b); workflow detail carries the JSON blobs - # source_settings / destination_settings that embed connector UUIDs. + # List endpoints serve stripped payloads; the workflow detail + # carries the JSON blobs source_settings / destination_settings + # that embed connector UUIDs. try: src_detail = self.ctx.source.get_workflow(src_id) except Exception as e: diff --git a/src/unstract/clone/phases/workflow_endpoint.py b/src/unstract/clone/phases/workflow_endpoint.py index 1101ec3..bd4130e 100644 --- a/src/unstract/clone/phases/workflow_endpoint.py +++ b/src/unstract/clone/phases/workflow_endpoint.py @@ -1,19 +1,17 @@ """Migrate WorkflowEndpoint rows from source org to target org. The backend auto-creates one SOURCE and one DESTINATION endpoint per -workflow on workflow create (``perform_create`` in WorkflowViewSet), so -there's nothing to POST — we only PATCH the target's existing endpoints -with the source's connection_type, connector_instance, and configuration. +workflow on workflow create, so there's nothing to POST — we only PATCH +the target's existing endpoints with the source's connection_type, +connector_instance, and configuration. Notes: -- ``workflow`` and ``endpoint_type`` are ``editable=False`` server-side - and aren't writable on PATCH. +- ``workflow`` and ``endpoint_type`` aren't writable on PATCH. - ``connector_instance`` FK is nullable; we remap via the connector remap table populated in ConnectorPhase. - ``configuration`` is a JSON blob that may embed connector UUIDs; walker pass remaps them before PATCH. -- Source ``connector_instance`` arrives as a nested dict (per - ``WorkflowEndpointSerializer.connector_instance``); we extract its +- Source ``connector_instance`` arrives as a nested dict; we extract its ``id`` and remap. """ @@ -141,30 +139,27 @@ def _patch_endpoint( tgt_ep_id = tgt_ep["id"] etype = src_ep["endpoint_type"] - # Resolve the connector before the dry-run gate so the plan mirrors - # the real run's unmapped-connector skip (an unmapped connector is - # left out of both counts, not predicted as a patch). src_conn_id = _extract_connector_id(src_ep) tgt_conn_id: str | None = None + connector_unmapped = False if src_conn_id: with lock: tgt_conn_id = self.ctx.remap.resolve("connector", src_conn_id) if not tgt_conn_id: + # Connector wasn't cloned (e.g. OAuth). Still patch the + # connection_type so the endpoint is valid and the operator + # only needs to re-bind the connector — skipping the whole + # patch would leave connection_type empty and fail runs with + # "Invalid source connection type". + connector_unmapped = True logger.warning( - "skipping %s endpoint src=%s tgt=%s — source connector %s " - "has no target remap; would silently unset connector", + "%s endpoint src=%s tgt=%s: source connector %s has no " + "target remap — setting type only, connector needs UI config", etype, src_ep_id, tgt_ep_id, src_conn_id, ) - with lock: - result.skipped += 1 - result.errors.append( - f"unmapped connector on {etype} endpoint {src_ep_id}: " - f"src_connector={src_conn_id}" - ) - return if self.ctx.options.dry_run: with lock: @@ -175,7 +170,7 @@ def _patch_endpoint( etype, src_ep_id, tgt_ep_id, - tgt_conn_id, + tgt_conn_id or "", ) return @@ -183,11 +178,15 @@ def _patch_endpoint( "configuration": remap_uuids( src_ep.get("configuration") or {}, self.ctx.remap ), - "connector_instance_id": tgt_conn_id, } src_connection_type = src_ep.get("connection_type") if src_connection_type is not None: payload["connection_type"] = src_connection_type + # Omit connector_instance_id only when the source connector couldn't be + # remapped, so the target keeps its connector for re-binding. A source + # with no connector (e.g. API) still patches null to clear any stale one. + if not connector_unmapped: + payload["connector_instance_id"] = tgt_conn_id try: self.ctx.target.update_workflow_endpoint(tgt_ep_id, payload) @@ -203,10 +202,16 @@ def _patch_endpoint( with lock: result.created += 1 self.ctx.remap.record("workflow_endpoint", src_ep_id, tgt_ep_id) + if connector_unmapped: + result.warnings.append( + f"{etype} endpoint {src_ep_id}: connector not cloned " + f"(src_connector={src_conn_id}) — connection_type set, " + "configure the connector in the UI" + ) logger.info( "patched %s endpoint src=%s -> tgt=%s (connector %s)", etype, src_ep_id, tgt_ep_id, - tgt_conn_id, + tgt_conn_id or "", ) diff --git a/src/unstract/clone/report.py b/src/unstract/clone/report.py index 1066585..98fcadb 100644 --- a/src/unstract/clone/report.py +++ b/src/unstract/clone/report.py @@ -43,6 +43,9 @@ class CloneReport: dry_run: bool = False phases: list[PhaseResult] = field(default_factory=list) skipped_phases: list[str] = field(default_factory=list) + # Run-level non-fatal warnings not tied to any one phase (e.g. a cloud + # phase skipped because the target deployment lacks the feature). + warnings: list[str] = field(default_factory=list) remap_snapshot: dict[str, dict[str, str]] = field(default_factory=dict) aborted: bool = False abort_reason: str | None = None @@ -224,6 +227,7 @@ def as_dict(self) -> dict[str, Any]: for p in self.phases ], "skipped_phases": list(self.skipped_phases), + "warnings": list(self.warnings), "remap_snapshot": self.remap_snapshot, "aborted": self.aborted, "abort_reason": self.abort_reason, @@ -336,6 +340,8 @@ def _render_failures_summary(self, console_print: Any, rich: bool) -> None: def _render_warnings_summary(self, console_print: Any, rich: bool) -> None: rows: list[tuple[str, str]] = [] + for warning in self.warnings: + rows.append(("run", warning)) for p in self.phases: for warning in p.warnings: rows.append((p.name, warning)) diff --git a/src/unstract/clone/sharing.py b/src/unstract/clone/sharing.py index 0ffde1b..e1dc883 100644 --- a/src/unstract/clone/sharing.py +++ b/src/unstract/clone/sharing.py @@ -23,8 +23,6 @@ logger = logging.getLogger(__name__) SHARE_AXES: tuple[str, ...] = ("shared_users", "shared_groups", "shared_to_org") -# Platform-key identities; they exist per-org and never map across orgs. -SERVICE_ACCOUNT_EMAIL_SUFFIX = "@platform.internal" _FETCH_FAILED = object() # cache sentinel so a failing listing isn't re-hit @@ -32,14 +30,10 @@ def is_service_account(row: dict[str, Any]) -> bool: """True if a user/member listing row is a service account. - Email-suffix fallback covers older backends without the flag; - mis-classification is benign — a service-account email never matches - across orgs, so worst case is a spurious skip-warning. + These identities are per-org and never map across orgs, so they are + skipped during share replication. """ - flag = row.get("is_service_account") - if flag is not None: - return bool(flag) - return (row.get("email") or "").lower().endswith(SERVICE_ACCOUNT_EMAIL_SUFFIX) + return bool(row.get("is_service_account")) def _cached(ctx: CloneContext, key: str, build: Callable[[], Any]) -> Any: @@ -89,27 +83,21 @@ def target_user_id_by_email(ctx: CloneContext) -> dict[str, int] | None: return None if value is _FETCH_FAILED else value -def apply_share_state( +def _resolve_share_src( ctx: CloneContext, - *, - share_path: str, - entity_label: str, src: dict[str, Any], + entity_label: str, result: PhaseResult, lock: threading.Lock, - src_detail_fn: Callable[[], dict[str, Any]] | None = None, -) -> None: - """Mirror ``src``'s share state onto the target resource at ``share_path``. - - ``src`` may be a stripped list-row; when any share axis is missing and - ``src_detail_fn`` is given, the source detail is fetched once. No-ops - when the effective share state is empty. Never raises — failures land - in ``result.errors`` (counted) and skips in ``result.warnings``. + src_detail_fn: Callable[[], dict[str, Any]] | None, +) -> dict[str, Any] | None: + """Return the source row carrying share axes, fetching the detail once + when ``src`` is a stripped list-row. ``None`` signals a fetch failure + (already warned) — caller should abort. """ - share_src = src - if src_detail_fn is not None and not all(k in share_src for k in SHARE_AXES): + if src_detail_fn is not None and not all(k in src for k in SHARE_AXES): try: - share_src = src_detail_fn() + return src_detail_fn() except Exception as e: logger.warning("share %s: source detail fetch failed: %s", entity_label, e) with lock: @@ -117,8 +105,24 @@ def apply_share_state( f"share {entity_label}: source detail fetch failed — " f"share state not replicated: {e}" ) - return + return None + return src + +def _build_share_payload( + ctx: CloneContext, + share_src: dict[str, Any], + entity_label: str, + result: PhaseResult, + lock: threading.Lock, + *, + include_groups: bool = True, +) -> dict[str, Any] | None: + """Map ``share_src``'s axes to target ids (groups via remap, users by + email). Returns the share payload, or ``None`` when there is nothing to + replicate. ``include_groups=False`` omits the group axis entirely (for + entities whose serializer has no writable ``shared_groups``). + """ shared_to_org = bool(share_src.get("shared_to_org")) src_group_ids = list(share_src.get("shared_groups") or []) src_user_ids = list(share_src.get("shared_users") or []) @@ -127,25 +131,34 @@ def apply_share_state( payload: dict[str, Any] = {"shared_to_org": shared_to_org} group_warnings: list[str] = [] - if src_group_ids and not ctx.options.includes("group"): - # Axis omitted entirely so the target's group shares are untouched. + if include_groups: + if src_group_ids and not ctx.options.includes("group"): + # Axis omitted entirely so the target's group shares are untouched. + group_warnings.append( + f"share {entity_label}: group phase excluded — " + f"{len(src_group_ids)} group share(s) not replicated" + ) + else: + mapped_groups: list[Any] = [] + for gid in src_group_ids: + tgt_gid = ctx.remap.resolve("group", str(gid)) + if tgt_gid is None: + group_warnings.append( + f"share {entity_label}: source group id {gid} has no " + "target mapping — skipped" + ) + else: + # Real group pks are ints; dry-run planned remaps are + # synthetic uuids (never POSTed) — keep those as-is. + mapped_groups.append( + int(tgt_gid) if str(tgt_gid).isdigit() else tgt_gid + ) + payload["shared_groups"] = mapped_groups + elif src_group_ids: group_warnings.append( - f"share {entity_label}: group phase excluded — " - f"{len(src_group_ids)} group share(s) not replicated" + f"share {entity_label}: {len(src_group_ids)} group share(s) not " + "supported by this entity — not replicated" ) - mapped_groups: list[int] | None = None - else: - mapped_groups = [] - for gid in src_group_ids: - tgt_gid = ctx.remap.resolve("group", str(gid)) - if tgt_gid is None: - group_warnings.append( - f"share {entity_label}: source group id {gid} has no " - "target mapping — skipped" - ) - else: - mapped_groups.append(int(tgt_gid)) - payload["shared_groups"] = mapped_groups user_warnings: list[str] = [] mapped_users: list[int] = [] @@ -187,30 +200,68 @@ def apply_share_state( if not mapped_users and not payload.get("shared_groups") and not shared_to_org: logger.debug("share %s: nothing to replicate", entity_label) - return + return None + return payload + +def replicate_share( + ctx: CloneContext, + *, + apply_fn: Callable[[dict[str, Any]], Any], + entity_label: str, + src: dict[str, Any], + result: PhaseResult, + lock: threading.Lock, + src_detail_fn: Callable[[], dict[str, Any]] | None = None, + include_groups: bool = True, +) -> None: + """Map ``src``'s share state and hand the payload to ``apply_fn`` (the + entity-specific write — a ``/share/`` POST or a detail PATCH). Generic + over the write mechanism so PATCH-shared cloud entities reuse the same + user/group mapping. Never raises. + """ + share_src = _resolve_share_src(ctx, src, entity_label, result, lock, src_detail_fn) + if share_src is None: + return + payload = _build_share_payload( + ctx, share_src, entity_label, result, lock, include_groups=include_groups + ) + if payload is None: + return if ctx.options.dry_run: - logger.info( - "[dry-run] would share %s: users=%s groups=%s org=%s", - entity_label, - mapped_users, - payload.get("shared_groups"), - shared_to_org, - ) + logger.info("[dry-run] would share %s: %s", entity_label, payload) return - try: - ctx.target.share_resource(share_path, payload) + apply_fn(payload) except Exception as e: logger.exception("Failed to apply share state for %s: %s", entity_label, e) with lock: result.failed += 1 result.errors.append(f"share {entity_label}: {e}") return - logger.info( - "shared %s: users=%s groups=%s org=%s", - entity_label, - mapped_users, - payload.get("shared_groups"), - shared_to_org, + logger.info("shared %s: %s", entity_label, payload) + + +def apply_share_state( + ctx: CloneContext, + *, + share_path: str, + entity_label: str, + src: dict[str, Any], + result: PhaseResult, + lock: threading.Lock, + src_detail_fn: Callable[[], dict[str, Any]] | None = None, +) -> None: + """Mirror ``src``'s share state onto the target resource via its + ``/share/`` POST endpoint at ``share_path``. Thin wrapper over + ``replicate_share`` preserving the original POST behavior. + """ + replicate_share( + ctx, + apply_fn=lambda payload: ctx.target.share_resource(share_path, payload), + entity_label=entity_label, + src=src, + result=result, + lock=lock, + src_detail_fn=src_detail_fn, ) diff --git a/tests/clone/test_agentic_studio_phase.py b/tests/clone/test_agentic_studio_phase.py new file mode 100644 index 0000000..4521ef4 --- /dev/null +++ b/tests/clone/test_agentic_studio_phase.py @@ -0,0 +1,509 @@ +"""Tests for ``AgenticStudioPhase`` (cloud-only Agentic Prompt Studio). + +Covers: create-fresh project + four-adapter remap; an unresolved adapter +omitted with a warning; adopt-by-name; prompt-version parent-before-child +ordering with the self-FK remapped; schema clone bound to the target project; +registry republish via the export action recording a remap; dry-run records a +planned project (+ planned children) and writes nothing. + +A single scripted fake plays both source and target; the target side records +every write so assertions read off ``created_*`` lists. +""" + +from __future__ import annotations + +from unstract.clone.context import CloneContext, CloneOptions, RemapTable +from unstract.clone.phases.agentic_studio import AgenticStudioPhase +from unstract.clone.report import CloneReport + + +class FakeClient: + POST_SCHEMA = frozenset( + { + "name", + "description", + "canary_fields", + "llm_connector_id", + "agent_llm_connector_id", + "lightweight_llm_connector_id", + "text_extractor_connector_id", + } + ) + + def __init__( + self, + *, + projects=None, + versions=None, + schemas=None, + settings=None, + registries=None, + users=None, + ): + self.projects = list(projects or []) + self.users = list(users or []) + self.shared_projects: list[tuple[str, dict]] = [] + # project_id -> list of prompt-version rows + self.versions = {k: list(v) for k, v in (versions or {}).items()} + # project_id -> list of schema rows + self.schemas = {k: list(v) for k, v in (schemas or {}).items()} + self.settings = list(settings or []) + # project_id -> list of registry rows + self.registries = {k: list(v) for k, v in (registries or {}).items()} + + self.created_projects: list[dict] = [] + self.created_versions: list[dict] = [] + self.created_schemas: list[dict] = [] + self.created_settings: list[dict] = [] + self.updated_settings: list[tuple[str, dict]] = [] + self.exported_projects: list[str] = [] + self._next_id = 1 + + def _mint(self, prefix: str) -> str: + out = f"{prefix}-{self._next_id:04d}" + self._next_id += 1 + return out + + def get_post_schema(self, entity_path): + return self.POST_SCHEMA + + # ----- projects ----- + + def list_agentic_projects(self): + return list(self.projects) + + def create_agentic_project(self, payload): + new = dict(payload) + new["id"] = self._mint("tgt-proj") + self.projects.append(new) + self.created_projects.append(new) + return new + + def share_resource(self, share_path, payload): + self.shared_projects.append((share_path, payload)) + return payload + + # ----- users (share replication) ----- + + def list_users(self): + return list(self.users) + + # ----- prompt versions ----- + + def list_agentic_prompt_versions(self, *, project_id=None): + return list(self.versions.get(project_id, [])) + + def create_agentic_prompt_version(self, payload): + new = dict(payload) + new["id"] = self._mint("tgt-ver") + self.created_versions.append(new) + self.versions.setdefault(new["project"], []).append(new) + return new + + # ----- schemas ----- + + def list_agentic_schemas(self, *, project_id=None): + return list(self.schemas.get(project_id, [])) + + def create_agentic_schema(self, payload): + new = dict(payload) + new["id"] = self._mint("tgt-schema") + self.created_schemas.append(new) + self.schemas.setdefault(new["project"], []).append(new) + return new + + # ----- settings ----- + + def list_agentic_settings(self): + return list(self.settings) + + def create_agentic_setting(self, payload): + new = dict(payload) + new["id"] = self._mint("tgt-setting") + self.created_settings.append(new) + self.settings.append(new) + return new + + def update_agentic_setting(self, setting_id, payload): + self.updated_settings.append((setting_id, payload)) + return {"id": setting_id, **payload} + + # ----- registry ----- + + def export_agentic_project(self, project_id, *, force=True): + self.exported_projects.append(project_id) + # Export auto-creates the registry row on the target. + self.registries.setdefault(project_id, []).append( + {"registry_id": f"tgt-reg-{project_id}"} + ) + return {"message": "ok"} + + def list_agentic_registries(self, *, agentic_project=None): + return list(self.registries.get(agentic_project, [])) + + +def _ctx(source, target, *, remap=None, **opt_overrides): + return CloneContext( + source=source, + target=target, + options=CloneOptions(**opt_overrides), + remap=remap or RemapTable(), + ) + + +def _src_project(pid, name, **adapters): + base = {"id": pid, "name": name, "description": f"{name} desc"} + base.update(adapters) + return base + + +def test_create_fresh_with_four_adapter_remap(): + src = FakeClient( + projects=[ + _src_project( + "src-p", + "Receipts", + llm_connector_id="src-llm", + agent_llm_connector_id="src-agent", + lightweight_llm_connector_id="src-light", + text_extractor_connector_id="src-x2t", + ) + ] + ) + tgt = FakeClient() + remap = RemapTable() + for s, t in [ + ("src-llm", "tgt-llm"), + ("src-agent", "tgt-agent"), + ("src-light", "tgt-light"), + ("src-x2t", "tgt-x2t"), + ]: + remap.record("adapter", s, t) + ctx = _ctx(src, tgt, remap=remap) + + result = AgenticStudioPhase(ctx).run(CloneReport()) + + assert result.created == 1 + assert len(tgt.created_projects) == 1 + payload = tgt.created_projects[0] + assert payload["llm_connector_id"] == "tgt-llm" + assert payload["agent_llm_connector_id"] == "tgt-agent" + assert payload["lightweight_llm_connector_id"] == "tgt-light" + assert payload["text_extractor_connector_id"] == "tgt-x2t" + new_id = payload["id"] + assert remap.resolve("agentic_project", "src-p") == new_id + + +def test_unresolved_adapter_omitted_with_warning(): + src = FakeClient( + projects=[ + _src_project( + "src-p", + "Receipts", + llm_connector_id="src-llm", + agent_llm_connector_id="src-agent", + ) + ] + ) + tgt = FakeClient() + remap = RemapTable() + remap.record("adapter", "src-llm", "tgt-llm") # agent intentionally absent + ctx = _ctx(src, tgt, remap=remap) + + result = AgenticStudioPhase(ctx).run(CloneReport()) + + payload = tgt.created_projects[0] + assert payload["llm_connector_id"] == "tgt-llm" + assert "agent_llm_connector_id" not in payload + assert any( + "agent_llm_connector_id adapter not remapped" in w for w in result.warnings + ) + + +def test_adopt_by_name_records_remap_no_create(): + src = FakeClient(projects=[_src_project("src-p", "Receipts")]) + tgt = FakeClient(projects=[{"id": "tgt-existing", "name": "Receipts"}]) + ctx = _ctx(src, tgt, on_name_conflict="adopt") + + result = AgenticStudioPhase(ctx).run(CloneReport()) + + assert result.adopted == 1 + assert result.created == 0 + assert tgt.created_projects == [] + assert ctx.remap.resolve("agentic_project", "src-p") == "tgt-existing" + + +def test_prompt_version_parent_before_child_remap(): + # Child (v2) listed first to prove ordering sorts roots ahead of children. + src = FakeClient( + projects=[_src_project("src-p", "Receipts")], + versions={ + "src-p": [ + { + "id": "src-v2", + "project": "src-p", + "version": 2, + "prompt_text": "v2", + "parent_version": "src-v1", + }, + { + "id": "src-v1", + "project": "src-p", + "version": 1, + "prompt_text": "v1", + "parent_version": None, + }, + ] + }, + ) + tgt = FakeClient() + ctx = _ctx(src, tgt) + + AgenticStudioPhase(ctx).run(CloneReport()) + + assert len(tgt.created_versions) == 2 + # Root v1 cloned first, no parent. + first, second = tgt.created_versions + assert first["version"] == 1 + assert "parent_version" not in first + # Child v2 second, parent remapped to the freshly created root id. + assert second["version"] == 2 + assert second["parent_version"] == first["id"] + tgt_pid = tgt.created_projects[0]["id"] + assert first["project"] == tgt_pid and second["project"] == tgt_pid + + +def test_schema_clone_bound_to_target_project(): + src = FakeClient( + projects=[_src_project("src-p", "Receipts")], + schemas={ + "src-p": [ + { + "id": "src-s1", + "project": "src-p", + "json_schema": '{"type":"object"}', + "version": 1, + "is_active": True, + } + ] + }, + ) + tgt = FakeClient() + ctx = _ctx(src, tgt) + + AgenticStudioPhase(ctx).run(CloneReport()) + + assert len(tgt.created_schemas) == 1 + schema = tgt.created_schemas[0] + assert schema["project"] == tgt.created_projects[0]["id"] + assert schema["json_schema"] == '{"type":"object"}' + + +def test_rerun_adopts_existing_prompt_versions_and_schemas(): + # Re-run against a pair whose project + children already exist on target + # must adopt them (no duplicate child creates). + src = FakeClient( + projects=[_src_project("src-p", "Receipts")], + versions={ + "src-p": [ + { + "id": "src-v1", + "project": "src-p", + "version": 1, + "prompt_text": "v1", + "parent_version": None, + } + ] + }, + schemas={ + "src-p": [ + { + "id": "src-s1", + "project": "src-p", + "json_schema": '{"type":"object"}', + "version": 1, + } + ] + }, + ) + tgt = FakeClient( + projects=[{"id": "tgt-existing", "name": "Receipts"}], + versions={ + "tgt-existing": [ + {"id": "tgt-v1", "project": "tgt-existing", "version": 1} + ] + }, + schemas={ + "tgt-existing": [ + {"id": "tgt-s1", "project": "tgt-existing", "version": 1} + ] + }, + ) + ctx = _ctx(src, tgt, on_name_conflict="adopt") + + result = AgenticStudioPhase(ctx).run(CloneReport()) + + # No duplicate children created on the re-run. + assert tgt.created_versions == [] + assert tgt.created_schemas == [] + assert result.created == 0 + # Existing version adopted + remap recorded so child parents still resolve. + assert ctx.remap.resolve("agentic_prompt_version", "src-v1") == "tgt-v1" + + +def test_registry_republished_and_remapped(): + src = FakeClient( + projects=[_src_project("src-p", "Receipts")], + registries={"src-p": [{"registry_id": "src-reg"}]}, + ) + tgt = FakeClient() + ctx = _ctx(src, tgt) + + AgenticStudioPhase(ctx).run(CloneReport()) + + tgt_pid = tgt.created_projects[0]["id"] + assert tgt.exported_projects == [tgt_pid] + assert ( + ctx.remap.resolve("agentic_studio_registry", "src-reg") + == f"tgt-reg-{tgt_pid}" + ) + + +def test_no_source_registry_skips_export(): + src = FakeClient(projects=[_src_project("src-p", "Receipts")]) + tgt = FakeClient() + ctx = _ctx(src, tgt) + + AgenticStudioPhase(ctx).run(CloneReport()) + + assert tgt.exported_projects == [] + + +def test_settings_create_and_adopt(): + src = FakeClient( + projects=[], + settings=[ + {"id": "src-s1", "key": "model", "value": "gpt-4"}, + {"id": "src-s2", "key": "temp", "value": "0.2"}, + ], + ) + tgt = FakeClient(settings=[{"id": "tgt-s2", "key": "temp", "value": "0.9"}]) + ctx = _ctx(src, tgt) + + result = AgenticStudioPhase(ctx).run(CloneReport()) + + # 'model' is new (create), 'temp' already exists (update/adopt). + assert [s["key"] for s in tgt.created_settings] == ["model"] + assert tgt.updated_settings and tgt.updated_settings[0][0] == "tgt-s2" + assert result.adopted == 1 + + +def test_setting_global_key_collision_skips_not_fails(): + # `key` is globally unique across orgs: a create can 500 on a key owned by + # another org that isn't in this org's listing. That's a warned skip, not a + # hard failure. + src = FakeClient(settings=[{"id": "src-s1", "key": "global-key", "value": "v"}]) + tgt = FakeClient() + + def _boom(payload): + raise RuntimeError("returned 500") + + tgt.create_agentic_setting = _boom + ctx = _ctx(src, tgt) + + result = AgenticStudioPhase(ctx).run(CloneReport()) + + assert result.failed == 0 + assert result.skipped == 1 + assert any("global-key" in w for w in result.warnings) + + +def test_dry_run_plans_without_writing(): + src = FakeClient( + projects=[_src_project("src-p", "Receipts")], + versions={ + "src-p": [ + {"id": "src-v1", "project": "src-p", "version": 1, "prompt_text": "v1"} + ] + }, + schemas={ + "src-p": [ + {"id": "src-s1", "project": "src-p", "json_schema": "{}", "version": 1} + ] + }, + settings=[{"id": "src-set", "key": "model", "value": "x"}], + ) + tgt = FakeClient() + ctx = _ctx(src, tgt, dry_run=True) + + result = AgenticStudioPhase(ctx).run(CloneReport()) + + assert tgt.created_projects == [] + assert tgt.created_versions == [] + assert tgt.created_schemas == [] + assert tgt.created_settings == [] + assert tgt.exported_projects == [] + # 1 project + 1 version + 1 schema + 1 setting planned. + assert result.created == 4 + planned = ctx.remap.resolve("agentic_project", "src-p") + assert planned is not None and ctx.remap.is_planned(planned) + planned_v = ctx.remap.resolve("agentic_prompt_version", "src-v1") + assert planned_v is not None and ctx.remap.is_planned(planned_v) + + +def test_share_replicated_with_mapped_users_and_org_flag(): + src = FakeClient( + projects=[ + { + "id": "src-p", + "name": "Receipts", + "description": "d", + "created_by": 1, # owner — skipped, server-managed on target + "shared_to_org": True, + "shared_users": [1, 2], # pks from the project serializer + } + ], + users=[ + {"id": 1, "email": "owner@x.com"}, + {"id": 2, "email": "alice@x.com"}, + ], + ) + tgt = FakeClient(users=[{"id": 42, "email": "alice@x.com"}]) + ctx = _ctx(src, tgt) + + AgenticStudioPhase(ctx).run(CloneReport()) + + tgt_pid = tgt.created_projects[0]["id"] + assert len(tgt.shared_projects) == 1 + share_path, payload = tgt.shared_projects[0] + assert share_path == f"agentic/projects/{tgt_pid}/share/" + assert payload["shared_to_org"] is True + assert payload["shared_users"] == [42] # owner dropped, alice remapped + assert payload["shared_groups"] == [] # no source groups, axis still sent + + +def test_source_group_share_replicated_via_remap(): + # Group sharing IS supported on agentic projects (via the share action), + # so a source group share maps through the `group` remap onto the target. + src = FakeClient( + projects=[ + { + "id": "src-p", + "name": "Receipts", + "description": "d", + "created_by": 1, + "shared_groups": [7], + } + ], + users=[{"id": 1, "email": "owner@x.com"}], + ) + tgt = FakeClient() + ctx = _ctx(src, tgt) + ctx.remap.record("group", "7", "70") + + result = AgenticStudioPhase(ctx).run(CloneReport()) + + assert result.failed == 0 + assert len(tgt.shared_projects) == 1 + _, payload = tgt.shared_projects[0] + assert payload["shared_groups"] == [70] diff --git a/tests/clone/test_client.py b/tests/clone/test_client.py index 9fa3dca..561c91a 100644 --- a/tests/clone/test_client.py +++ b/tests/clone/test_client.py @@ -143,3 +143,17 @@ def test_options_response_with_null_body_still_yields_empty_schema(): # Some deployments return 200 with no body on OPTIONS. client, _ = _client_with_mock(payload=None, text="") assert client.get_post_schema("pipeline/") == frozenset() + + +def test_get_review_settings_500_treated_as_absent(): + # Backend raises DoesNotExist (-> 500) when no HITLSettings row exists. + client, _ = _client_with_mock(status=500, text="DoesNotExist") + assert client.get_review_settings("wf-1") is None + + +def test_get_review_settings_reraises_non_500(): + # Auth / rate-limit errors must surface, not masquerade as "no settings". + client, _ = _client_with_mock(status=403, text="forbidden") + with pytest.raises(PlatformAPIError) as exc_info: + client.get_review_settings("wf-1") + assert exc_info.value.status_code == 403 diff --git a/tests/clone/test_cloud_phase_gating.py b/tests/clone/test_cloud_phase_gating.py new file mode 100644 index 0000000..2250c12 --- /dev/null +++ b/tests/clone/test_cloud_phase_gating.py @@ -0,0 +1,85 @@ +"""Tests for capability-probe gating of cloud-only phases. + +A cloud phase declares ``probe_path``; the orchestrator probes source then +target before running it. Matrix: +- source absent → silent skip (no run, no report row, no warning). +- source present, target absent → warn + skip (one ``report.warnings`` entry). +- both present → run normally. + +Core OSS phases (``probe_path is None``) are never probed and always run. +""" + +from __future__ import annotations + +from unittest.mock import patch + +from unstract.clone import orchestrator +from unstract.clone.context import OrgEndpoint +from unstract.clone.phases.base import Phase +from unstract.clone.report import CloneReport, PhaseResult + + +class _CloudPhase(Phase): + """Dummy cloud phase; records that it ran on a shared list.""" + + invocations: list[str] = [] + name = "dummy_cloud" + probe_path = "dummy/" + + def run(self, report: CloneReport) -> PhaseResult: + _CloudPhase.invocations.append(self.name) + return report.get_phase(self.name) + + +def _src() -> OrgEndpoint: + return OrgEndpoint( + base_url="https://src.example.com", + organization_id="src_org", + platform_key="src-key", + ) + + +def _tgt() -> OrgEndpoint: + return OrgEndpoint( + base_url="https://tgt.example.com", + organization_id="tgt_org", + platform_key="tgt-key", + ) + + +def _run_with_probes(*, source_present: bool, target_present: bool) -> CloneReport: + """Drive clone() with a single cloud phase and scripted per-org probes.""" + _CloudPhase.invocations = [] + scripted = {"src_org": source_present, "tgt_org": target_present} + + def fake_probe(self, path: str) -> bool: + return scripted[self.endpoint.organization_id] + + with ( + patch.object(orchestrator, "PHASES", [("dummy_cloud", _CloudPhase)]), + patch.object(orchestrator.PlatformClient, "close"), + patch.object(orchestrator.PlatformClient, "probe", fake_probe), + ): + return orchestrator.clone(_src(), _tgt()) + + +def test_source_absent_skips_silently(): + report = _run_with_probes(source_present=False, target_present=False) + assert _CloudPhase.invocations == [] + # OSS source must look exactly like today: no phase row, no warning. + assert report.phases == [] + assert report.warnings == [] + assert "dummy_cloud" not in report.skipped_phases + + +def test_source_present_target_absent_warns_and_skips(): + report = _run_with_probes(source_present=True, target_present=False) + assert _CloudPhase.invocations == [] + assert len(report.warnings) == 1 + assert "dummy_cloud" in report.warnings[0] + + +def test_both_present_runs_phase(): + report = _run_with_probes(source_present=True, target_present=True) + assert _CloudPhase.invocations == ["dummy_cloud"] + assert report.warnings == [] diff --git a/tests/clone/test_custom_tool_phase.py b/tests/clone/test_custom_tool_phase.py index 98d56a4..d0fb317 100644 --- a/tests/clone/test_custom_tool_phase.py +++ b/tests/clone/test_custom_tool_phase.py @@ -9,7 +9,9 @@ - registry remap recorded after ``export_custom_tool``. - dry-run: no writes on either side. - abort on name conflict when option is set. -- missing target adapter fails the tool cleanly. +- incomplete source tools (missing target adapter / no profile) mirror + unconfigured instead of failing; frictionless adapters still skip. +- registry republish 500 warns, doesn't fail the tool. """ from __future__ import annotations @@ -49,6 +51,7 @@ def __init__(self) -> None: self.export_blobs: dict[str, dict] = {} self.registries_by_tool: dict[str, dict] = {} self.adapters_by_name: dict[str, dict] = {} + self.prompts_by_tool: dict[str, list[dict]] = {} # Call recorders. self.import_calls: list[tuple[dict, dict | None]] = [] self.sync_calls: list[tuple[str, dict, bool]] = [] @@ -90,7 +93,17 @@ def list_registries(self, *, custom_tool: str | None = None) -> list[dict]: reg = self.registries_by_tool.get(custom_tool) return [reg] if reg else [] + def list_prompts(self, tool_id: str) -> list[dict]: + return list(self.prompts_by_tool.get(tool_id, [])) + # --- writes --- + _REQUIRED_ADAPTER_FIELDS = ( + "llm_adapter_id", + "vector_db_adapter_id", + "embedding_adapter_id", + "x2text_adapter_id", + ) + def import_project( self, export_data: dict, adapter_ids: dict | None = None ) -> dict: @@ -98,10 +111,14 @@ def import_project( tool_id = self._mint("tool") tool_name = export_data["tool_metadata"]["tool_name"] self.tools[tool_id] = {"tool_name": tool_name} + # Backend flags needs_adapter_config unless all four are wired. + fully_wired = bool(adapter_ids) and all( + adapter_ids.get(k) for k in self._REQUIRED_ADAPTER_FIELDS + ) return { "tool_id": tool_id, "message": f"Project imported successfully as '{tool_name}'", - "needs_adapter_config": adapter_ids is None, + "needs_adapter_config": not fully_wired, } def sync_prompts( @@ -223,7 +240,8 @@ def test_fresh_imports_with_name_resolved_adapter_ids_and_records_registry(): assert result.created == 1 assert result.failed == 0 - # Exactly one import_project call with the right export blob + name-resolved adapter ids. + # Exactly one import_project call with the right export blob + name-resolved + # adapter ids. assert len(tgt.import_calls) == 1 blob, adapter_ids = tgt.import_calls[0] assert blob["tool_metadata"]["tool_name"] == "Invoice Extractor" @@ -396,7 +414,11 @@ def test_never_exported_source_tool_skips_registry_republish(): assert ctx.remap.resolve("prompt_studio_registry", SRC_REG) is None -def test_missing_target_adapter_fails_tool_cleanly(): +def test_missing_target_adapter_imports_unconfigured(): + """A source adapter with no target match isn't fatal: the tool is + mirrored with a partial adapter set, flagged needs_adapter_config, and + a warning tells the operator to wire it + re-run. + """ src = FakeClient() tgt = FakeClient() _preload_source_tool(src, "src-tool-x", "T") @@ -408,9 +430,106 @@ def test_missing_target_adapter_fails_tool_cleanly(): result = CustomToolPhase(ctx).run(CloneReport()) - assert result.failed == 1 - assert tgt.import_calls == [] - # Registry republish should NOT fire when the tool fails. - assert tgt.export_tool_calls == [] - # No custom_tool remap recorded. - assert ctx.remap.resolve("custom_tool", "src-tool-x") is None + assert result.created == 1 + assert result.failed == 0 + # Import fired with the 3 resolvable adapters; x2text omitted. + assert len(tgt.import_calls) == 1 + _, adapter_ids = tgt.import_calls[0] + assert adapter_ids == { + "llm_adapter_id": TGT_ADAPTER_IDS["gpt4"], + "vector_db_adapter_id": TGT_ADAPTER_IDS["pgvector"], + "embedding_adapter_id": TGT_ADAPTER_IDS["ada-embed"], + } + assert "x2text_adapter_id" not in adapter_ids + assert any("full adapter config" in w for w in result.warnings) + # Source had a registry → still republished + remap recorded. + assert tgt.export_tool_calls == [ctx.remap.resolve("custom_tool", "src-tool-x")] + + +def test_no_default_profile_imports_unconfigured(): + """A source tool with no profiles can't derive adapter ids; mirror it + anyway (backend auto-creates an unconfigured default profile) rather + than failing the whole clone. + """ + src = FakeClient() + tgt = FakeClient() + _preload_source_tool(src, "src-tool-x", "Profileless") + src.profiles_by_tool["src-tool-x"] = [] # no default profile on source + _seed_source_adapters(src) + _seed_target_adapters(tgt) + ctx = _ctx(src, tgt) + + result = CustomToolPhase(ctx).run(CloneReport()) + + assert result.created == 1 + assert result.failed == 0 + # Imported with an empty adapter set (no profile to derive from). + assert tgt.import_calls == [(src.export_blobs["src-tool-x"], {})] + assert any("full adapter config" in w for w in result.warnings) + assert ctx.remap.resolve("custom_tool", "src-tool-x") is not None + + +def test_republish_failure_warns_not_fails(): + """A registry republish 500 (e.g. stale/empty source registry) must not + fail the whole tool — the tool itself cloned; downstream tool_instances + just cascade-skip. + """ + src = FakeClient() + tgt = FakeClient() + _preload_source_tool(src, "src-tool-x", "Stale Registry Tool") + _seed_source_adapters(src) + _seed_target_adapters(tgt) + + def boom(tool_id, *, force=True): + raise RuntimeError("500 export failed: no run prompts") + + tgt.export_custom_tool = boom + ctx = _ctx(src, tgt) + + result = CustomToolPhase(ctx).run(CloneReport()) + + # Tool cloned; republish failure is a warning, not a failure. + assert result.created == 1 + assert result.failed == 0 + assert any("republish" in w for w in result.warnings) + # Tool remap recorded; registry remap absent (republish never landed). + assert ctx.remap.resolve("custom_tool", "src-tool-x") is not None + assert ctx.remap.resolve("prompt_studio_registry", SRC_REG) is None + + +def test_remap_prompts_maps_src_to_tgt_by_prompt_key(): + import threading + + src = FakeClient() + tgt = FakeClient() + src.prompts_by_tool["src-tool"] = [ + {"prompt_id": "sp1", "prompt_key": "k1"}, + {"prompt_id": "sp2", "prompt_key": "k2"}, + ] + tgt.prompts_by_tool["tgt-tool"] = [ + {"prompt_id": "tp1", "prompt_key": "k1"}, + {"prompt_id": "tp2", "prompt_key": "k2"}, + ] + ctx = _ctx(src, tgt) + + CustomToolPhase(ctx)._remap_prompts( + "src-tool", "tgt-tool", "T", threading.Lock() + ) + + assert ctx.remap.resolve("prompt", "sp1") == "tp1" + assert ctx.remap.resolve("prompt", "sp2") == "tp2" + + +def test_record_planned_prompts_records_synthetic_remaps(): + import threading + + src = FakeClient() + src.prompts_by_tool["src-tool"] = [ + {"prompt_id": "sp1", "prompt_key": "k1"}, + ] + ctx = _ctx(src, FakeClient(), dry_run=True) + + CustomToolPhase(ctx)._record_planned_prompts("src-tool", threading.Lock()) + + planned = ctx.remap.resolve("prompt", "sp1") + assert planned is not None and ctx.remap.is_planned(planned) diff --git a/tests/clone/test_group_phase.py b/tests/clone/test_group_phase.py index 8ce3d0c..c661f9a 100644 --- a/tests/clone/test_group_phase.py +++ b/tests/clone/test_group_phase.py @@ -129,9 +129,8 @@ def test_member_cloning_matches_by_email_and_skips_missing(): 1: [ {"user_id": 7, "email": "alice@x.com"}, {"user_id": 8, "email": "ghost@x.com"}, # not in target org - # service acct via email-suffix fallback (no flag in row) - {"user_id": 9, "email": "svc@platform.internal"}, - # service acct via backend flag (email alone wouldn't tell) + # service accts flagged by the backend; email alone wouldn't tell + {"user_id": 9, "email": "svc@x.com", "is_service_account": True}, {"user_id": 10, "email": "bot@x.com", "is_service_account": True}, ] }, @@ -151,7 +150,7 @@ def test_member_cloning_matches_by_email_and_skips_missing(): assert tgt.member_posts == [(tgt_group_id, [70])] assert any("ghost@x.com" in w for w in result.warnings) # service accounts are skipped silently, not warned about - assert not any("platform.internal" in w for w in result.warnings) + assert not any("svc@x.com" in w for w in result.warnings) assert not any("bot@x.com" in w for w in result.warnings) diff --git a/tests/clone/test_lookups_phase.py b/tests/clone/test_lookups_phase.py new file mode 100644 index 0000000..48f5446 --- /dev/null +++ b/tests/clone/test_lookups_phase.py @@ -0,0 +1,641 @@ +"""Tests for ``LookupsPhase`` (cloud-only Lookups feature). + +Covers: create-fresh definition + draft template patch + adapter remap; +adopt-by-name; a draft-pinned assignment remapped via the ``prompt`` + +``lookup_definition`` tables; share replication (PATCH with mapped users + +``shared_to_org``); published-version replay (publish in ``version_number`` +order + ``lookup_version`` remap recorded + draft restored after replay); a +published-pinned assignment resolved via the version remap; dry-run records a +planned remap and writes nothing. + +A single scripted fake plays both source and target; the target side records +every write so assertions read off ``posts`` / ``patches``. +""" + +from __future__ import annotations + +from unstract.clone.context import CloneContext, CloneOptions, RemapTable +from unstract.clone.phases.lookups import LookupsPhase +from unstract.clone.report import CloneReport + + +class FakeClient: + POST_SCHEMA = frozenset({"name", "description", "shared_to_org", "shared_users"}) + + def __init__( + self, + *, + lookups=None, + details=None, + files=None, + file_blobs=None, + assignments=None, + versions=None, + version_details=None, + version_file_blobs=None, + users=None, + ): + self._users = list(users or []) + self.lookups = list(lookups or []) + # lookup_id -> detail dict (draft template + adapters + draft_version_id) + self.details = dict(details or {}) + # lookup_id -> list of file rows + self.files = {k: list(v) for k, v in (files or {}).items()} + # file_id -> bytes + self.file_blobs = dict(file_blobs or {}) + self.assignments = list(assignments or []) + # lookup_id -> list of version rows (draft + published) + self.versions = {k: list(v) for k, v in (versions or {}).items()} + # version_id -> version detail dict + self.version_details = dict(version_details or {}) + # (version_id, file_id) -> bytes + self.version_file_blobs = dict(version_file_blobs or {}) + + self.created_lookups: list[dict] = [] + self.draft_template_patches: list[tuple[str, str]] = [] + self.draft_adapter_patches: list[tuple[str, dict]] = [] + self.uploaded_files: list[tuple[str, str]] = [] + self.created_assignments: list[dict] = [] + self.share_patches: list[tuple[str, dict]] = [] + self.published_versions: list[tuple[str, dict]] = [] + self._next_id = 1 + + # ----- schema / definitions ----- + + def get_post_schema(self, entity_path): + return self.POST_SCHEMA + + def list_lookup_definitions(self): + return list(self.lookups) + + def get_lookup_definition(self, lookup_id): + return self.details[lookup_id] + + def create_lookup_definition(self, payload): + new = dict(payload) + lid = f"tgt-lookup-{self._next_id:04d}" + self._next_id += 1 + new["lookup_id"] = lid + self.lookups.append(new) + self.created_lookups.append(new) + # Fresh definition auto-spawns an empty draft with a default version id. + self.details[lid] = { + "prompt_template": "", + "draft_version_id": f"tgt-draft-{lid}", + "adapters": {"llm": None, "x2text": None}, + } + return new + + def update_lookup_draft_template(self, lookup_id, prompt_template): + self.draft_template_patches.append((lookup_id, prompt_template)) + self.details[lookup_id]["prompt_template"] = prompt_template + return self.details[lookup_id] + + def update_lookup_draft_adapters(self, lookup_id, adapters): + self.draft_adapter_patches.append((lookup_id, adapters)) + self.details[lookup_id]["adapters"].update(adapters) + return self.details[lookup_id]["adapters"] + + # ----- files ----- + + def list_lookup_files(self, lookup_id): + return list(self.files.get(lookup_id, [])) + + def download_lookup_file(self, lookup_id, file_id): + return self.file_blobs[file_id] + + def upload_lookup_file(self, lookup_id, file_name, data, mime_type): + self.uploaded_files.append((lookup_id, file_name)) + self.files.setdefault(lookup_id, []).append( + {"file_id": f"tgt-file-{self._next_id}", "file_name": file_name} + ) + self._next_id += 1 + return {"file_id": f"tgt-file-{file_name}"} + + # ----- share ----- + + def list_users(self): + return list(self._users) + + def update_lookup_share(self, lookup_id, payload): + self.share_patches.append((lookup_id, payload)) + return {"lookup_id": lookup_id, **payload} + + # ----- versions ----- + + def list_lookup_versions(self, lookup_id): + return list(self.versions.get(lookup_id, [])) + + def get_lookup_version(self, lookup_id, version_id): + return self.version_details[version_id] + + def download_lookup_version_file(self, lookup_id, version_id, file_id): + return self.version_file_blobs[(version_id, file_id)] + + def publish_lookup_version(self, lookup_id, payload): + """Freeze the current draft into a published version + spawn a fresh + draft (mirrors the backend's ``_publish_draft``). + """ + detail = self.details[lookup_id] + max_num = max( + (v.get("version_number") or 0 for v in self.versions.get(lookup_id, [])), + default=0, + ) + vid = f"tgt-ver-{self._next_id:04d}" + self._next_id += 1 + published = { + "version_id": vid, + "is_draft": False, + "version_name": payload.get("version_name") or f"v{max_num + 1}", + "version_number": max_num + 1, + } + self.versions.setdefault(lookup_id, []).append(published) + # New empty-ish draft: backend clones the published content into it, + # but the phase re-stages content per version anyway. + new_draft_id = f"tgt-draft-{vid}" + detail["draft_version_id"] = new_draft_id + self.published_versions.append((lookup_id, published)) + return published + + # ----- assignments ----- + + def list_lookup_assignments(self): + return list(self.assignments) + + def create_lookup_assignment(self, payload): + new = dict(payload) + new["assignment_id"] = f"tgt-asg-{self._next_id:04d}" + self._next_id += 1 + self.created_assignments.append(new) + self.assignments.append(new) + return new + + +def _ctx(source, target, *, remap=None, **opt_overrides): + return CloneContext( + source=source, + target=target, + options=CloneOptions(**opt_overrides), + remap=remap or RemapTable(), + ) + + +def _src_lookup(lookup_id, name): + return {"lookup_id": lookup_id, "name": name, "description": f"{name} desc"} + + +def _src_detail( + template, + *, + llm=None, + x2text=None, + shared_to_org=False, + shared_users=None, + created_by=None, +): + return { + "prompt_template": template, + "draft_version_id": "src-draft", + "adapters": {"llm": llm, "x2text": x2text}, + "shared_to_org": shared_to_org, + "shared_users": list(shared_users or []), + "created_by": created_by, + } + + +def test_create_fresh_with_draft_template_and_adapter_remap(): + src = FakeClient( + lookups=[_src_lookup("src-lk", "Vendors")], + details={ + "src-lk": _src_detail("Find {{vendor}}", llm="src-llm", x2text="src-x2t") + }, + ) + tgt = FakeClient() + remap = RemapTable() + # AdapterPhase recorded these earlier in the run. + remap.record("adapter", "src-llm", "tgt-llm") + remap.record("adapter", "src-x2t", "tgt-x2t") + ctx = _ctx(src, tgt, remap=remap) + report = CloneReport() + + result = LookupsPhase(ctx).run(report) + + assert result.created == 1 + assert len(tgt.created_lookups) == 1 + new_id = tgt.created_lookups[0]["lookup_id"] + assert remap.resolve("lookup_definition", "src-lk") == new_id + # Draft template replicated. + assert tgt.draft_template_patches == [(new_id, "Find {{vendor}}")] + # Both adapters remapped to target ids in one PATCH. + assert tgt.draft_adapter_patches == [ + (new_id, {"llm": "tgt-llm", "x2text": "tgt-x2t"}) + ] + + +def test_unresolved_adapter_is_skipped_with_warning(): + src = FakeClient( + lookups=[_src_lookup("src-lk", "Vendors")], + details={"src-lk": _src_detail("T", llm="src-llm", x2text="src-x2t")}, + ) + tgt = FakeClient() + remap = RemapTable() + remap.record("adapter", "src-llm", "tgt-llm") # x2text intentionally absent + ctx = _ctx(src, tgt, remap=remap) + report = CloneReport() + + result = LookupsPhase(ctx).run(report) + + new_id = tgt.created_lookups[0]["lookup_id"] + assert tgt.draft_adapter_patches == [(new_id, {"llm": "tgt-llm"})] + assert any("x2text adapter not remapped" in w for w in result.warnings) + + +def test_adopt_by_name_records_remap_no_create(): + src = FakeClient( + lookups=[_src_lookup("src-lk", "Vendors")], + details={"src-lk": _src_detail("T")}, + ) + tgt = FakeClient( + lookups=[{"lookup_id": "tgt-existing", "name": "Vendors"}], + details={ + "tgt-existing": { + "prompt_template": "", + "draft_version_id": "tgt-existing-draft", + "adapters": {"llm": None, "x2text": None}, + } + }, + ) + ctx = _ctx(src, tgt, on_name_conflict="adopt") + report = CloneReport() + + result = LookupsPhase(ctx).run(report) + + assert result.adopted == 1 + assert result.created == 0 + assert tgt.created_lookups == [] + assert ctx.remap.resolve("lookup_definition", "src-lk") == "tgt-existing" + + +def test_share_replication_patches_mapped_users_and_org_flag(): + src = FakeClient( + lookups=[_src_lookup("src-lk", "Vendors")], + details={ + "src-lk": _src_detail( + "T", + shared_to_org=True, + shared_users=[10, 20], + created_by=99, # owner — skipped from share payload + ) + }, + users=[ + {"id": 10, "email": "a@x.com"}, + {"id": 20, "email": "b@x.com"}, + {"id": 99, "email": "owner@x.com"}, + ], + ) + tgt = FakeClient( + users=[ + {"id": 110, "email": "a@x.com"}, + {"id": 120, "email": "b@x.com"}, + ], + ) + ctx = _ctx(src, tgt) + report = CloneReport() + + LookupsPhase(ctx).run(report) + + new_id = tgt.created_lookups[0]["lookup_id"] + assert len(tgt.share_patches) == 1 + lid, payload = tgt.share_patches[0] + assert lid == new_id + assert payload["shared_to_org"] is True + assert sorted(payload["shared_users"]) == [110, 120] + # Lookups have no group sharing — axis omitted entirely. + assert "shared_groups" not in payload + + +def test_draft_pinned_assignment_remapped(): + src = FakeClient( + lookups=[_src_lookup("src-lk", "Vendors")], + details={"src-lk": _src_detail("T")}, + assignments=[ + { + "assignment_id": "src-asg", + "prompt": "src-prompt", + "version": "src-draft", + "lookup_definition": "src-lk", + "is_draft_version": True, + "variable_mappings": {"vendor": "src-prompt"}, + } + ], + ) + tgt = FakeClient() + remap = RemapTable() + # custom_tool phase recorded the prompt remap. + remap.record("prompt", "src-prompt", "tgt-prompt") + ctx = _ctx(src, tgt, remap=remap) + report = CloneReport() + + LookupsPhase(ctx).run(report) + + assert len(tgt.created_assignments) == 1 + asg = tgt.created_assignments[0] + assert asg["prompt"] == "tgt-prompt" + new_lookup = tgt.created_lookups[0]["lookup_id"] + assert asg["lookup_definition"] == new_lookup + assert asg["version"] == f"tgt-draft-{new_lookup}" + # Mapping value that is a source prompt uuid is remapped too. + assert asg["variable_mappings"] == {"vendor": "tgt-prompt"} + + +def _src_published_lookup(): + """A source lookup carrying one published version + a draft, with a + published-pinned assignment. Shared by the replay/resolution tests. + """ + return FakeClient( + lookups=[_src_lookup("src-lk", "Vendors")], + details={ + "src-lk": _src_detail("Current draft", llm="src-llm") + }, + versions={ + "src-lk": [ + { + "version_id": "src-draft", + "is_draft": True, + "version_name": "", + "version_number": 0, + }, + { + "version_id": "src-v1", + "is_draft": False, + "version_name": "v1", + "version_number": 1, + }, + ] + }, + version_details={ + "src-v1": { + "version_id": "src-v1", + "is_draft": False, + "version_name": "v1", + "version_number": 1, + "prompt_template": "Frozen v1", + "adapters": {"llm": "src-llm", "x2text": None}, + "files": [], + } + }, + assignments=[ + { + "assignment_id": "src-asg", + "prompt": "src-prompt", + "version": "src-v1", + "lookup_definition": "src-lk", + "is_draft_version": False, + "variable_mappings": {}, + } + ], + ) + + +def test_published_version_replayed_publishes_in_order_and_records_remap(): + src = FakeClient( + lookups=[_src_lookup("src-lk", "Vendors")], + details={"src-lk": _src_detail("Current draft", llm="src-llm")}, + versions={ + "src-lk": [ + # Out-of-order on purpose: replay must sort by version_number. + { + "version_id": "src-v2", + "is_draft": False, + "version_name": "v2", + "version_number": 2, + }, + { + "version_id": "src-draft", + "is_draft": True, + "version_name": "", + "version_number": 0, + }, + { + "version_id": "src-v1", + "is_draft": False, + "version_name": "v1", + "version_number": 1, + }, + ] + }, + version_details={ + "src-v1": { + "version_id": "src-v1", + "is_draft": False, + "version_name": "v1", + "version_number": 1, + "prompt_template": "Frozen v1", + "adapters": {"llm": "src-llm", "x2text": None}, + "files": [], + }, + "src-v2": { + "version_id": "src-v2", + "is_draft": False, + "version_name": "v2", + "version_number": 2, + "prompt_template": "Frozen v2", + "adapters": {"llm": "src-llm", "x2text": None}, + "files": [], + }, + }, + ) + tgt = FakeClient() + remap = RemapTable() + remap.record("adapter", "src-llm", "tgt-llm") + ctx = _ctx(src, tgt, remap=remap) + report = CloneReport() + + LookupsPhase(ctx).run(report) + + # Published in version_number order. + assert [p[1]["version_name"] for p in tgt.published_versions] == ["v1", "v2"] + # A version remap recorded for each published version. + assert remap.resolve("lookup_version", "src-v1") is not None + assert remap.resolve("lookup_version", "src-v2") is not None + + +def test_staging_template_failure_skips_publish(): + # If staging a version's template onto the draft fails, the version must + # NOT be published — else stale content freezes into a named version. + src = FakeClient( + lookups=[_src_lookup("src-lk", "Vendors")], + details={"src-lk": _src_detail("Current draft", llm="src-llm")}, + versions={ + "src-lk": [ + { + "version_id": "src-v1", + "is_draft": False, + "version_name": "v1", + "version_number": 1, + }, + ] + }, + version_details={ + "src-v1": { + "version_id": "src-v1", + "is_draft": False, + "version_name": "v1", + "version_number": 1, + "prompt_template": "Frozen v1", + "adapters": {"llm": "src-llm", "x2text": None}, + "files": [], + }, + }, + ) + tgt = FakeClient() + + def _boom(lookup_id, prompt_template): + raise RuntimeError("template too long") + + tgt.update_lookup_draft_template = _boom + remap = RemapTable() + remap.record("adapter", "src-llm", "tgt-llm") + ctx = _ctx(src, tgt, remap=remap) + + result = LookupsPhase(ctx).run(CloneReport()) + + assert tgt.published_versions == [] # staging failed -> no publish + assert result.failed >= 1 + + +def test_published_version_adopted_on_rerun_no_republish(): + # Target already has a same-name definition with the same published version + # names: re-run must adopt them (no re-publish) and still record remaps. + src = FakeClient( + lookups=[_src_lookup("src-lk", "Vendors")], + details={"src-lk": _src_detail("Current draft", llm="src-llm")}, + versions={ + "src-lk": [ + { + "version_id": "src-v1", + "is_draft": False, + "version_name": "v1", + "version_number": 1, + }, + ] + }, + version_details={ + "src-v1": { + "version_id": "src-v1", + "is_draft": False, + "version_name": "v1", + "version_number": 1, + "prompt_template": "Frozen v1", + "adapters": {"llm": "src-llm", "x2text": None}, + "files": [], + }, + }, + ) + tgt = FakeClient( + lookups=[{"lookup_id": "tgt-lk", "name": "Vendors"}], + details={ + "tgt-lk": { + "prompt_template": "", + "draft_version_id": "tgt-draft-lk", + "adapters": {"llm": None, "x2text": None}, + } + }, + versions={ + "tgt-lk": [ + { + "version_id": "tgt-v1", + "is_draft": False, + "version_name": "v1", + "version_number": 1, + }, + ] + }, + ) + remap = RemapTable() + remap.record("adapter", "src-llm", "tgt-llm") + ctx = _ctx(src, tgt, remap=remap) + report = CloneReport() + + LookupsPhase(ctx).run(report) + + # Nothing re-published; remap points at the existing target version. + assert tgt.published_versions == [] + assert remap.resolve("lookup_version", "src-v1") == "tgt-v1" + + +def test_published_pinned_assignment_resolves_via_version_remap(): + src = _src_published_lookup() + tgt = FakeClient() + remap = RemapTable() + remap.record("prompt", "src-prompt", "tgt-prompt") + remap.record("adapter", "src-llm", "tgt-llm") + ctx = _ctx(src, tgt, remap=remap) + report = CloneReport() + + LookupsPhase(ctx).run(report) + + # No longer skipped — the published pin resolves via the version remap. + assert len(tgt.created_assignments) == 1 + asg = tgt.created_assignments[0] + assert asg["prompt"] == "tgt-prompt" + tgt_v1 = remap.resolve("lookup_version", "src-v1") + assert tgt_v1 is not None + assert asg["version"] == tgt_v1 + + +def test_draft_restored_after_replay(): + src = _src_published_lookup() + tgt = FakeClient() + remap = RemapTable() + remap.record("prompt", "src-prompt", "tgt-prompt") + remap.record("adapter", "src-llm", "tgt-llm") + ctx = _ctx(src, tgt, remap=remap) + report = CloneReport() + + LookupsPhase(ctx).run(report) + + new_id = tgt.created_lookups[0]["lookup_id"] + # The LAST template patch on the target draft is the source's CURRENT + # draft, not the frozen v1 content staged during replay. + last_template = [ + t for (lid, t) in tgt.draft_template_patches if lid == new_id + ][-1] + assert last_template == "Current draft" + # Source draft version id maps to the target's final draft id. + src_draft = src.details["src-lk"]["draft_version_id"] + tgt_draft = tgt.details[new_id]["draft_version_id"] + assert remap.resolve("lookup_version", src_draft) == tgt_draft + + +def test_dry_run_records_planned_and_writes_nothing(): + src = FakeClient( + lookups=[_src_lookup("src-lk", "Vendors")], + details={"src-lk": _src_detail("T", llm="src-llm")}, + assignments=[ + { + "assignment_id": "src-asg", + "prompt": "src-prompt", + "version": "src-draft", + "lookup_definition": "src-lk", + "is_draft_version": True, + "variable_mappings": {}, + } + ], + ) + tgt = FakeClient() + remap = RemapTable() + remap.record("prompt", "src-prompt", "tgt-prompt") + ctx = _ctx(src, tgt, remap=remap, dry_run=True) + report = CloneReport() + + result = LookupsPhase(ctx).run(report) + + # One planned definition + one planned assignment. + assert result.created == 2 + assert tgt.created_lookups == [] + assert tgt.created_assignments == [] + assert tgt.draft_template_patches == [] + assert tgt.draft_adapter_patches == [] + planned = ctx.remap.resolve("lookup_definition", "src-lk") + assert planned is not None and ctx.remap.is_planned(planned) diff --git a/tests/clone/test_manual_review_phase.py b/tests/clone/test_manual_review_phase.py new file mode 100644 index 0000000..a11b46d --- /dev/null +++ b/tests/clone/test_manual_review_phase.py @@ -0,0 +1,298 @@ +"""Tests for ``ManualReviewPhase`` (cloud-only HITL feature). + +Covers: per-workflow RuleEngine + HITLSettings cloned with the workflow +remap (incl. nested confidence_filters); a workflow with no target mapping is +silently skipped; org-level AutoApprovalSettings cloned once with +auto_approved_users remapped by email and doc-classes carried verbatim; +ReviewApiKey recreation emits the re-wire warning; dry-run plans the creates +without writing. + +A single scripted fake plays both source and target; the target side records +every write so assertions read off the ``created_*`` lists. +""" + +from __future__ import annotations + +from unstract.clone.context import CloneContext, CloneOptions, RemapTable +from unstract.clone.phases.manual_review import ManualReviewPhase +from unstract.clone.report import CloneReport + + +class FakeClient: + MR_RULE_TYPES = ("DB", "API") + + def __init__( + self, + *, + workflows=None, + rules=None, + settings=None, + auto_approval=None, + api_keys=None, + users=None, + ): + self.workflows = list(workflows or []) + # (workflow_id, rule_type) -> rule dict + self.rules = dict(rules or {}) + # workflow_id -> settings dict + self.settings = dict(settings or {}) + self.auto_approval = list(auto_approval or []) + self.api_keys = list(api_keys or []) + self.users = list(users or []) + + self.created_rules: list[dict] = [] + self.created_settings: list[dict] = [] + self.created_auto_approval: list[dict] = [] + self.created_api_keys: list[dict] = [] + + def list_workflows(self): + return list(self.workflows) + + def list_users(self): + return list(self.users) + + def get_review_rule(self, workflow_id, rule_type): + return self.rules.get((str(workflow_id), rule_type)) + + def create_review_rule(self, payload): + self.created_rules.append(payload) + self.rules[(str(payload["workflow"]), payload.get("rule_type", "DB"))] = payload + return payload + + def get_review_settings(self, workflow_id): + return self.settings.get(str(workflow_id)) + + def create_review_settings(self, payload): + self.created_settings.append(payload) + self.settings[str(payload["workflow"])] = payload + return payload + + def list_auto_approval_settings(self): + return list(self.auto_approval) + + def create_auto_approval_settings(self, payload): + self.created_auto_approval.append(payload) + self.auto_approval.append(payload) + return payload + + def list_review_api_keys(self): + return list(self.api_keys) + + def create_review_api_key(self, payload): + self.created_api_keys.append(payload) + self.api_keys.append(payload) + return payload + + +def _ctx(source, target, *, remap=None, **opt_overrides): + return CloneContext( + source=source, + target=target, + options=CloneOptions(**opt_overrides), + remap=remap or RemapTable(), + ) + + +def _src_settings(workflow="src-wf"): + return {"workflow": workflow, "sync_with": "DB", "ttl_hours": 100} + + +def _src_rule(rule_type, **over): + rule = { + "id": f"src-rule-{rule_type}", + "workflow": "src-wf", + "rule_type": rule_type, + "percentage": 25, + "rule_string": "x > 1", + "rule_json": {"x": 1}, + "rule_logic": "OR", + "confidence_filters": [ + {"id": "cf1", "field_key": "amount", "confidence_threshold": 80} + ], + } + rule.update(over) + return rule + + +def test_rule_and_settings_cloned_with_workflow_remap(): + src = FakeClient( + workflows=[{"id": "src-wf", "workflow_name": "WF"}], + rules={("src-wf", "DB"): _src_rule("DB")}, + settings={"src-wf": _src_settings()}, + ) + tgt = FakeClient() + remap = RemapTable() + remap.record("workflow", "src-wf", "tgt-wf") + ctx = _ctx(src, tgt, remap=remap) + report = CloneReport() + + result = ManualReviewPhase(ctx).run(report) + + # One DB rule + one settings row created (no API rule in source). + assert len(tgt.created_rules) == 1 + rule = tgt.created_rules[0] + assert rule["workflow"] == "tgt-wf" + assert rule["rule_type"] == "DB" + assert rule["percentage"] == 25 + # Nested filters carried, server-managed id stripped. + assert rule["confidence_filters"] == [ + {"field_key": "amount", "confidence_threshold": 80} + ] + assert "id" not in rule + + assert len(tgt.created_settings) == 1 + settings = tgt.created_settings[0] + assert settings == {"workflow": "tgt-wf", "sync_with": "DB", "ttl_hours": 100} + assert result.created == 2 + + +def test_workflow_without_target_mapping_skipped(): + src = FakeClient( + workflows=[{"id": "src-wf", "workflow_name": "WF"}], + rules={("src-wf", "DB"): _src_rule("DB")}, + settings={"src-wf": _src_settings()}, + ) + tgt = FakeClient() + # No workflow remap recorded — its tool/workflow wasn't cloned. + ctx = _ctx(src, tgt, remap=RemapTable()) + report = CloneReport() + + result = ManualReviewPhase(ctx).run(report) + + assert tgt.created_rules == [] + assert tgt.created_settings == [] + assert result.created == 0 + assert result.failed == 0 + + +def test_auto_approval_remaps_users_by_email_and_carries_doc_classes(): + src = FakeClient( + workflows=[], + auto_approval=[ + { + "id": "aa1", + "auto_approved_document_classes": ["cls-1"], + "auto_approved_users": ["7"], + } + ], + users=[{"id": 7, "email": "Alice@x.com"}], + ) + tgt = FakeClient(users=[{"id": 42, "email": "alice@x.com"}]) + ctx = _ctx(src, tgt, remap=RemapTable()) + report = CloneReport() + + result = ManualReviewPhase(ctx).run(report) + + assert len(tgt.created_auto_approval) == 1 + payload = tgt.created_auto_approval[0] + # Doc classes carried verbatim (no cross-org remap). + assert payload["auto_approved_document_classes"] == ["cls-1"] + # User src pk 7 -> target pk 42 via case-insensitive email match. + assert payload["auto_approved_users"] == ["42"] + assert "organization" not in payload + assert any("manual verification" in w for w in result.warnings) + assert result.created == 1 + + +def test_auto_approval_user_absent_on_target_skipped_with_warning(): + src = FakeClient( + workflows=[], + auto_approval=[ + { + "id": "aa1", + "auto_approved_document_classes": [], + "auto_approved_users": ["7", "8"], + } + ], + users=[ + {"id": 7, "email": "alice@x.com"}, + {"id": 8, "email": "bob@x.com"}, + ], + ) + # bob has no counterpart on the target org. + tgt = FakeClient(users=[{"id": 42, "email": "alice@x.com"}]) + ctx = _ctx(src, tgt, remap=RemapTable()) + report = CloneReport() + + result = ManualReviewPhase(ctx).run(report) + + payload = tgt.created_auto_approval[0] + assert payload["auto_approved_users"] == ["42"] + assert any("bob@x.com not found in target org" in w for w in result.warnings) + + +def test_review_api_key_recreated_with_warning(): + src = FakeClient( + workflows=[], + api_keys=[ + { + "id": "k1", + "api_key": "secret-uuid", + "class_name": "invoices", + "description": "d", + "is_active": True, + } + ], + ) + tgt = FakeClient() + ctx = _ctx(src, tgt, remap=RemapTable()) + report = CloneReport() + + result = ManualReviewPhase(ctx).run(report) + + assert len(tgt.created_api_keys) == 1 + payload = tgt.created_api_keys[0] + # Secret + server-managed id NOT carried over. + assert "api_key" not in payload + assert "id" not in payload + assert payload == {"class_name": "invoices", "description": "d", "is_active": True} + assert any("re-wire any external consumers" in w for w in result.warnings) + + +def test_review_api_key_adopted_on_rerun(): + # Target already carries a key with the same (class_name, description): + # re-run must adopt, not create a duplicate or re-warn. + key = {"class_name": "invoices", "description": "d", "is_active": True} + src = FakeClient(workflows=[], api_keys=[{"id": "k1", "api_key": "s", **key}]) + tgt = FakeClient(api_keys=[{"id": "t1", "api_key": "other", **key}]) + ctx = _ctx(src, tgt, remap=RemapTable()) + report = CloneReport() + + result = ManualReviewPhase(ctx).run(report) + + assert tgt.created_api_keys == [] + assert result.adopted >= 1 + assert not any("re-wire any external consumers" in w for w in result.warnings) + + +def test_dry_run_plans_without_writing(): + src = FakeClient( + workflows=[{"id": "src-wf", "workflow_name": "WF"}], + rules={ + ("src-wf", "DB"): _src_rule("DB"), + ("src-wf", "API"): _src_rule("API"), + }, + settings={"src-wf": _src_settings()}, + auto_approval=[ + { + "id": "aa", + "auto_approved_document_classes": [], + "auto_approved_users": [], + } + ], + api_keys=[{"id": "k1", "class_name": "c", "is_active": True}], + ) + tgt = FakeClient() + remap = RemapTable() + remap.record("workflow", "src-wf", "tgt-wf") + ctx = _ctx(src, tgt, remap=remap, dry_run=True) + report = CloneReport() + + result = ManualReviewPhase(ctx).run(report) + + # 2 rules + 1 settings + 1 auto-approval + 1 api key planned. + assert result.created == 5 + assert tgt.created_rules == [] + assert tgt.created_settings == [] + assert tgt.created_auto_approval == [] + assert tgt.created_api_keys == [] diff --git a/tests/clone/test_pipeline_phase.py b/tests/clone/test_pipeline_phase.py index f1c09fc..8b53725 100644 --- a/tests/clone/test_pipeline_phase.py +++ b/tests/clone/test_pipeline_phase.py @@ -45,6 +45,7 @@ class FakeClient: def __init__(self, pipelines: list[dict] | None = None): self.pipelines: list[dict] = list(pipelines or []) self.posts: list[dict] = [] + self.patches: list[tuple[str, dict]] = [] self.keys_by_pipeline: dict[str, list[dict]] = {} self._next = 1 @@ -75,6 +76,14 @@ def create_pipeline(self, payload: dict) -> dict: self.posts.append(new) return new + def update_pipeline(self, pipeline_id: str, payload: dict) -> dict: + self.patches.append((pipeline_id, payload)) + for p in self.pipelines: + if p["id"] == pipeline_id: + p.update(payload) + return dict(p) + raise KeyError(pipeline_id) + def list_pipeline_keys(self, pipeline_id: str) -> list[dict]: return list(self.keys_by_pipeline.get(pipeline_id, [])) @@ -172,6 +181,38 @@ def get_pipeline(self, pipeline_id): assert posted["cron_string"] == "0 5 * * *" +def test_inactive_source_pipeline_deactivated_on_target(): + # Backend force-activates on create; an inactive source pipeline must be + # patched back to inactive so its schedule doesn't run on the target. + pl = _src_pipeline("src-pl-1", "Disabled ETL", "wf-src-1", cron_string="0 5 * * *") + pl["active"] = False + src = FakeClient([pl]) + tgt = FakeClient() + remap = RemapTable() + remap.record("workflow", "wf-src-1", "wf-tgt-1") + ctx = _ctx(src, tgt, remap=remap) + + result = PipelinePhase(ctx).run(CloneReport()) + + assert result.created == 1 + posted = tgt.posts[0] + assert tgt.patches == [(posted["id"], {"active": False})] + + +def test_active_source_pipeline_not_patched(): + src = FakeClient( + [_src_pipeline("src-pl-1", "Live ETL", "wf-src-1", cron_string="0 5 * * *")] + ) + tgt = FakeClient() + remap = RemapTable() + remap.record("workflow", "wf-src-1", "wf-tgt-1") + ctx = _ctx(src, tgt, remap=remap) + + PipelinePhase(ctx).run(CloneReport()) + + assert tgt.patches == [] + + def test_default_and_app_pipeline_types_are_skipped(): src = FakeClient( [ diff --git a/tests/clone/test_sharing.py b/tests/clone/test_sharing.py index 7a0a3af..531895a 100644 --- a/tests/clone/test_sharing.py +++ b/tests/clone/test_sharing.py @@ -55,9 +55,8 @@ def test_share_payload_maps_users_groups_and_org_flag(): src_client = FakeClient( users=[ {"id": "7", "email": "alice@x.com"}, - # service account via email-suffix fallback (no flag in row) - {"id": "8", "email": "svc@platform.internal"}, - # service account via backend flag (email alone wouldn't tell) + # service accounts flagged by the backend; email alone wouldn't tell + {"id": "8", "email": "svc@x.com", "is_service_account": True}, {"id": "9", "email": "bot@x.com", "is_service_account": True}, ] ) @@ -173,6 +172,23 @@ def test_share_dry_run_never_posts(): assert tgt_client.share_posts == [] +def test_share_dry_run_planned_group_remap_does_not_crash(): + # In dry-run the group remap resolves to a synthetic uuid (planned), not + # an int pk — building the payload must not int()-cast and blow up. + ctx = _ctx(FakeClient(), FakeClient(), dry_run=True) + ctx.remap.record_planned("group", "1") + planned = ctx.remap.resolve("group", "1") + + result = _apply( + ctx, + {"shared_users": [], "shared_groups": [1], "shared_to_org": False}, + ) + + assert ctx.target.share_posts == [] # dry-run never posts + assert not result.errors + assert ctx.remap.is_planned(planned) + + def test_share_fetches_source_detail_when_axes_missing_from_list_row(): ctx = _ctx(FakeClient(), FakeClient()) detail = { diff --git a/tests/clone/test_workflow_endpoint_phase.py b/tests/clone/test_workflow_endpoint_phase.py index 6ec8072..496f27c 100644 --- a/tests/clone/test_workflow_endpoint_phase.py +++ b/tests/clone/test_workflow_endpoint_phase.py @@ -5,7 +5,8 @@ - pairs source/target endpoints by ``endpoint_type``; - remaps the embedded ``connector_instance`` UUID; - walker-rewrites UUIDs nested in ``configuration``; -- silently leaves connector_instance_id null when no remap exists. +- sets connection_type even when a connector has no remap (connector left + unset for the operator to re-bind). """ from __future__ import annotations @@ -179,10 +180,12 @@ def test_endpoint_without_source_connector_patches_with_null(): assert payload["configuration"] == {"foo": "bar"} -def test_unknown_connector_uuid_skips_endpoint_and_flags_error(): - """Source had a connector but its remap is missing — patching with - connector=None would silently detach the endpoint on target. Skip - the PATCH and record an operator-visible error entry instead. +def test_unmapped_connector_still_sets_connection_type_and_warns(): + """Source had a connector but its remap is missing (e.g. an OAuth + connector that couldn't be cloned). The phase still patches + connection_type so the endpoint is valid at runtime, omits the + connector (operator re-binds in the UI), and records a warning rather + than failing the run with "Invalid source connection type". """ src = FakeClient() src.endpoints[SRC_WF] = [ @@ -199,10 +202,14 @@ def test_unknown_connector_uuid_skips_endpoint_and_flags_error(): result = WorkflowEndpointPhase(ctx).run(CloneReport()) - assert result.created == 0 - assert result.skipped == 1 - assert tgt.patch_calls == [] - assert any("unmapped connector" in e for e in result.errors) + assert result.created == 1 + assert result.skipped == 0 + assert result.failed == 0 + assert len(tgt.patch_calls) == 1 + _, payload = tgt.patch_calls[0] + assert payload["connection_type"] == "FILESYSTEM" + assert "connector_instance_id" not in payload + assert any("connector not cloned" in w for w in result.warnings) def test_missing_target_endpoint_fails_loudly():