Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/pyrecest/evaluation/get_extract_mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,10 @@ def _extract_mtt_mean(filter_state):

def get_extract_mean(manifold_name, mtt_scenario=False):
normalized_name = _normalize_registry_name(manifold_name)
is_mtt_scenario = bool(mtt_scenario) or "mtt" in normalized_name
registered_factory = _EXTRACT_MEAN_FACTORIES.get(normalized_name)
if registered_factory is not None:
return registered_factory(manifold_name, mtt_scenario)
return registered_factory(manifold_name, is_mtt_scenario)

if "circle" in normalized_name or "hypertorus" in normalized_name:

Expand Down Expand Up @@ -111,12 +112,12 @@ def extract_mean(filter_state):
def extract_mean(filter_state):
return filter_state.hybrid_mean()

elif "euclidean" in normalized_name and not mtt_scenario:
elif "euclidean" in normalized_name and not is_mtt_scenario:

def extract_mean(filter_state):
return _point_estimate_or_mean(filter_state)

elif "euclidean" in normalized_name and mtt_scenario:
elif "euclidean" in normalized_name and is_mtt_scenario:

def extract_mean(filter_state):
return _extract_mtt_mean(filter_state)
Expand Down
6 changes: 6 additions & 0 deletions tests/evaluation/test_evaluation_registries.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,12 @@ def test_euclidean_mtt_extract_mean_uses_public_track_selection():
assert extracted == ["selected"]


def test_named_euclidean_mtt_extract_mean_uses_public_track_selection():
extracted = get_extract_mean("euclideanMTT")(_TrackManagerLikeState())

assert extracted == ["selected"]


def test_symmetric_hypersphere_extract_mean_requires_custom_extractor():
with pytest.raises(NotImplementedError, match="custom extractor"):
get_extract_mean("hypersphereSymmetric")
Expand Down
Loading