From 1f251cd0076042f6eec6169ffd3105a69bf94863 Mon Sep 17 00:00:00 2001 From: Florian Pfaff <6773539+FlorianPfaff@users.noreply.github.com> Date: Mon, 29 Jun 2026 11:46:50 +0200 Subject: [PATCH 1/2] Fix JAX matmul out contract --- src/pyrecest/__init__.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/src/pyrecest/__init__.py b/src/pyrecest/__init__.py index 9df2b5bfa..8918b1d7e 100644 --- a/src/pyrecest/__init__.py +++ b/src/pyrecest/__init__.py @@ -229,9 +229,39 @@ def std( backend.std = std +def _patch_jax_matmul_out_facade() -> None: + """Make public and raw JAX ``matmul`` honor NumPy's ``out`` contract.""" + + import pyrecest.backend as backend # pylint: disable=import-outside-toplevel + + if getattr(backend, "__backend_name__", None) != "jax": + return + + try: + import pyrecest._backend.jax as jax_backend # pylint: disable=import-outside-toplevel + except ( + ModuleNotFoundError + ): # pragma: no cover - backend import fails first in practice + return + + original_matmul = jax_backend.matmul + + def matmul(x, y, out=None): + result = original_matmul(x, y, out=None) + if out is None: + return result + return backend.asarray(out).at[...].set(result) + + matmul.__name__ = getattr(original_matmul, "__name__", "matmul") + matmul.__doc__ = getattr(original_matmul, "__doc__", None) + jax_backend.matmul = matmul + backend.matmul = matmul + + _patch_pytorch_comparison_facade() _patch_pytorch_tile_facade() _patch_jax_std_out_facade() +_patch_jax_matmul_out_facade() from pyrecest.backend_support import ( # noqa: E402,F401 backend_support, From 2af93087e3bff764dc277a9a6d25f01669546b93 Mon Sep 17 00:00:00 2001 From: Florian Pfaff <6773539+FlorianPfaff@users.noreply.github.com> Date: Mon, 29 Jun 2026 11:47:04 +0200 Subject: [PATCH 2/2] Add JAX matmul out regression test --- .../test_jax_matmul_out_contract.py | 50 +++++++++++++++++++ 1 file changed, 50 insertions(+) create mode 100644 tests/backend_support/test_jax_matmul_out_contract.py diff --git a/tests/backend_support/test_jax_matmul_out_contract.py b/tests/backend_support/test_jax_matmul_out_contract.py new file mode 100644 index 000000000..dd0a92632 --- /dev/null +++ b/tests/backend_support/test_jax_matmul_out_contract.py @@ -0,0 +1,50 @@ +import importlib.util +import os +import subprocess +import sys + +import pytest + + +@pytest.mark.backend_portable +def test_jax_matmul_honors_out_shape_contract(): + if importlib.util.find_spec("jax") is None: + pytest.skip("jax is not installed") + + env = os.environ.copy() + env["PYRECEST_BACKEND"] = "jax" + src_path = os.path.abspath("src") + env["PYTHONPATH"] = ( + src_path + if not env.get("PYTHONPATH") + else os.pathsep.join([src_path, env["PYTHONPATH"]]) + ) + + code = """ +import pyrecest.backend as backend +import pyrecest._backend.jax as raw_jax_backend + +left = [[1.0, 2.0], [3.0, 4.0]] +right = [[1.0, 0.0], [0.0, 1.0]] + +out = backend.zeros((2, 2)) +returned = backend.matmul(left, right, out=out) +assert backend.to_numpy(returned).tolist() == [[1.0, 2.0], [3.0, 4.0]] + +bad_out = backend.zeros((1, 1)) +try: + backend.matmul(left, right, out=bad_out) +except (TypeError, ValueError): + pass +else: + raise AssertionError("JAX backend.matmul ignored incompatible out shape") + +raw_bad_out = raw_jax_backend.zeros((1, 1)) +try: + raw_jax_backend.matmul(left, right, out=raw_bad_out) +except (TypeError, ValueError): + pass +else: + raise AssertionError("raw JAX matmul ignored incompatible out shape") +""" + subprocess.run([sys.executable, "-c", code], check=True, env=env)