Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions src/pyrecest/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,9 +229,39 @@ def std(
backend.std = std


def _patch_jax_matmul_out_facade() -> None:
"""Make public and raw JAX ``matmul`` honor NumPy's ``out`` contract."""

import pyrecest.backend as backend # pylint: disable=import-outside-toplevel

if getattr(backend, "__backend_name__", None) != "jax":
return

try:
import pyrecest._backend.jax as jax_backend # pylint: disable=import-outside-toplevel
except (
ModuleNotFoundError
): # pragma: no cover - backend import fails first in practice
return

original_matmul = jax_backend.matmul

def matmul(x, y, out=None):
result = original_matmul(x, y, out=None)
if out is None:
return result
return backend.asarray(out).at[...].set(result)

matmul.__name__ = getattr(original_matmul, "__name__", "matmul")
matmul.__doc__ = getattr(original_matmul, "__doc__", None)
jax_backend.matmul = matmul
backend.matmul = matmul


_patch_pytorch_comparison_facade()
_patch_pytorch_tile_facade()
_patch_jax_std_out_facade()
_patch_jax_matmul_out_facade()

from pyrecest.backend_support import ( # noqa: E402,F401
backend_support,
Expand Down
50 changes: 50 additions & 0 deletions tests/backend_support/test_jax_matmul_out_contract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import importlib.util
import os
import subprocess
import sys

import pytest


@pytest.mark.backend_portable
def test_jax_matmul_honors_out_shape_contract():
if importlib.util.find_spec("jax") is None:
pytest.skip("jax is not installed")

env = os.environ.copy()
env["PYRECEST_BACKEND"] = "jax"
src_path = os.path.abspath("src")
env["PYTHONPATH"] = (
src_path
if not env.get("PYTHONPATH")
else os.pathsep.join([src_path, env["PYTHONPATH"]])
)

code = """
import pyrecest.backend as backend
import pyrecest._backend.jax as raw_jax_backend

left = [[1.0, 2.0], [3.0, 4.0]]
right = [[1.0, 0.0], [0.0, 1.0]]

out = backend.zeros((2, 2))
returned = backend.matmul(left, right, out=out)
assert backend.to_numpy(returned).tolist() == [[1.0, 2.0], [3.0, 4.0]]

bad_out = backend.zeros((1, 1))
try:
backend.matmul(left, right, out=bad_out)
except (TypeError, ValueError):
pass
else:
raise AssertionError("JAX backend.matmul ignored incompatible out shape")

raw_bad_out = raw_jax_backend.zeros((1, 1))
try:
raw_jax_backend.matmul(left, right, out=raw_bad_out)
except (TypeError, ValueError):
pass
else:
raise AssertionError("raw JAX matmul ignored incompatible out shape")
"""
subprocess.run([sys.executable, "-c", code], check=True, env=env)
Loading