diff --git a/datamasque/client/runs.py b/datamasque/client/runs.py index 6b9a827..c25317e 100644 --- a/datamasque/client/runs.py +++ b/datamasque/client/runs.py @@ -1,5 +1,6 @@ import logging import re +from typing import Union from datamasque.client.base import BaseClient from datamasque.client.exceptions import ( @@ -43,9 +44,12 @@ def get_run_report(self, run_id: RunId) -> str: response = self.make_request("GET", f"api/runs/{run_id}/run-report/") return response.text - def get_db_discovery_result_report(self, run_id: RunId, include_selection_column: bool = True) -> str: + def get_db_discovery_result_report(self, run_id: RunId, include_selection_column: bool = True) -> Union[str, bytes]: """ - Returns the database-discovery result report for the specified run as CSV. + Returns the database-discovery result report for the specified run. + + Returns CSV text (`str`), or a zip of numbered CSV parts as `bytes` when the report is + large enough that the server splits it. When `include_selection_column` is true (the default), the CSV includes a `selected` column suitable for feeding back into ruleset generation. @@ -54,6 +58,9 @@ def get_db_discovery_result_report(self, run_id: RunId, include_selection_column url = f"api/runs/{run_id}/db-discovery-results/report/" params = None if include_selection_column else {"include_selection_column": "false"} response = self.make_request("GET", url, params=params) + + if response.headers.get("X-DM-Download-Format") == "zip": + return response.content return response.text def get_unfinished_runs(self) -> dict[str, UnfinishedRun]: diff --git a/tests/test_discovery.py b/tests/test_discovery.py index 2830287..ff2b1ed 100644 --- a/tests/test_discovery.py +++ b/tests/test_discovery.py @@ -93,6 +93,17 @@ def test_get_db_discovery_result_report(client): assert result == "db discovery report without selection column" +def test_get_db_discovery_result_report_returns_zip_bytes_when_split(client): + run_id = RunId(1) + zip_bytes = b"PK\x03\x04 split report zip bytes" + with requests_mock.Mocker() as m: + url = f"http://test-server/api/runs/{run_id}/db-discovery-results/report/" + m.get(url, content=zip_bytes, headers={"X-DM-Download-Format": "zip"}, status_code=200) + result = client.get_db_discovery_result_report(run_id) + assert result == zip_bytes + assert isinstance(result, bytes) + + def test_poll_async_ruleset_generation(client): connection_id = ConnectionId("1") with requests_mock.Mocker() as m: