diff --git a/src/pyrecest/_backend/pytorch/fft.py b/src/pyrecest/_backend/pytorch/fft.py index 394bac151..887ea16d6 100644 --- a/src/pyrecest/_backend/pytorch/fft.py +++ b/src/pyrecest/_backend/pytorch/fft.py @@ -18,8 +18,10 @@ def _with_dim_alias(kwargs, alias, func_name): kwargs = dict(kwargs) alias_value = kwargs.pop(alias) dim_value = kwargs.get("dim") - if dim_value is not None and alias_value is not None and dim_value != alias_value: - raise TypeError("conflicting FFT axis aliases") + if dim_value is not None: + if alias_value is not None and dim_value != alias_value: + raise TypeError("conflicting FFT axis aliases") + return kwargs kwargs["dim"] = alias_value return kwargs diff --git a/tests/backend_support/test_pytorch_fft_axis_contract.py b/tests/backend_support/test_pytorch_fft_axis_contract.py index 3300a58b9..94ae08fbd 100644 --- a/tests/backend_support/test_pytorch_fft_axis_contract.py +++ b/tests/backend_support/test_pytorch_fft_axis_contract.py @@ -37,6 +37,24 @@ def test_raw_pytorch_fft_helpers_accept_numpy_axis_aliases(): ) +@pytest.mark.backend_portable +def test_raw_pytorch_fft_none_axis_alias_preserves_explicit_dim(): + matrix = np.arange(6.0).reshape(2, 3) + + npt.assert_allclose( + pytorch_fft.fftn(matrix.tolist(), axes=None, dim=(0,)).numpy(), + np.fft.fftn(matrix, axes=(0,)), + ) + npt.assert_allclose( + pytorch_fft.ifftn( + pytorch_fft.fftn(matrix, dim=(0,)), + axes=None, + dim=(0,), + ).numpy(), + np.fft.ifftn(np.fft.fftn(matrix, axes=(0,)), axes=(0,)), + ) + + def test_raw_pytorch_fft_rejects_conflicting_axis_aliases(): with pytest.raises(TypeError): pytorch_fft.rfft(np.arange(4.0), axis=0, dim=1)