From 43c45f475dfbfa62735d3dabb892226948e86c69 Mon Sep 17 00:00:00 2001 From: Florian Pfaff <6773539+FlorianPfaff@users.noreply.github.com> Date: Mon, 29 Jun 2026 10:31:28 +0200 Subject: [PATCH 1/2] Patch set_diag array-like facade --- src/pyrecest/__init__.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/src/pyrecest/__init__.py b/src/pyrecest/__init__.py index 9df2b5bfa..97fef7162 100644 --- a/src/pyrecest/__init__.py +++ b/src/pyrecest/__init__.py @@ -91,8 +91,26 @@ def squeeze(x, axis=None): backend.squeeze = squeeze +def _patch_set_diag_arraylike_facade() -> None: + """Make public set_diag accept array-like matrix inputs.""" + + import pyrecest.backend as backend # pylint: disable=import-outside-toplevel + + original_set_diag = backend.set_diag + + def set_diag(x, new_diag): + if not backend.is_array(x): + x = backend.array(x) + return original_set_diag(x, new_diag) + + set_diag.__name__ = getattr(original_set_diag, "__name__", "set_diag") + set_diag.__doc__ = getattr(original_set_diag, "__doc__", None) + backend.set_diag = set_diag + + _patch_shared_numpy_copy_facade() _patch_shared_numpy_squeeze_facade() +_patch_set_diag_arraylike_facade() def _patch_pytorch_comparison_facade() -> None: From 67402658a12e05353ea2262f87a3d133a8e5576b Mon Sep 17 00:00:00 2001 From: Florian Pfaff <6773539+FlorianPfaff@users.noreply.github.com> Date: Mon, 29 Jun 2026 10:31:40 +0200 Subject: [PATCH 2/2] Add set_diag array-like regression test --- tests/test_backend_set_diag_contract.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) create mode 100644 tests/test_backend_set_diag_contract.py diff --git a/tests/test_backend_set_diag_contract.py b/tests/test_backend_set_diag_contract.py new file mode 100644 index 000000000..266810524 --- /dev/null +++ b/tests/test_backend_set_diag_contract.py @@ -0,0 +1,14 @@ +import pyrecest.backend as backend + + +def _to_python(value): + value = backend.to_numpy(value) + if hasattr(value, "tolist"): + return value.tolist() + return value + + +def test_set_diag_accepts_array_like_matrix_inputs(): + result = backend.set_diag([[1, 2], [3, 4]], [9, 8]) + + assert _to_python(result) == [[9, 2], [3, 8]]