Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 35 additions & 4 deletions src/pyrecest/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,14 +199,25 @@ def tile(x, reps):


def _patch_jax_std_out_facade() -> None:
"""Make public JAX ``std`` accept NumPy's ``out`` argument."""
"""Make public and raw JAX ``std`` accept NumPy's ``out`` argument."""

import pyrecest.backend as backend # pylint: disable=import-outside-toplevel

if getattr(backend, "__backend_name__", None) != "jax":
return

try:
import pyrecest._backend.jax as raw_jax # pylint: disable=import-outside-toplevel
except ModuleNotFoundError: # pragma: no cover - backend import fails first
raw_jax = None

original_std = backend.std
original_raw_std = getattr(raw_jax, "std", None) if raw_jax is not None else None

def _return_or_store_out(result, out):
if out is None:
return result
return backend.asarray(out).at[...].set(result)

def std(
a, axis=None, dtype=None, out=None, ddof=0, keepdims=False, *, correction=0
Expand All @@ -220,14 +231,34 @@ def std(
keepdims=keepdims,
correction=correction,
)
if out is None:
return result
return backend.asarray(out).at[...].set(result)
return _return_or_store_out(result, out)

std.__name__ = getattr(original_std, "__name__", "std")
std.__doc__ = getattr(original_std, "__doc__", None)
backend.std = std

if original_raw_std is None:
return

def raw_std(
a, axis=None, dtype=None, out=None, ddof=0, keepdims=False, *, correction=0
):
kwargs = {
"axis": axis,
"dtype": dtype,
"out": None,
"ddof": ddof,
"keepdims": keepdims,
}
if correction != 0:
kwargs["correction"] = correction
result = original_raw_std(a, **kwargs)
return _return_or_store_out(result, out)

raw_std.__name__ = getattr(original_raw_std, "__name__", "std")
raw_std.__doc__ = getattr(original_raw_std, "__doc__", None)
raw_jax.std = raw_std


_patch_pytorch_comparison_facade()
_patch_pytorch_tile_facade()
Expand Down
29 changes: 25 additions & 4 deletions tests/test_backend_std_contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,34 @@ def test_jax_std_accepts_out_argument():
"jax",
"""
import pyrecest.backend as backend
import pyrecest._backend.jax as raw_jax

values = backend.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
out = backend.zeros((1, 3), dtype=backend.float64)
result = backend.std(values, axis=0, dtype=backend.float64, out=out, ddof=1, keepdims=True)
expected = backend.array([[2.1213203435596424, 2.1213203435596424, 2.1213203435596424]])
assert tuple(result.shape) == (1, 3)
assert backend.allclose(result, expected)

public_out = backend.zeros((1, 3), dtype=backend.float64)
public_result = backend.std(
values,
axis=0,
dtype=backend.float64,
out=public_out,
ddof=1,
keepdims=True,
)
assert tuple(public_result.shape) == (1, 3)
assert backend.allclose(public_result, expected)

raw_out = backend.zeros((1, 3), dtype=backend.float64)
raw_result = raw_jax.std(
values,
axis=0,
dtype=backend.float64,
out=raw_out,
ddof=1,
keepdims=True,
)
assert tuple(raw_result.shape) == (1, 3)
assert backend.allclose(raw_result, expected)
print("ok")
""",
)
Expand Down
Loading