diff --git a/src/pyrecest/__init__.py b/src/pyrecest/__init__.py index 9df2b5bfa..97fef7162 100644 --- a/src/pyrecest/__init__.py +++ b/src/pyrecest/__init__.py @@ -91,8 +91,26 @@ def squeeze(x, axis=None): backend.squeeze = squeeze +def _patch_set_diag_arraylike_facade() -> None: + """Make public set_diag accept array-like matrix inputs.""" + + import pyrecest.backend as backend # pylint: disable=import-outside-toplevel + + original_set_diag = backend.set_diag + + def set_diag(x, new_diag): + if not backend.is_array(x): + x = backend.array(x) + return original_set_diag(x, new_diag) + + set_diag.__name__ = getattr(original_set_diag, "__name__", "set_diag") + set_diag.__doc__ = getattr(original_set_diag, "__doc__", None) + backend.set_diag = set_diag + + _patch_shared_numpy_copy_facade() _patch_shared_numpy_squeeze_facade() +_patch_set_diag_arraylike_facade() def _patch_pytorch_comparison_facade() -> None: diff --git a/tests/test_backend_set_diag_contract.py b/tests/test_backend_set_diag_contract.py new file mode 100644 index 000000000..266810524 --- /dev/null +++ b/tests/test_backend_set_diag_contract.py @@ -0,0 +1,14 @@ +import pyrecest.backend as backend + + +def _to_python(value): + value = backend.to_numpy(value) + if hasattr(value, "tolist"): + return value.tolist() + return value + + +def test_set_diag_accepts_array_like_matrix_inputs(): + result = backend.set_diag([[1, 2], [3, 4]], [9, 8]) + + assert _to_python(result) == [[9, 2], [3, 8]]