Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions src/pyrecest/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,9 +229,72 @@ def std(
backend.std = std


def _patch_jax_scatter_add_facade() -> None:
"""Make JAX ``scatter_add`` keep all non-scatter-axis coordinates."""

import pyrecest.backend as backend # pylint: disable=import-outside-toplevel

if getattr(backend, "__backend_name__", None) != "jax":
return

try:
import jax.numpy as _jnp # pylint: disable=import-outside-toplevel
import pyrecest._backend.jax as jax_backend # pylint: disable=import-outside-toplevel
except ModuleNotFoundError: # pragma: no cover - backend import fails first in practice
return

original_scatter_add = jax_backend.scatter_add

def _coordinate_indices(input_shape, dim, index):
coordinates = []
for axis, axis_size in enumerate(input_shape):
if axis == dim:
coordinates.append(index)
continue
if axis >= index.ndim:
raise ValueError(
"index must have one dimension per input axis for scatter_add"
)
coordinate_shape = (
(1,) * axis + (axis_size,) + (1,) * (index.ndim - axis - 1)
)
coordinates.append(
_jnp.broadcast_to(
_jnp.arange(axis_size).reshape(coordinate_shape), index.shape
)
)
return tuple(coordinates)

def scatter_add(input, dim, index, src):
input = _jnp.asarray(input)
index = _jnp.asarray(index)
src = _jnp.asarray(src, dtype=input.dtype)
dim = int(dim)
if dim < 0:
dim += input.ndim
if dim < 0 or dim >= input.ndim:
raise IndexError(f"dim {dim} is out of bounds for array of dimension {input.ndim}")
if index.ndim == 0:
return input.at[index].add(src)
if index.ndim != input.ndim:
if dim == 1 and input.ndim == 2 and index.ndim == 1:
row_indices = _jnp.arange(input.shape[0])
return input.at[row_indices, index].add(src)
raise ValueError(
"index must have one dimension per input axis for scatter_add"
)
return input.at[_coordinate_indices(input.shape, dim, index)].add(src)

scatter_add.__name__ = getattr(original_scatter_add, "__name__", "scatter_add")
scatter_add.__doc__ = getattr(original_scatter_add, "__doc__", None)
backend.scatter_add = scatter_add
jax_backend.scatter_add = scatter_add


_patch_pytorch_comparison_facade()
_patch_pytorch_tile_facade()
_patch_jax_std_out_facade()
_patch_jax_scatter_add_facade()

from pyrecest.backend_support import ( # noqa: E402,F401
backend_support,
Expand Down
21 changes: 21 additions & 0 deletions tests/backend_support/test_jax_scatter_add_dim0.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import pytest

from tests.support.backend_runner import run_backend_code


def test_jax_scatter_add_dim_zero_uses_remaining_axis_coordinates():
pytest.importorskip("jax")
code = """
import pyrecest.backend as backend

values = backend.zeros((2, 3))
indices = backend.asarray([[0, 1, 0], [1, 0, 1]], dtype=backend.int64)
updates = backend.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])

result = backend.scatter_add(values, 0, indices, updates)

assert backend.to_numpy(result).tolist() == [[1.0, 5.0, 3.0], [4.0, 2.0, 6.0]]
"""
result = run_backend_code("jax", code)

assert result.returncode == 0, result.stderr
Loading