Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions src/pyrecest/_backend_submodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import sys
from functools import wraps
from operator import index as _operator_index
from types import ModuleType

from pyrecest._backend import BACKEND_ATTRIBUTES
Expand Down Expand Up @@ -50,12 +51,85 @@ def _adapt_cumulative_out_contract(backend: ModuleType) -> None:
setattr(backend, attribute_name, _cumulative_with_out(cumulative))


def _adapt_pytorch_repeat_contract(backend: ModuleType) -> None:
"""Adapt PyTorch repeat to NumPy's ``axis`` keyword contract."""
if getattr(backend, "__backend_name__", None) != "pytorch":
return

current_repeat = getattr(backend, "repeat", None)
if current_repeat is not None and getattr(
current_repeat, "_pyrecest_axis_contract", False
):
return

try:
import pyrecest._backend.pytorch as pytorch_backend # pylint: disable=import-outside-toplevel
import torch as _torch # pylint: disable=import-outside-toplevel
except ModuleNotFoundError: # pragma: no cover - backend import fails first in practice
return

integer_dtypes = {
_torch.uint8,
_torch.int8,
_torch.int16,
_torch.int32,
_torch.int64,
}

def _repeat_count(repeats, *, device):
if _torch.is_tensor(repeats):
repeats_tensor = repeats.to(device=device)
else:
repeats_tensor = _torch.as_tensor(repeats, device=device)

if repeats_tensor.ndim == 0:
try:
return _operator_index(repeats_tensor.item())
except TypeError as exc:
raise TypeError("repeats must be integers") from exc

if repeats_tensor.dtype not in integer_dtypes:
raise TypeError("repeats must be integers")
return repeats_tensor.to(dtype=_torch.long)

def _repeat_axis(axis, ndim):
axis = _operator_index(axis)
if axis < 0:
axis += ndim
if axis < 0 or axis >= ndim:
raise IndexError(f"axis {axis} is out of bounds for array of dimension {ndim}")
return axis

def repeat(a, repeats, axis=None, *, dim=None):
if dim is not None:
if axis is not None and axis != dim:
raise TypeError("repeat() got both 'axis' and 'dim'")
axis = dim

values = backend.array(a)
if axis is None:
values = values.flatten()
axis = 0
else:
axis = _repeat_axis(axis, values.ndim)

repeat_count = _repeat_count(repeats, device=values.device)
return _torch.repeat_interleave(values, repeat_count, dim=axis)

repeat.__name__ = "repeat"
repeat.__doc__ = getattr(_torch.repeat_interleave, "__doc__", None)
repeat._pyrecest_axis_contract = True
backend.repeat = repeat
pytorch_backend.repeat = repeat


def register_backend_submodules(backend: ModuleType | None = None) -> None:
"""Register virtual backend submodules for standard import statements."""
if backend is None:
import pyrecest.backend as backend # pylint: disable=import-outside-toplevel

_adapt_cumulative_out_contract(backend)
_adapt_pytorch_repeat_contract(backend)

backend.__path__ = getattr(backend, "__path__", [])
backend_spec = getattr(backend, "__spec__", None)
Expand Down
Loading