diff --git a/src/pyrecest/backend_tools.py b/src/pyrecest/backend_tools.py index edb09ba5b..0703b8ad0 100644 --- a/src/pyrecest/backend_tools.py +++ b/src/pyrecest/backend_tools.py @@ -34,7 +34,7 @@ def _normalize_expected_backend_names( not isinstance(name, str) or not name or name.strip() != name for name in names ): raise ValueError(message) - return names + return tuple(dict.fromkeys(names)) def assert_backend(expected: str | tuple[str, ...]) -> None: diff --git a/tests/test_backend_tools.py b/tests/test_backend_tools.py index b02416994..85dfa1f99 100644 --- a/tests/test_backend_tools.py +++ b/tests/test_backend_tools.py @@ -19,6 +19,16 @@ def test_assert_backend_rejects_unexpected_backend(): pyrecest.assert_backend(unexpected) +def test_assert_backend_deduplicates_expected_names_in_error_message(): + active = pyrecest.get_backend_name() + unexpected = "jax" if active != "jax" else "numpy" + + with pytest.raises(RuntimeError) as excinfo: + pyrecest.assert_backend((unexpected, unexpected)) + + assert str(excinfo.value).count(unexpected) == 1 + + @pytest.mark.parametrize( "expected", [(), ("",), " ", (" numpy",), ("numpy ",), ("numpy", 1), 1] )