From 85eb3ade354158cafd71740fc0d25875fbf38162 Mon Sep 17 00:00:00 2001 From: Oleksii Lubynets Date: Wed, 17 Jun 2026 17:18:48 +0200 Subject: [PATCH] check for numInputNodes == input.size() in getModelOutput() --- Tools/ML/MlResponse.h | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/Tools/ML/MlResponse.h b/Tools/ML/MlResponse.h index 4b5d976cb77..da75eff66e1 100644 --- a/Tools/ML/MlResponse.h +++ b/Tools/ML/MlResponse.h @@ -151,14 +151,8 @@ class MlResponse void init(bool enableOptimizations = false, int threads = 0) { uint8_t counterModel{0}; - const int numCachedIndices = static_cast(mCachedIndices.size()); for (const auto& path : mPaths) { mModels[counterModel].initModel(path, enableOptimizations, threads); - const int numInputNodes = mModels[counterModel].getNumInputNodes(); - if (numInputNodes != numCachedIndices) { - LOG(fatal) << "Number of input nodes in the model " << path << " is different from the number of input features indices (" << numInputNodes << " vs " << numCachedIndices << ")"; - return; - } ++counterModel; } } @@ -188,6 +182,13 @@ class MlResponse LOG(fatal) << "Model index " << nModel << " is out of range! The number of initialised models is " << mModels.size() << ". Please check your configurables."; } + const int numInputNodes = mModels[nModel].getNumInputNodes(); + const int numInputFeatures = static_cast(input.size()); + + if (numInputNodes != numInputFeatures) { + LOG(fatal) << "Number of input nodes in the model " << mPaths[nModel] << " is different from the number of input features to be tested (" << numInputNodes << " vs " << numInputFeatures << ")"; + } + TypeOutputScore* outputPtr = mModels[nModel].template evalModel(input); return std::vector{outputPtr, outputPtr + mNClasses}; }