From 9ba4dfef866b9827a0539aca42992cd66f7a7f0b Mon Sep 17 00:00:00 2001 From: Nikhil Dev Goyal Date: Fri, 19 Jun 2026 06:37:40 -0700 Subject: [PATCH] [Gemma] Fix weights loader assertion crash in kReadBF16 mode. PiperOrigin-RevId: 934904524 --- gemma/weights.cc | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/gemma/weights.cc b/gemma/weights.cc index 4903a483..29330e83 100644 --- a/gemma/weights.cc +++ b/gemma/weights.cc @@ -515,6 +515,7 @@ struct TensorToRead { // only for kReadBF16 bool keep_type = false; Type prev_type; + size_t prev_packed_bytes = 0; }; // Allocates multiple in parallel and binds to NUMA nodes. @@ -533,6 +534,7 @@ static void AllocateAndBindAll(std::vector& tensors, MatPtr& mat = *tensor.mat; tensor.prev_type = mat.GetType(); + tensor.prev_packed_bytes = mat.PackedBytes(); // We only care about MatMul inputs; skip F32 or small tensors. if (tensor.prev_type == Type::kF32 || mat.Rows() < 1024) { tensor.keep_type = true; @@ -596,7 +598,8 @@ static void ReadAllToBF16(const std::vector& tensors, // Validate blob size matches allocated buffer before any read. // MapAll (line ~557) and MakeBatches (line ~645) both assert this; // this path was the only one missing the check. - HWY_ASSERT_M(tensor.range.bytes == mat.PackedBytes(), mat.Name()); + HWY_ASSERT_M(tensor.range.bytes == tensor.prev_packed_bytes, + mat.Name()); if (tensor.keep_type) { HWY_ASSERT(reader.file().Read(