diff --git a/src/pyrecest/optional_dependencies.py b/src/pyrecest/optional_dependencies.py index 90bb0ad85..b4d311e07 100644 --- a/src/pyrecest/optional_dependencies.py +++ b/src/pyrecest/optional_dependencies.py @@ -15,6 +15,12 @@ def _validate_nonempty_string(value: Any, name: str) -> str: return value.strip() +def _validate_optional_feature(value: Any | None) -> str | None: + if value is None: + return None + return _validate_nonempty_string(value, "feature") + + def _is_missing_requested_package(exc: ModuleNotFoundError, package: str) -> bool: missing_name = exc.name if missing_name is None: @@ -38,6 +44,7 @@ def require_optional_dependency( """ package = _validate_nonempty_string(package, "package") extra = _validate_nonempty_string(extra, "extra") + feature = _validate_optional_feature(feature) try: return importlib.import_module(package) diff --git a/tests/test_optional_dependencies.py b/tests/test_optional_dependencies.py index 2effaa5e0..e540c88f5 100644 --- a/tests/test_optional_dependencies.py +++ b/tests/test_optional_dependencies.py @@ -33,6 +33,12 @@ def test_require_optional_dependency_rejects_invalid_names(package, extra): require_optional_dependency(package, extra) +@pytest.mark.parametrize("feature", ["", " "]) +def test_require_optional_dependency_rejects_blank_feature_names(feature): + with pytest.raises(ValueError, match="feature must be a non-empty string"): + require_optional_dependency("math", "plot", feature=feature) + + def test_require_optional_dependency_reports_missing_parent_for_submodule(): import_error = ModuleNotFoundError( "No module named 'missing_parent'",