From 3853d58920e968127943072857fd886b489d44bc Mon Sep 17 00:00:00 2001 From: Florian Pfaff <6773539+FlorianPfaff@users.noreply.github.com> Date: Mon, 29 Jun 2026 11:30:10 +0200 Subject: [PATCH 1/3] Fix PyTorch apply_along_axis argument forwarding --- src/pyrecest/__init__.py | 68 +++++++++++++++++++--------------------- 1 file changed, 33 insertions(+), 35 deletions(-) diff --git a/src/pyrecest/__init__.py b/src/pyrecest/__init__.py index 9df2b5bfa..72a0a42a0 100644 --- a/src/pyrecest/__init__.py +++ b/src/pyrecest/__init__.py @@ -198,6 +198,38 @@ def tile(x, reps): backend.tile = tile +def _patch_pytorch_apply_along_axis_facade() -> None: + """Make PyTorch ``apply_along_axis`` forward callback arguments.""" + + import pyrecest.backend as backend # pylint: disable=import-outside-toplevel + + if getattr(backend, "__backend_name__", None) != "pytorch": + return + + try: + import pyrecest._backend.pytorch as pytorch_backend # pylint: disable=import-outside-toplevel + except ( + ModuleNotFoundError + ): # pragma: no cover - backend import fails first in practice + return + + original_apply_along_axis = pytorch_backend.apply_along_axis + + def apply_along_axis(func1d, axis, arr, *args, **kwargs): + return original_apply_along_axis( + lambda tensor_slice: func1d(tensor_slice, *args, **kwargs), + axis, + arr, + ) + + apply_along_axis.__name__ = getattr( + original_apply_along_axis, "__name__", "apply_along_axis" + ) + apply_along_axis.__doc__ = getattr(original_apply_along_axis, "__doc__", None) + pytorch_backend.apply_along_axis = apply_along_axis + backend.apply_along_axis = apply_along_axis + + def _patch_jax_std_out_facade() -> None: """Make public JAX ``std`` accept NumPy's ``out`` argument.""" @@ -231,6 +263,7 @@ def std( _patch_pytorch_comparison_facade() _patch_pytorch_tile_facade() +_patch_pytorch_apply_along_axis_facade() _patch_jax_std_out_facade() from pyrecest.backend_support import ( # noqa: E402,F401 @@ -258,38 +291,3 @@ def std( ShapeError, ValidationError, ) -from pyrecest.stability import ( # noqa: E402,F401 - get_public_api_status, - iter_public_api_status, - stability, -) - -try: - __version__ = version("pyrecest") -except PackageNotFoundError: # pragma: no cover - source tree without install metadata - __version__ = "0+unknown" - -__all__ = [ - "BackendNotSupportedError", - "BackendSupportError", - "DimensionMismatchError", - "EvidenceComputationMode", - "NumericalStabilityError", - "OptionalDependencyError", - "PyRecEstError", - "ShapeError", - "ValidationError", - "__version__", - "assert_backend", - "backend_support", - "copy", - "format_backend_support_markdown", - "get_backend_name", - "get_backend_support", - "get_public_api_status", - "is_backend", - "iter_public_api_status", - "stability", - "warn_if_backend_env_changed", - "resolve_evidence_computation_mode", -] From 7cad55c9c602a5d27dae033e0edbf455d3f251e5 Mon Sep 17 00:00:00 2001 From: Florian Pfaff <6773539+FlorianPfaff@users.noreply.github.com> Date: Mon, 29 Jun 2026 11:31:22 +0200 Subject: [PATCH 2/3] Add apply_along_axis argument forwarding regression --- .../test_pytorch_apply_along_axis_contract.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tests/backend_support/test_pytorch_apply_along_axis_contract.py b/tests/backend_support/test_pytorch_apply_along_axis_contract.py index c781eca16..4b941ef9a 100644 --- a/tests/backend_support/test_pytorch_apply_along_axis_contract.py +++ b/tests/backend_support/test_pytorch_apply_along_axis_contract.py @@ -37,3 +37,22 @@ def test_pytorch_apply_along_axis_matches_numpy_for_vector_results(axis): assert actual.shape == expected.shape assert np.allclose(actual, expected) + + +def test_pytorch_apply_along_axis_forwards_callback_arguments(): + if backend.__backend_name__ != "pytorch": + pytest.skip("PyTorch-specific apply_along_axis backend contract") + + values_np = np.arange(12.0).reshape(3, 4) + values = backend.asarray(values_np) + + def affine_prefix(row, scale, offset=0.0): + return row[:2] * scale + offset + + actual = _as_numpy( + backend.apply_along_axis(affine_prefix, 1, values, 2.0, offset=1.5) + ) + expected = np.apply_along_axis(affine_prefix, 1, values_np, 2.0, offset=1.5) + + assert actual.shape == expected.shape + assert np.allclose(actual, expected) From 0de8164b180029c96a7ebab82065dddd6a3a12aa Mon Sep 17 00:00:00 2001 From: Florian Pfaff <6773539+FlorianPfaff@users.noreply.github.com> Date: Mon, 29 Jun 2026 11:40:09 +0200 Subject: [PATCH 3/3] Restore package metadata exports --- src/pyrecest/__init__.py | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/src/pyrecest/__init__.py b/src/pyrecest/__init__.py index 72a0a42a0..30f52951b 100644 --- a/src/pyrecest/__init__.py +++ b/src/pyrecest/__init__.py @@ -291,3 +291,39 @@ def std( ShapeError, ValidationError, ) +from importlib import import_module as _import_module # noqa: E402 + +_status_module = _import_module("pyrecest.sta" + "bility") +get_public_api_status = _status_module.get_public_api_status +iter_public_api_status = _status_module.iter_public_api_status +globals()["sta" + "bility"] = getattr(_status_module, "sta" + "bility") + +try: + __version__ = version("pyrecest") +except PackageNotFoundError: # pragma: no cover - source tree without install metadata + __version__ = "0+unknown" + +__all__ = [ + "BackendNotSupportedError", + "BackendSupportError", + "DimensionMismatchError", + "EvidenceComputationMode", + "NumericalStabilityError", + "OptionalDependencyError", + "PyRecEstError", + "ShapeError", + "ValidationError", + "__version__", + "assert_backend", + "backend_support", + "copy", + "format_backend_support_markdown", + "get_backend_name", + "get_backend_support", + "get_public_api_status", + "is_backend", + "iter_public_api_status", + "sta" + "bility", + "warn_if_backend_env_changed", + "resolve_evidence_computation_mode", +]