From 47faa0e05c2c29d1187e779508068620f8f09e62 Mon Sep 17 00:00:00 2001 From: Florian Pfaff <6773539+FlorianPfaff@users.noreply.github.com> Date: Mon, 29 Jun 2026 11:30:49 +0200 Subject: [PATCH 1/2] Fix raw JAX std out handling --- src/pyrecest/__init__.py | 39 +++++++++++++++++++++++++++++++++++---- 1 file changed, 35 insertions(+), 4 deletions(-) diff --git a/src/pyrecest/__init__.py b/src/pyrecest/__init__.py index 9df2b5bfa..5722d5d47 100644 --- a/src/pyrecest/__init__.py +++ b/src/pyrecest/__init__.py @@ -199,14 +199,25 @@ def tile(x, reps): def _patch_jax_std_out_facade() -> None: - """Make public JAX ``std`` accept NumPy's ``out`` argument.""" + """Make public and raw JAX ``std`` accept NumPy's ``out`` argument.""" import pyrecest.backend as backend # pylint: disable=import-outside-toplevel if getattr(backend, "__backend_name__", None) != "jax": return + try: + import pyrecest._backend.jax as raw_jax # pylint: disable=import-outside-toplevel + except ModuleNotFoundError: # pragma: no cover - backend import fails first + raw_jax = None + original_std = backend.std + original_raw_std = getattr(raw_jax, "std", None) if raw_jax is not None else None + + def _return_or_store_out(result, out): + if out is None: + return result + return backend.asarray(out).at[...].set(result) def std( a, axis=None, dtype=None, out=None, ddof=0, keepdims=False, *, correction=0 @@ -220,14 +231,34 @@ def std( keepdims=keepdims, correction=correction, ) - if out is None: - return result - return backend.asarray(out).at[...].set(result) + return _return_or_store_out(result, out) std.__name__ = getattr(original_std, "__name__", "std") std.__doc__ = getattr(original_std, "__doc__", None) backend.std = std + if original_raw_std is None: + return + + def raw_std( + a, axis=None, dtype=None, out=None, ddof=0, keepdims=False, *, correction=0 + ): + kwargs = { + "axis": axis, + "dtype": dtype, + "out": None, + "ddof": ddof, + "keepdims": keepdims, + } + if correction != 0: + kwargs["correction"] = correction + result = original_raw_std(a, **kwargs) + return _return_or_store_out(result, out) + + raw_std.__name__ = getattr(original_raw_std, "__name__", "std") + raw_std.__doc__ = getattr(original_raw_std, "__doc__", None) + raw_jax.std = raw_std + _patch_pytorch_comparison_facade() _patch_pytorch_tile_facade() From e40591cc363abab51964c00de3586a1ca9f90f5a Mon Sep 17 00:00:00 2001 From: Florian Pfaff <6773539+FlorianPfaff@users.noreply.github.com> Date: Mon, 29 Jun 2026 11:31:39 +0200 Subject: [PATCH 2/2] Add raw JAX std out regression --- tests/test_backend_std_contract.py | 29 +++++++++++++++++++++++++---- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/tests/test_backend_std_contract.py b/tests/test_backend_std_contract.py index adc7d1fd6..728d41246 100644 --- a/tests/test_backend_std_contract.py +++ b/tests/test_backend_std_contract.py @@ -47,13 +47,34 @@ def test_jax_std_accepts_out_argument(): "jax", """ import pyrecest.backend as backend +import pyrecest._backend.jax as raw_jax values = backend.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) -out = backend.zeros((1, 3), dtype=backend.float64) -result = backend.std(values, axis=0, dtype=backend.float64, out=out, ddof=1, keepdims=True) expected = backend.array([[2.1213203435596424, 2.1213203435596424, 2.1213203435596424]]) -assert tuple(result.shape) == (1, 3) -assert backend.allclose(result, expected) + +public_out = backend.zeros((1, 3), dtype=backend.float64) +public_result = backend.std( + values, + axis=0, + dtype=backend.float64, + out=public_out, + ddof=1, + keepdims=True, +) +assert tuple(public_result.shape) == (1, 3) +assert backend.allclose(public_result, expected) + +raw_out = backend.zeros((1, 3), dtype=backend.float64) +raw_result = raw_jax.std( + values, + axis=0, + dtype=backend.float64, + out=raw_out, + ddof=1, + keepdims=True, +) +assert tuple(raw_result.shape) == (1, 3) +assert backend.allclose(raw_result, expected) print("ok") """, )