From 19354933a87b3f05f6ee6bc5f7a4821cacce4a54 Mon Sep 17 00:00:00 2001 From: Florian Pfaff <6773539+FlorianPfaff@users.noreply.github.com> Date: Mon, 29 Jun 2026 11:08:50 +0200 Subject: [PATCH 1/2] Fix JAX scatter_add coordinate indexing --- src/pyrecest/__init__.py | 63 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) diff --git a/src/pyrecest/__init__.py b/src/pyrecest/__init__.py index 9df2b5bfa..32ea9995b 100644 --- a/src/pyrecest/__init__.py +++ b/src/pyrecest/__init__.py @@ -229,9 +229,72 @@ def std( backend.std = std +def _patch_jax_scatter_add_facade() -> None: + """Make JAX ``scatter_add`` keep all non-scatter-axis coordinates.""" + + import pyrecest.backend as backend # pylint: disable=import-outside-toplevel + + if getattr(backend, "__backend_name__", None) != "jax": + return + + try: + import jax.numpy as _jnp # pylint: disable=import-outside-toplevel + import pyrecest._backend.jax as jax_backend # pylint: disable=import-outside-toplevel + except ModuleNotFoundError: # pragma: no cover - backend import fails first in practice + return + + original_scatter_add = jax_backend.scatter_add + + def _coordinate_indices(input_shape, dim, index): + coordinates = [] + for axis, axis_size in enumerate(input_shape): + if axis == dim: + coordinates.append(index) + continue + if axis >= index.ndim: + raise ValueError( + "index must have one dimension per input axis for scatter_add" + ) + coordinate_shape = ( + (1,) * axis + (axis_size,) + (1,) * (index.ndim - axis - 1) + ) + coordinates.append( + _jnp.broadcast_to( + _jnp.arange(axis_size).reshape(coordinate_shape), index.shape + ) + ) + return tuple(coordinates) + + def scatter_add(input, dim, index, src): + input = _jnp.asarray(input) + index = _jnp.asarray(index) + src = _jnp.asarray(src, dtype=input.dtype) + dim = int(dim) + if dim < 0: + dim += input.ndim + if dim < 0 or dim >= input.ndim: + raise IndexError(f"dim {dim} is out of bounds for array of dimension {input.ndim}") + if index.ndim == 0: + return input.at[index].add(src) + if index.ndim != input.ndim: + if dim == 1 and input.ndim == 2 and index.ndim == 1: + row_indices = _jnp.arange(input.shape[0]) + return input.at[row_indices, index].add(src) + raise ValueError( + "index must have one dimension per input axis for scatter_add" + ) + return input.at[_coordinate_indices(input.shape, dim, index)].add(src) + + scatter_add.__name__ = getattr(original_scatter_add, "__name__", "scatter_add") + scatter_add.__doc__ = getattr(original_scatter_add, "__doc__", None) + backend.scatter_add = scatter_add + jax_backend.scatter_add = scatter_add + + _patch_pytorch_comparison_facade() _patch_pytorch_tile_facade() _patch_jax_std_out_facade() +_patch_jax_scatter_add_facade() from pyrecest.backend_support import ( # noqa: E402,F401 backend_support, From a2a5c07df1b5f77512ebcfc015f3f3604e379039 Mon Sep 17 00:00:00 2001 From: Florian Pfaff <6773539+FlorianPfaff@users.noreply.github.com> Date: Mon, 29 Jun 2026 11:09:22 +0200 Subject: [PATCH 2/2] Add JAX scatter_add dim-zero regression test --- .../test_jax_scatter_add_dim0.py | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) create mode 100644 tests/backend_support/test_jax_scatter_add_dim0.py diff --git a/tests/backend_support/test_jax_scatter_add_dim0.py b/tests/backend_support/test_jax_scatter_add_dim0.py new file mode 100644 index 000000000..3ff2fdb1f --- /dev/null +++ b/tests/backend_support/test_jax_scatter_add_dim0.py @@ -0,0 +1,21 @@ +import pytest + +from tests.support.backend_runner import run_backend_code + + +def test_jax_scatter_add_dim_zero_uses_remaining_axis_coordinates(): + pytest.importorskip("jax") + code = """ +import pyrecest.backend as backend + +values = backend.zeros((2, 3)) +indices = backend.asarray([[0, 1, 0], [1, 0, 1]], dtype=backend.int64) +updates = backend.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + +result = backend.scatter_add(values, 0, indices, updates) + +assert backend.to_numpy(result).tolist() == [[1.0, 5.0, 3.0], [4.0, 2.0, 6.0]] +""" + result = run_backend_code("jax", code) + + assert result.returncode == 0, result.stderr