diff --git a/src/pyrecest/__init__.py b/src/pyrecest/__init__.py index 9df2b5bfa..e2a5f4e3d 100644 --- a/src/pyrecest/__init__.py +++ b/src/pyrecest/__init__.py @@ -198,6 +198,46 @@ def tile(x, reps): backend.tile = tile +def _patch_raw_pytorch_trace_facade() -> None: + """Make raw PyTorch ``trace`` follow NumPy's trace signature.""" + + try: + import pyrecest._backend.pytorch as pytorch_backend # pylint: disable=import-outside-toplevel + except ModuleNotFoundError: # pragma: no cover - only relevant without PyTorch + return + + original_trace = getattr(pytorch_backend, "trace", None) + + def trace(a, offset=0, axis1=-2, axis2=-1, dtype=None, out=None): + values = pytorch_backend.array(a) + if values.ndim < 2: + raise ValueError("diag requires an array of at least two dimensions") + diagonal = pytorch_backend.diagonal( + values, + offset=_operator_index(offset), + axis1=_operator_index(axis1), + axis2=_operator_index(axis2), + ) + result = pytorch_backend.sum(diagonal, axis=-1, dtype=dtype) + if out is not None: + copy_ = getattr(out, "copy_", None) + if copy_ is not None: + copy_(result) + return out + out[...] = result + return out + return result + + trace.__name__ = getattr(original_trace, "__name__", "trace") + trace.__doc__ = getattr(original_trace, "__doc__", None) + pytorch_backend.trace = trace + + import pyrecest.backend as backend # pylint: disable=import-outside-toplevel + + if getattr(backend, "__backend_name__", None) == "pytorch": + backend.trace = trace + + def _patch_jax_std_out_facade() -> None: """Make public JAX ``std`` accept NumPy's ``out`` argument.""" @@ -231,6 +271,7 @@ def std( _patch_pytorch_comparison_facade() _patch_pytorch_tile_facade() +_patch_raw_pytorch_trace_facade() _patch_jax_std_out_facade() from pyrecest.backend_support import ( # noqa: E402,F401 diff --git a/tests/backend_support/test_pytorch_trace_contract.py b/tests/backend_support/test_pytorch_trace_contract.py new file mode 100644 index 000000000..6c6b152b4 --- /dev/null +++ b/tests/backend_support/test_pytorch_trace_contract.py @@ -0,0 +1,55 @@ +import importlib.util +import os +import subprocess +import sys + +import pytest + + +def test_raw_pytorch_trace_accepts_numpy_signature(): + if importlib.util.find_spec("torch") is None: + pytest.skip("PyTorch is not installed") + + env = os.environ.copy() + env.pop("PYRECEST_BACKEND", None) + completed = subprocess.run( + [ + sys.executable, + "-c", + """ +import pyrecest # noqa: F401 # triggers raw-backend compatibility patches +import torch +import pyrecest._backend.pytorch as raw_pytorch_backend + +values = torch.arange(12.0, dtype=torch.float64).reshape(2, 2, 3) +expected = torch.tensor([6.0, 18.0], dtype=torch.float64) + +result = raw_pytorch_backend.trace(values, offset=1, axis1=-2, axis2=-1) +assert torch.allclose(result, expected) + +out = torch.empty(2, dtype=torch.float64) +returned = raw_pytorch_backend.trace( + values, + offset=1, + axis1=-2, + axis2=-1, + dtype=torch.float64, + out=out, +) +assert returned is out +assert torch.allclose(out, expected) + +scalar_result = raw_pytorch_backend.trace([[1.0, 2.0], [3.0, 4.0]]) +assert tuple(scalar_result.shape) == () +assert float(scalar_result) == 5.0 +print("ok") +""", + ], + capture_output=True, + env=env, + text=True, + timeout=30.0, + ) + + assert completed.returncode == 0, completed.stderr + assert "ok" in completed.stdout