Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions src/pyrecest/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,9 +229,43 @@ def std(
backend.std = std


def _patch_jax_dot_out_facade() -> None:
"""Make public and raw JAX ``dot`` accept NumPy's ``out`` argument."""

import pyrecest.backend as backend # pylint: disable=import-outside-toplevel

if getattr(backend, "__backend_name__", None) != "jax":
return

import jax.numpy as _jnp # pylint: disable=import-outside-toplevel
import pyrecest._backend.jax as jax_backend # pylint: disable=import-outside-toplevel

original_dot = backend.dot

def dot(a, b, out=None):
result = original_dot(a, b)
if out is None:
return result

out_array = _jnp.asarray(out)
result_shape = tuple(getattr(result, "shape", ()))
if tuple(out_array.shape) != result_shape:
raise ValueError(
"output array has wrong shape: "
f"expected {result_shape}, got {tuple(out_array.shape)}"
)
return out_array.at[...].set(result)

dot.__name__ = getattr(original_dot, "__name__", "dot")
dot.__doc__ = getattr(original_dot, "__doc__", None)
backend.dot = dot
jax_backend.dot = dot


_patch_pytorch_comparison_facade()
_patch_pytorch_tile_facade()
_patch_jax_std_out_facade()
_patch_jax_dot_out_facade()

from pyrecest.backend_support import ( # noqa: E402,F401
backend_support,
Expand Down
47 changes: 47 additions & 0 deletions tests/backend_support/test_jax_dot_out_contract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import importlib.util
import os
import subprocess
import sys

import pytest


def test_jax_dot_accepts_out_keyword_and_validates_shape():
if importlib.util.find_spec("jax") is None:
pytest.skip("jax is not installed")

env = os.environ.copy()
env["PYRECEST_BACKEND"] = "jax"
src_path = os.path.abspath("src")
env["PYTHONPATH"] = (
src_path
if not env.get("PYTHONPATH")
else os.pathsep.join([src_path, env["PYTHONPATH"]])
)

code = """
import jax.numpy as jnp
import pyrecest.backend as backend
import pyrecest._backend.jax as raw_jax

scalar_out = jnp.zeros(())
scalar_result = backend.dot([1.0, 2.0], [3.0, 4.0], out=scalar_out)
assert scalar_result.shape == ()
assert float(scalar_result) == 11.0

vector_out = jnp.zeros((2,))
vector_result = backend.dot([[1.0, 2.0], [3.0, 4.0]], [5.0, 6.0], out=vector_out)
assert vector_result.shape == (2,)
assert backend.to_numpy(vector_result).tolist() == [17.0, 39.0]

raw_result = raw_jax.dot([1.0, 2.0], [3.0, 4.0], out=scalar_out)
assert float(raw_result) == 11.0

try:
backend.dot([1.0, 2.0], [3.0, 4.0], out=jnp.zeros((2,)))
except ValueError as exc:
assert "wrong shape" in str(exc)
else:
raise AssertionError("JAX dot accepted an incompatible output shape")
"""
subprocess.run([sys.executable, "-c", code], check=True, env=env)
Loading