diff --git a/src/pyrecest/__init__.py b/src/pyrecest/__init__.py index 9df2b5bfa..83d2c1339 100644 --- a/src/pyrecest/__init__.py +++ b/src/pyrecest/__init__.py @@ -198,6 +198,31 @@ def tile(x, reps): backend.tile = tile +def _patch_pytorch_gamma_facade() -> None: + """Make PyTorch ``gamma`` accept scalar and array-like inputs.""" + + import pyrecest.backend as backend # pylint: disable=import-outside-toplevel + + if getattr(backend, "__backend_name__", None) != "pytorch": + return + + try: + import torch as _torch # pylint: disable=import-outside-toplevel + import pyrecest._backend.pytorch as pytorch_backend # pylint: disable=import-outside-toplevel + except ( + ModuleNotFoundError + ): # pragma: no cover - backend import fails first in practice + return + + def gamma(a): + return _torch.exp(_torch.special.gammaln(backend.array(a))) + + gamma.__name__ = "gamma" + gamma.__doc__ = "Gamma function for PyTorch tensors and array-like inputs." + backend.gamma = gamma + pytorch_backend.gamma = gamma + + def _patch_jax_std_out_facade() -> None: """Make public JAX ``std`` accept NumPy's ``out`` argument.""" @@ -231,6 +256,7 @@ def std( _patch_pytorch_comparison_facade() _patch_pytorch_tile_facade() +_patch_pytorch_gamma_facade() _patch_jax_std_out_facade() from pyrecest.backend_support import ( # noqa: E402,F401 diff --git a/tests/backend_support/test_pytorch_gamma_contract.py b/tests/backend_support/test_pytorch_gamma_contract.py new file mode 100644 index 000000000..6c600f84a --- /dev/null +++ b/tests/backend_support/test_pytorch_gamma_contract.py @@ -0,0 +1,36 @@ +import importlib.util + +import pytest +from tests.support.backend_runner import run_backend_code + + +@pytest.mark.backend_portable +def test_pytorch_gamma_accepts_array_like_inputs(): + if importlib.util.find_spec("torch") is None: + pytest.skip("PyTorch is not installed") + + result = run_backend_code( + "pytorch", + """ +import math +import torch +import pyrecest.backend as backend +import pyrecest._backend.pytorch as raw_pytorch_backend + +scalar = backend.gamma(5.0) +assert torch.is_tensor(scalar) +assert tuple(scalar.shape) == () +assert backend.allclose(scalar, backend.array(24.0)) + +values = backend.gamma([0.5, 1.0, 5.0]) +expected = backend.array([math.sqrt(math.pi), 1.0, 24.0]) +assert backend.allclose(values, expected) + +raw_values = raw_pytorch_backend.gamma([1.0, 2.0, 3.0]) +assert backend.allclose(raw_values, backend.array([1.0, 1.0, 2.0])) +print("ok") +""", + ) + + assert result.returncode == 0, result.stderr + assert "ok" in result.stdout