From a9aea0b40463820ffc5c837bfd6f3359784d760e Mon Sep 17 00:00:00 2001 From: Florian Pfaff <6773539+FlorianPfaff@users.noreply.github.com> Date: Mon, 29 Jun 2026 10:31:25 +0200 Subject: [PATCH 1/2] Recognize MTT suffix in mean extraction --- src/pyrecest/evaluation/get_extract_mean.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/pyrecest/evaluation/get_extract_mean.py b/src/pyrecest/evaluation/get_extract_mean.py index 0613bd422..a680ddf4a 100644 --- a/src/pyrecest/evaluation/get_extract_mean.py +++ b/src/pyrecest/evaluation/get_extract_mean.py @@ -71,9 +71,10 @@ def _extract_mtt_mean(filter_state): def get_extract_mean(manifold_name, mtt_scenario=False): normalized_name = _normalize_registry_name(manifold_name) + is_mtt_scenario = bool(mtt_scenario) or "mtt" in normalized_name registered_factory = _EXTRACT_MEAN_FACTORIES.get(normalized_name) if registered_factory is not None: - return registered_factory(manifold_name, mtt_scenario) + return registered_factory(manifold_name, is_mtt_scenario) if "circle" in normalized_name or "hypertorus" in normalized_name: @@ -111,12 +112,12 @@ def extract_mean(filter_state): def extract_mean(filter_state): return filter_state.hybrid_mean() - elif "euclidean" in normalized_name and not mtt_scenario: + elif "euclidean" in normalized_name and not is_mtt_scenario: def extract_mean(filter_state): return _point_estimate_or_mean(filter_state) - elif "euclidean" in normalized_name and mtt_scenario: + elif "euclidean" in normalized_name and is_mtt_scenario: def extract_mean(filter_state): return _extract_mtt_mean(filter_state) From 3902c10ae0675e1861fa56a055412b51a8bac479 Mon Sep 17 00:00:00 2001 From: Florian Pfaff <6773539+FlorianPfaff@users.noreply.github.com> Date: Mon, 29 Jun 2026 10:32:01 +0200 Subject: [PATCH 2/2] Cover named Euclidean MTT mean extraction --- tests/evaluation/test_evaluation_registries.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/evaluation/test_evaluation_registries.py b/tests/evaluation/test_evaluation_registries.py index f8777fb82..85c83c17a 100644 --- a/tests/evaluation/test_evaluation_registries.py +++ b/tests/evaluation/test_evaluation_registries.py @@ -194,6 +194,12 @@ def test_euclidean_mtt_extract_mean_uses_public_track_selection(): assert extracted == ["selected"] +def test_named_euclidean_mtt_extract_mean_uses_public_track_selection(): + extracted = get_extract_mean("euclideanMTT")(_TrackManagerLikeState()) + + assert extracted == ["selected"] + + def test_symmetric_hypersphere_extract_mean_requires_custom_extractor(): with pytest.raises(NotImplementedError, match="custom extractor"): get_extract_mean("hypersphereSymmetric")