Skip to content

Raise qmv batch limit for large matrices on M5-class GPUs#3791

Open
pierre427 wants to merge 1 commit into
ml-explore:mainfrom
pierre427:g17-qmv-batch-limit
Open

Raise qmv batch limit for large matrices on M5-class GPUs#3791
pierre427 wants to merge 1 commit into
ml-explore:mainfrom
pierre427:g17-qmv-batch-limit

Conversation

@pierre427

Copy link
Copy Markdown

Proposed changes

get_qmv_batch_limit currently sends gen-17 (M5-class) GPUs down the generic
non-'d' fallback, which caps the qmv/qmv_wide path at M<10 for large
matrices. Measured on an M5 Max (applegpu_g17s), the qmv_wide kernels stay
ahead of the qmm path well past that limit on large shapes. This PR raises
the large-matrix limit to 16 for arch_gen >= 17 non-'d' devices; small and
mid shapes keep the existing fallback values (no data to justify changing
them).

This mitigates the small-M quantized-matmul cost step discussed in #3553 for
the M=11–16 band, which matters for speculative decoding: verification of
multi-candidate / tree / long prompt-lookup proposals lands exactly in that
window. #2031 notes these limits are empirical and machine-dependent; gen-17
hardware postdates the tunings there.

Measurements

Apple M5 Max, 128GB (applegpu_g17s), macOS 26. Whole-model forward cost vs
row count M against a warm 400-token KV cache, min of 6 reps. Model: 70B dense
llama-arch (LLM360 K2-V2), affine 4-bit, group_size 64 (D=8192; O spans 1024 /
8192 / 28672 / 250112 across layers).

M main @ e9463bb (limit 10) this PR (limit 16)
1 76.4 ms 75.9 ms
3 85.0 84.8
5 108.2 112.7
8 167.9 170.8
12 294.6 (qmm) 241.5 (qmv_wide, −18%)
16 296.0 (qmv_wide)

Same model at affine 8-bit: M=12 326.7 → 304.1 (−7%).

For reference, the 0.31.2 release (no qmv_wide kernels) costs 298.5 ms at
M=12 on the same 4-bit model, and end-to-end speculative decoding on this
hardware improves from ~18 tok/s (release) to 24.5 tok/s (main + k=3 chain
with a 0.6B draft), with the M=11–16 window enabling wider verification
schemes this change dispatches to the faster path.

Repro

import time
import mlx.core as mx
from mlx_lm import load
from mlx_lm.models.cache import make_prompt_cache

model, _ = load("<4-bit 70B model>")
caches = make_prompt_cache(model)
model(mx.array([[i % 1000 + 100 for i in range(400)]]), cache=caches)
mx.eval([c.state[0] for c in caches])
for n in (1, 3, 5, 8, 12, 16):
    ids = mx.array([[200 + i for i in range(n)]])
    ts = []
    for _ in range(6):
        t0 = time.perf_counter()
        mx.eval(model(ids, cache=caches))
        ts.append(time.perf_counter() - t0)
        for c in caches:
            c.trim(n)
    print(f"M={n}: {min(ts)*1e3:.1f} ms")

Happy to run additional shapes/configs on this hardware if useful for tuning
the small/mid buckets too.

On g17 hardware the qmv_wide kernels stay ahead of the qmm path well past
the generic non-'d' fallback limit of 10. Measured on M5 Max (g17s) with
affine 4-bit/8-bit weights (group_size 64) at D=8192,
O in {1024, 8192, 28672, 250112}: qmv_wide is ~18% faster at M=12
(241.5ms vs 294.6ms whole-model forward on a 70B) and still ahead at M=16.
Raises the large-matrix limit to 16 for gen>=17 non-'d' GPUs; small and
mid shapes keep the existing fallback values. Mitigates the small-M
speculative-decoding verification cost discussed in ml-explore#3553.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant