diff --git a/src/pyrecest/__init__.py b/src/pyrecest/__init__.py index 9df2b5bfa..2eef7210d 100644 --- a/src/pyrecest/__init__.py +++ b/src/pyrecest/__init__.py @@ -229,9 +229,43 @@ def std( backend.std = std +def _patch_jax_dot_out_facade() -> None: + """Make public and raw JAX ``dot`` accept NumPy's ``out`` argument.""" + + import pyrecest.backend as backend # pylint: disable=import-outside-toplevel + + if getattr(backend, "__backend_name__", None) != "jax": + return + + import jax.numpy as _jnp # pylint: disable=import-outside-toplevel + import pyrecest._backend.jax as jax_backend # pylint: disable=import-outside-toplevel + + original_dot = backend.dot + + def dot(a, b, out=None): + result = original_dot(a, b) + if out is None: + return result + + out_array = _jnp.asarray(out) + result_shape = tuple(getattr(result, "shape", ())) + if tuple(out_array.shape) != result_shape: + raise ValueError( + "output array has wrong shape: " + f"expected {result_shape}, got {tuple(out_array.shape)}" + ) + return out_array.at[...].set(result) + + dot.__name__ = getattr(original_dot, "__name__", "dot") + dot.__doc__ = getattr(original_dot, "__doc__", None) + backend.dot = dot + jax_backend.dot = dot + + _patch_pytorch_comparison_facade() _patch_pytorch_tile_facade() _patch_jax_std_out_facade() +_patch_jax_dot_out_facade() from pyrecest.backend_support import ( # noqa: E402,F401 backend_support, diff --git a/tests/backend_support/test_jax_dot_out_contract.py b/tests/backend_support/test_jax_dot_out_contract.py new file mode 100644 index 000000000..4fb937711 --- /dev/null +++ b/tests/backend_support/test_jax_dot_out_contract.py @@ -0,0 +1,47 @@ +import importlib.util +import os +import subprocess +import sys + +import pytest + + +def test_jax_dot_accepts_out_keyword_and_validates_shape(): + if importlib.util.find_spec("jax") is None: + pytest.skip("jax is not installed") + + env = os.environ.copy() + env["PYRECEST_BACKEND"] = "jax" + src_path = os.path.abspath("src") + env["PYTHONPATH"] = ( + src_path + if not env.get("PYTHONPATH") + else os.pathsep.join([src_path, env["PYTHONPATH"]]) + ) + + code = """ +import jax.numpy as jnp +import pyrecest.backend as backend +import pyrecest._backend.jax as raw_jax + +scalar_out = jnp.zeros(()) +scalar_result = backend.dot([1.0, 2.0], [3.0, 4.0], out=scalar_out) +assert scalar_result.shape == () +assert float(scalar_result) == 11.0 + +vector_out = jnp.zeros((2,)) +vector_result = backend.dot([[1.0, 2.0], [3.0, 4.0]], [5.0, 6.0], out=vector_out) +assert vector_result.shape == (2,) +assert backend.to_numpy(vector_result).tolist() == [17.0, 39.0] + +raw_result = raw_jax.dot([1.0, 2.0], [3.0, 4.0], out=scalar_out) +assert float(raw_result) == 11.0 + +try: + backend.dot([1.0, 2.0], [3.0, 4.0], out=jnp.zeros((2,))) +except ValueError as exc: + assert "wrong shape" in str(exc) +else: + raise AssertionError("JAX dot accepted an incompatible output shape") +""" + subprocess.run([sys.executable, "-c", code], check=True, env=env)