Skip to content

Decompose small kernel-depth 3D convs into 2D convs (#3625)#3785

Open
katlun-lgtm wants to merge 4 commits into
ml-explore:mainfrom
katlun-lgtm:conv-3d-small-kd-decomp
Open

Decompose small kernel-depth 3D convs into 2D convs (#3625)#3785
katlun-lgtm wants to merge 4 commits into
ml-explore:mainfrom
katlun-lgtm:conv-3d-small-kd-decomp

Conversation

@katlun-lgtm

Copy link
Copy Markdown
Contributor

Closes #3625.

Problem

mx.conv_general with 5D inputs (3D convolution) is 2–5× slower than decomposing the
same op into per-frame 2D convolutions with a Python loop. This hits video-generation
workloads (VAE decoders with CausalConv3d).

The reason isn't the explicit-gemm fallback — for the common case (mod16 channels,
idil==1, groups==1) dispatch_conv_3D_gpu already routes to
implicit_gemm_conv_3D_gpu. The real cause is that the 2D dispatch has a Winograd
kernel for 3×3 stride-1 convs
(dispatch_conv_2D_gpuwinograd_conv_2D_gpu), and the
3D path has no Winograd or 3×3-specialized kernel — so a 3×3×3 conv misses Winograd
entirely. Decomposing a small kernel-depth 3D conv into KD 2D convs lets each 2D conv
hit the tuned 2D path.

Change

small_kd_conv_3D_gpu: for small kernel depth, for each depth tap kd we build a
zero-copy strided view of the input frames ([OD, H, W, C] at depth offset kd) and the
weight depth-slice ([O, KH, KW, C]), run conv_2D_gpu, and accumulate. The accumulator
buffer is repointed into out via copy_shared_buffer (the pad_and_slice pattern).

A guard in dispatch_conv_3D_gpu takes this path only when it is valid and faster:

  • idil == 1 (all dims), groups == 1, N == 1;
  • KD <= 7 (the KD-1 accumulate adds erode the win for large KD);
  • depth stride == 1 and kernel_dilation == 1, and pad[0] == 0;
  • mod16 channels (so the per-frame 2D convs hit the fast path).

Everything else falls through to the existing implicit gemm unchanged.

(Skeleton follows @Ved235's sketch in the issue; the depth-tap decomposition idea is
theirs.)

Results (M3 Max, mlx built from this branch)

Correctness — the fast path vs the CPU reference (fp32, exact):

KD max|gpu − cpu| max|cpu|
1 3e-5 61.0
2 1e-4 91.6
3 1e-4 106.8
5 2e-4 142.2

Speed (bf16, 3×3×3), native 3D vs the per-frame 2D decomposition:

case (T×H×W, C→O) before after per-frame 2D
41×60×104, 256→256 71 ms 28 ms 27 ms
41×120×208, 512→512 1128 ms 335 ms 336 ms
17×64×64, 256→256 20 ms 9 ms 8 ms

i.e. 2.4–3.4× → ~1.0× (parity) with the per-frame 2D path, and numerically identical
to it.

Tests

python/tests/test_conv.py:

  • test_conv_3D_small_kd_decomposition — the fast path vs CPU across 5 shapes
    (Cout != Cin, KD ∈ {1,2,3,5}, 1×1 spatial), fp32 exact.
  • test_conv_3D_small_kd_fallback_cases — depth stride > 1, depth padding, and non-mod16
    channels stay correct (must take the fall-back path).

Full test_conv.py suite passes with no regressions.

Open questions for maintainers

  1. Accumulation — currently the first tap owns the buffer and later taps do an in-place
    binary_op_gpu_inplace(..., "Add"). Prefer that, or a beta-accumulate added to
    conv_2D_gpu?
  2. KD threshold — fixed at 7, or a cost heuristic vs the 3D implicit gemm?
  3. Scope — is the decomposition the accepted fix, or do you want a native 3D Winograd
    later? Also: should the fast path be extended to N > 1 and depth padding, or is the
    N == 1 / pad[0] == 0 guard fine for a first PR?

3x3x3 (and other small kernel-depth) 3D convolutions run the generic 3D
implicit-gemm kernel, which has no Winograd / 3x3-specialized path. As a
result they are 2-5x slower than decomposing the same op into per-frame 2D
convolutions (which do hit the tuned 2D dispatch). This shows up in video
VAE decoders (CausalConv3d).

Add small_kd_conv_3D_gpu: for small kernel depth, run KD 2D convs over
zero-copy strided views of the input frames and weight depth-slices, and
accumulate. Each 2D conv goes through dispatch_conv_2D_gpu, so 3x3 stride-1
taps get Winograd. A guard in dispatch_conv_3D_gpu takes this path only when
it is valid and faster: input dilation 1, groups 1, N == 1, KD <= 7, depth
stride and kernel-dilation 1, no depth padding, mod16 channels. Everything
else falls through to the existing implicit gemm unchanged.

Tests: the fast path is validated against the CPU reference (fp32, exact)
across several shapes, plus fall-back cases (depth stride > 1, depth
padding, non-mod16 channels).

Benchmarked on an M3 Max (bf16): 41x120x208, C=512, 3x3x3 goes 1128ms ->
335ms (3.4x -> 1.0x vs the per-frame 2D decomposition); fp32-exact vs CPU.
Comment thread mlx/backend/metal/conv.cpp Outdated

@zcbenz zcbenz left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Performance] conv_general 5D (3D convolution) is 2-5x slower than equivalent per-frame 2D conv

2 participants