Decompose small kernel-depth 3D convs into 2D convs (#3625)#3785
Open
katlun-lgtm wants to merge 4 commits into
Open
Decompose small kernel-depth 3D convs into 2D convs (#3625)#3785katlun-lgtm wants to merge 4 commits into
katlun-lgtm wants to merge 4 commits into
Conversation
3x3x3 (and other small kernel-depth) 3D convolutions run the generic 3D implicit-gemm kernel, which has no Winograd / 3x3-specialized path. As a result they are 2-5x slower than decomposing the same op into per-frame 2D convolutions (which do hit the tuned 2D dispatch). This shows up in video VAE decoders (CausalConv3d). Add small_kd_conv_3D_gpu: for small kernel depth, run KD 2D convs over zero-copy strided views of the input frames and weight depth-slices, and accumulate. Each 2D conv goes through dispatch_conv_2D_gpu, so 3x3 stride-1 taps get Winograd. A guard in dispatch_conv_3D_gpu takes this path only when it is valid and faster: input dilation 1, groups 1, N == 1, KD <= 7, depth stride and kernel-dilation 1, no depth padding, mod16 channels. Everything else falls through to the existing implicit gemm unchanged. Tests: the fast path is validated against the CPU reference (fp32, exact) across several shapes, plus fall-back cases (depth stride > 1, depth padding, non-mod16 channels). Benchmarked on an M3 Max (bf16): 41x120x208, C=512, 3x3x3 goes 1128ms -> 335ms (3.4x -> 1.0x vs the per-frame 2D decomposition); fp32-exact vs CPU.
zcbenz
reviewed
Jul 1, 2026
zcbenz
approved these changes
Jul 3, 2026
zcbenz
left a comment
Collaborator
There was a problem hiding this comment.
Looks good to me, thanks!
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Closes #3625.
Problem
mx.conv_generalwith 5D inputs (3D convolution) is 2–5× slower than decomposing thesame op into per-frame 2D convolutions with a Python loop. This hits video-generation
workloads (VAE decoders with
CausalConv3d).The reason isn't the explicit-gemm fallback — for the common case (mod16 channels,
idil==1,groups==1)dispatch_conv_3D_gpualready routes toimplicit_gemm_conv_3D_gpu. The real cause is that the 2D dispatch has a Winogradkernel for 3×3 stride-1 convs (
dispatch_conv_2D_gpu→winograd_conv_2D_gpu), and the3D path has no Winograd or 3×3-specialized kernel — so a 3×3×3 conv misses Winograd
entirely. Decomposing a small kernel-depth 3D conv into
KD2D convs lets each 2D convhit the tuned 2D path.
Change
small_kd_conv_3D_gpu: for small kernel depth, for each depth tapkdwe build azero-copy strided view of the input frames (
[OD, H, W, C]at depth offsetkd) and theweight depth-slice (
[O, KH, KW, C]), runconv_2D_gpu, and accumulate. The accumulatorbuffer is repointed into
outviacopy_shared_buffer(thepad_and_slicepattern).A guard in
dispatch_conv_3D_gputakes this path only when it is valid and faster:idil == 1(all dims),groups == 1,N == 1;KD <= 7(theKD-1accumulate adds erode the win for largeKD);stride == 1andkernel_dilation == 1, andpad[0] == 0;mod16channels (so the per-frame 2D convs hit the fast path).Everything else falls through to the existing implicit gemm unchanged.
(Skeleton follows @Ved235's sketch in the issue; the depth-tap decomposition idea is
theirs.)
Results (M3 Max, mlx built from this branch)
Correctness — the fast path vs the CPU reference (fp32, exact):
Speed (bf16, 3×3×3), native 3D vs the per-frame 2D decomposition:
i.e. 2.4–3.4× → ~1.0× (parity) with the per-frame 2D path, and numerically identical
to it.
Tests
python/tests/test_conv.py:test_conv_3D_small_kd_decomposition— the fast path vs CPU across 5 shapes(
Cout != Cin,KD ∈ {1,2,3,5}, 1×1 spatial), fp32 exact.test_conv_3D_small_kd_fallback_cases— depth stride > 1, depth padding, and non-mod16channels stay correct (must take the fall-back path).
Full
test_conv.pysuite passes with no regressions.Open questions for maintainers
binary_op_gpu_inplace(..., "Add"). Prefer that, or abeta-accumulate added toconv_2D_gpu?KDthreshold — fixed at 7, or a cost heuristic vs the 3D implicit gemm?later? Also: should the fast path be extended to
N > 1and depth padding, or is theN == 1/pad[0] == 0guard fine for a first PR?