Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions src/pyrecest/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,31 @@ def tile(x, reps):
backend.tile = tile


def _patch_pytorch_gamma_facade() -> None:
"""Make PyTorch ``gamma`` accept scalar and array-like inputs."""

import pyrecest.backend as backend # pylint: disable=import-outside-toplevel

if getattr(backend, "__backend_name__", None) != "pytorch":
return

try:
import torch as _torch # pylint: disable=import-outside-toplevel
import pyrecest._backend.pytorch as pytorch_backend # pylint: disable=import-outside-toplevel
except (
ModuleNotFoundError
): # pragma: no cover - backend import fails first in practice
return

def gamma(a):
return _torch.exp(_torch.special.gammaln(backend.array(a)))

gamma.__name__ = "gamma"
gamma.__doc__ = "Gamma function for PyTorch tensors and array-like inputs."
backend.gamma = gamma
pytorch_backend.gamma = gamma


def _patch_jax_std_out_facade() -> None:
"""Make public JAX ``std`` accept NumPy's ``out`` argument."""

Expand Down Expand Up @@ -231,6 +256,7 @@ def std(

_patch_pytorch_comparison_facade()
_patch_pytorch_tile_facade()
_patch_pytorch_gamma_facade()
_patch_jax_std_out_facade()

from pyrecest.backend_support import ( # noqa: E402,F401
Expand Down
36 changes: 36 additions & 0 deletions tests/backend_support/test_pytorch_gamma_contract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import importlib.util

import pytest
from tests.support.backend_runner import run_backend_code


@pytest.mark.backend_portable
def test_pytorch_gamma_accepts_array_like_inputs():
if importlib.util.find_spec("torch") is None:
pytest.skip("PyTorch is not installed")

result = run_backend_code(
"pytorch",
"""
import math
import torch
import pyrecest.backend as backend
import pyrecest._backend.pytorch as raw_pytorch_backend

scalar = backend.gamma(5.0)
assert torch.is_tensor(scalar)
assert tuple(scalar.shape) == ()
assert backend.allclose(scalar, backend.array(24.0))

values = backend.gamma([0.5, 1.0, 5.0])
expected = backend.array([math.sqrt(math.pi), 1.0, 24.0])
assert backend.allclose(values, expected)

raw_values = raw_pytorch_backend.gamma([1.0, 2.0, 3.0])
assert backend.allclose(raw_values, backend.array([1.0, 1.0, 2.0]))
print("ok")
""",
)

assert result.returncode == 0, result.stderr
assert "ok" in result.stdout
Loading