From 1c9440fe165c60eea33746748bcee485e4f7018a Mon Sep 17 00:00:00 2001 From: Florian Pfaff <6773539+FlorianPfaff@users.noreply.github.com> Date: Mon, 29 Jun 2026 10:18:17 +0200 Subject: [PATCH 1/3] Patch raw PyTorch trace contract --- src/pyrecest/__init__.py | 41 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/src/pyrecest/__init__.py b/src/pyrecest/__init__.py index 9df2b5bfa..e2a5f4e3d 100644 --- a/src/pyrecest/__init__.py +++ b/src/pyrecest/__init__.py @@ -198,6 +198,46 @@ def tile(x, reps): backend.tile = tile +def _patch_raw_pytorch_trace_facade() -> None: + """Make raw PyTorch ``trace`` follow NumPy's trace signature.""" + + try: + import pyrecest._backend.pytorch as pytorch_backend # pylint: disable=import-outside-toplevel + except ModuleNotFoundError: # pragma: no cover - only relevant without PyTorch + return + + original_trace = getattr(pytorch_backend, "trace", None) + + def trace(a, offset=0, axis1=-2, axis2=-1, dtype=None, out=None): + values = pytorch_backend.array(a) + if values.ndim < 2: + raise ValueError("diag requires an array of at least two dimensions") + diagonal = pytorch_backend.diagonal( + values, + offset=_operator_index(offset), + axis1=_operator_index(axis1), + axis2=_operator_index(axis2), + ) + result = pytorch_backend.sum(diagonal, axis=-1, dtype=dtype) + if out is not None: + copy_ = getattr(out, "copy_", None) + if copy_ is not None: + copy_(result) + return out + out[...] = result + return out + return result + + trace.__name__ = getattr(original_trace, "__name__", "trace") + trace.__doc__ = getattr(original_trace, "__doc__", None) + pytorch_backend.trace = trace + + import pyrecest.backend as backend # pylint: disable=import-outside-toplevel + + if getattr(backend, "__backend_name__", None) == "pytorch": + backend.trace = trace + + def _patch_jax_std_out_facade() -> None: """Make public JAX ``std`` accept NumPy's ``out`` argument.""" @@ -231,6 +271,7 @@ def std( _patch_pytorch_comparison_facade() _patch_pytorch_tile_facade() +_patch_raw_pytorch_trace_facade() _patch_jax_std_out_facade() from pyrecest.backend_support import ( # noqa: E402,F401 From beffcc40915afd7b258ae3323d923ba5f810ce70 Mon Sep 17 00:00:00 2001 From: Florian Pfaff <6773539+FlorianPfaff@users.noreply.github.com> Date: Mon, 29 Jun 2026 10:18:45 +0200 Subject: [PATCH 2/3] Add raw PyTorch trace regression test --- .../test_pytorch_trace_contract.py | 34 +++++++++++++++++++ 1 file changed, 34 insertions(+) create mode 100644 tests/backend_support/test_pytorch_trace_contract.py diff --git a/tests/backend_support/test_pytorch_trace_contract.py b/tests/backend_support/test_pytorch_trace_contract.py new file mode 100644 index 000000000..547d765a2 --- /dev/null +++ b/tests/backend_support/test_pytorch_trace_contract.py @@ -0,0 +1,34 @@ +import importlib.util + +import pytest + + +def test_raw_pytorch_trace_accepts_numpy_signature(): + if importlib.util.find_spec("torch") is None: + pytest.skip("PyTorch is not installed") + + import pyrecest # noqa: F401 # triggers raw-backend compatibility patches + import torch + import pyrecest._backend.pytorch as raw_pytorch_backend + + values = torch.arange(12.0, dtype=torch.float64).reshape(2, 2, 3) + expected = torch.tensor([6.0, 18.0], dtype=torch.float64) + + result = raw_pytorch_backend.trace(values, offset=1, axis1=-2, axis2=-1) + assert torch.allclose(result, expected) + + out = torch.empty(2, dtype=torch.float64) + returned = raw_pytorch_backend.trace( + values, + offset=1, + axis1=-2, + axis2=-1, + dtype=torch.float64, + out=out, + ) + assert returned is out + assert torch.allclose(out, expected) + + scalar_result = raw_pytorch_backend.trace([[1.0, 2.0], [3.0, 4.0]]) + assert tuple(scalar_result.shape) == () + assert float(scalar_result) == pytest.approx(5.0) From 23e1ab3692351a6e1483b08e7d4cfb03441ce9a1 Mon Sep 17 00:00:00 2001 From: Florian Pfaff <6773539+FlorianPfaff@users.noreply.github.com> Date: Mon, 29 Jun 2026 10:20:38 +0200 Subject: [PATCH 3/3] Isolate raw PyTorch trace regression test --- .../test_pytorch_trace_contract.py | 67 ++++++++++++------- 1 file changed, 44 insertions(+), 23 deletions(-) diff --git a/tests/backend_support/test_pytorch_trace_contract.py b/tests/backend_support/test_pytorch_trace_contract.py index 547d765a2..6c6b152b4 100644 --- a/tests/backend_support/test_pytorch_trace_contract.py +++ b/tests/backend_support/test_pytorch_trace_contract.py @@ -1,4 +1,7 @@ import importlib.util +import os +import subprocess +import sys import pytest @@ -7,28 +10,46 @@ def test_raw_pytorch_trace_accepts_numpy_signature(): if importlib.util.find_spec("torch") is None: pytest.skip("PyTorch is not installed") - import pyrecest # noqa: F401 # triggers raw-backend compatibility patches - import torch - import pyrecest._backend.pytorch as raw_pytorch_backend - - values = torch.arange(12.0, dtype=torch.float64).reshape(2, 2, 3) - expected = torch.tensor([6.0, 18.0], dtype=torch.float64) - - result = raw_pytorch_backend.trace(values, offset=1, axis1=-2, axis2=-1) - assert torch.allclose(result, expected) - - out = torch.empty(2, dtype=torch.float64) - returned = raw_pytorch_backend.trace( - values, - offset=1, - axis1=-2, - axis2=-1, - dtype=torch.float64, - out=out, + env = os.environ.copy() + env.pop("PYRECEST_BACKEND", None) + completed = subprocess.run( + [ + sys.executable, + "-c", + """ +import pyrecest # noqa: F401 # triggers raw-backend compatibility patches +import torch +import pyrecest._backend.pytorch as raw_pytorch_backend + +values = torch.arange(12.0, dtype=torch.float64).reshape(2, 2, 3) +expected = torch.tensor([6.0, 18.0], dtype=torch.float64) + +result = raw_pytorch_backend.trace(values, offset=1, axis1=-2, axis2=-1) +assert torch.allclose(result, expected) + +out = torch.empty(2, dtype=torch.float64) +returned = raw_pytorch_backend.trace( + values, + offset=1, + axis1=-2, + axis2=-1, + dtype=torch.float64, + out=out, +) +assert returned is out +assert torch.allclose(out, expected) + +scalar_result = raw_pytorch_backend.trace([[1.0, 2.0], [3.0, 4.0]]) +assert tuple(scalar_result.shape) == () +assert float(scalar_result) == 5.0 +print("ok") +""", + ], + capture_output=True, + env=env, + text=True, + timeout=30.0, ) - assert returned is out - assert torch.allclose(out, expected) - scalar_result = raw_pytorch_backend.trace([[1.0, 2.0], [3.0, 4.0]]) - assert tuple(scalar_result.shape) == () - assert float(scalar_result) == pytest.approx(5.0) + assert completed.returncode == 0, completed.stderr + assert "ok" in completed.stdout