diff --git a/src/pyrecest/filters/sequence_association.py b/src/pyrecest/filters/sequence_association.py index d769e8737..c2f75489a 100644 --- a/src/pyrecest/filters/sequence_association.py +++ b/src/pyrecest/filters/sequence_association.py @@ -365,7 +365,10 @@ def _validate_cost(value: object, name: str) -> float: def _validate_positive_integer(value: object, name: str) -> int: message = f"{name} must be a positive integer" - value_array = np.asarray(value) + try: + value_array = np.asarray(value) + except (TypeError, ValueError) as exc: + raise ValueError(message) from exc if value_array.ndim != 0 or value_array.dtype == np.bool_: raise ValueError(message) diff --git a/tests/filters/test_sequence_association_top_k_bad_input.py b/tests/filters/test_sequence_association_top_k_bad_input.py new file mode 100644 index 000000000..78dfdf799 --- /dev/null +++ b/tests/filters/test_sequence_association_top_k_bad_input.py @@ -0,0 +1,20 @@ +import pytest + +from pyrecest.filters import SequenceAssociationNode, solve_top_k_viterbi_sequence_associations + + +class UncoercibleScalar: + def __array__(self, dtype=None): + del dtype + raise TypeError("cannot convert") + + +def test_top_k_terminal_paths_reports_value_error_for_uncoercible_scalar(): + frames = [[SequenceAssociationNode(0, 0)]] + + with pytest.raises(ValueError, match="top_k_terminal_paths must be a positive integer"): + solve_top_k_viterbi_sequence_associations( + frames, + lambda _previous, _current, _context: 0.0, + top_k_terminal_paths=UncoercibleScalar(), + )