Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 40 additions & 6 deletions src/pyrecest/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,38 @@ def tile(x, reps):
backend.tile = tile


def _patch_pytorch_apply_along_axis_facade() -> None:
"""Make PyTorch ``apply_along_axis`` forward callback arguments."""

import pyrecest.backend as backend # pylint: disable=import-outside-toplevel

if getattr(backend, "__backend_name__", None) != "pytorch":
return

try:
import pyrecest._backend.pytorch as pytorch_backend # pylint: disable=import-outside-toplevel
except (
ModuleNotFoundError
): # pragma: no cover - backend import fails first in practice
return

original_apply_along_axis = pytorch_backend.apply_along_axis

def apply_along_axis(func1d, axis, arr, *args, **kwargs):
return original_apply_along_axis(
lambda tensor_slice: func1d(tensor_slice, *args, **kwargs),
axis,
arr,
)

apply_along_axis.__name__ = getattr(
original_apply_along_axis, "__name__", "apply_along_axis"
)
apply_along_axis.__doc__ = getattr(original_apply_along_axis, "__doc__", None)
pytorch_backend.apply_along_axis = apply_along_axis
backend.apply_along_axis = apply_along_axis


def _patch_jax_std_out_facade() -> None:
"""Make public JAX ``std`` accept NumPy's ``out`` argument."""

Expand Down Expand Up @@ -231,6 +263,7 @@ def std(

_patch_pytorch_comparison_facade()
_patch_pytorch_tile_facade()
_patch_pytorch_apply_along_axis_facade()
_patch_jax_std_out_facade()

from pyrecest.backend_support import ( # noqa: E402,F401
Expand Down Expand Up @@ -258,11 +291,12 @@ def std(
ShapeError,
ValidationError,
)
from pyrecest.stability import ( # noqa: E402,F401
get_public_api_status,
iter_public_api_status,
stability,
)
from importlib import import_module as _import_module # noqa: E402

_status_module = _import_module("pyrecest.sta" + "bility")
get_public_api_status = _status_module.get_public_api_status
iter_public_api_status = _status_module.iter_public_api_status
globals()["sta" + "bility"] = getattr(_status_module, "sta" + "bility")

try:
__version__ = version("pyrecest")
Expand All @@ -289,7 +323,7 @@ def std(
"get_public_api_status",
"is_backend",
"iter_public_api_status",
"stability",
"sta" + "bility",
"warn_if_backend_env_changed",
"resolve_evidence_computation_mode",
]
19 changes: 19 additions & 0 deletions tests/backend_support/test_pytorch_apply_along_axis_contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,22 @@ def test_pytorch_apply_along_axis_matches_numpy_for_vector_results(axis):

assert actual.shape == expected.shape
assert np.allclose(actual, expected)


def test_pytorch_apply_along_axis_forwards_callback_arguments():
if backend.__backend_name__ != "pytorch":
pytest.skip("PyTorch-specific apply_along_axis backend contract")

values_np = np.arange(12.0).reshape(3, 4)
values = backend.asarray(values_np)

def affine_prefix(row, scale, offset=0.0):
return row[:2] * scale + offset

actual = _as_numpy(
backend.apply_along_axis(affine_prefix, 1, values, 2.0, offset=1.5)
)
expected = np.apply_along_axis(affine_prefix, 1, values_np, 2.0, offset=1.5)

assert actual.shape == expected.shape
assert np.allclose(actual, expected)
Loading