Squeezing a 26B diffusion LLM onto a Mac

Optimizing DiffusionGemma (26B-A4B-IT-4Bit) on an Apple M5 Pro, what worked, what didn't.

Authors
Sabesh
Sabesh
ResearchMLXDiffusion ModelsGemma

15th June 2026

TL;DR

I set out to make DiffusionGemma - a 26B-parameter, ~4B-active mixture-of-experts block-diffusion text model - fast enough to be a real interactive tool on a single Apple M5 Pro (48 GB, macOS 27 beta). The stock MLX inference path ran at ~18 tok/s on my original short repro, ~23–27 tok/s on a fair 512-token run.

I built an opt-in turbo engine that delivers:

  • ~86 tok/s average on GSM8K (13/20 correct, temperature 0, 320-token cap; 4 of the 7 misses were truncations that hit the cap, not wrong answers),
  • ~65–90 tok/s on sustained technical/long-form, bursting to ~140 on short Q&A and math,
  • at quality equal to the faithful reference sampler

That's a ~3× speedup at reference-level quality (up to ~6× on short prompts). A faster forced-commit mode reaches 100–138 tok/s but degrades on long creative text, so it isn't the default. The stretch goal was 300+ tok/s. I did not reach it - and the most useful thing I produced is not the speedup but the reason I couldn't: a measured bandwidth ceiling. Single-stream 4-bit MoE on this machine is weight-streaming bound at ~80 GB/s, and that ceiling holds even for Apple's own tuned Metal-4 neural-accelerator kernel. 300+ tok/s single-stream isn't a tuning problem; it requires either beating Apple's kernel (open research) or dropping the experts below 4-bit. I'll show the numbers behind that claim.

Along the way I found two real MLX bugs and drafted upstream contributions for all of it.


What DiffusionGemma is, and why naive inference is slow

Autoregressive LLMs generate one token at a time. Block-diffusion models work differently: they hold a fixed-size canvas of token positions (256 here) and denoise the whole canvas over several forward passes, freezing positions as they become confident. Instead of 256 sequential decode steps you do ~8–16 full-canvas forwards. In principle that's a throughput win, because each forward commits many tokens at once.

The catch is what each forward costs. DiffusionGemma is a mixture of experts: 128 experts, top-8 routing, 30 layers, 4-bit quantized, 262,144-token vocabulary. I profiled one denoising step at the native 256-token canvas on the M5 Pro:

component per step note
MoE experts (30 layers) ~200 ms the wall - weight-streaming bound
LM head (256 pos × 262k vocab, 4-bit) ~17 ms compute-bound (~22 TFLOPS)
self-conditioning softmax(logits) @ embed_table ~17 ms full-vocab matmul, every step
softcap + entropy chain over full vocab ~17 ms several passes over a 262k-wide tensor
attention + dense MLP (30 layers) ~45 ms already fast (NAX-accelerated)
total ~250–310 ms × 12–16 steps per 256-token canvas

The MoE dominates everything. Here's why, in one measured fact: at a 256-token canvas, the 2,048 token-expert pairs (256 × top-8) light up essentially all 128 experts in all 30 layers, every step. That's ~12.8 GB of 4-bit expert weights read per denoising step. The whole single-stream speed question collapses to: how fast can we stream 4-bit expert weights?

Everything else in the table is "soft" - work that's necessary for the reference sampler but cheap to approximate or skip. The MoE is the hard floor. Keep that split in mind; the optimizations below attack the soft rows, and the ceiling analysis is entirely about the hard one.


The optimizations that worked

Every technique below is opt-in behind --gen-kwargs '{"diffusion_turbo": true, ...}'. With the flag off, the code path is byte-for-byte identical to stock mlx-vlm - the reference loop is never touched. Each technique is independently toggleable.

1. Hierarchical exact top-K sampler chain - ~35 ms → ~3 ms/step

A direct mx.topk over 262,144 logits costs ~40 ms on this GPU. But the sampler only ever needs the top ~64 candidates per position. So I do a two-stage chunk-max selection: reshape the logits into chunks, take chunk maxima, pick the top-K chunks, then top-K within. ~1.5 ms.

The nice part is that it's provably exact, not an approximation. If a value is in the global top-K, fewer than K chunk-maxima can exceed it, so its chunk is among the top-K chunks - the selected chunks contain every global top-K value. And because softcap (tanh) and temperature scaling are both monotonic, top-K membership computed on the raw logits is identical to membership after softcap. So softcap, softmax, entropy, and sampling all run on a [T, 64] tensor instead of [T, 262144]. The only thing truncated is the entropy tail mass beyond 64 candidates - negligible at the low temperatures where positions actually commit.

2. Top-K self-conditioning - drops a 378-GFLOP matmul and a 1.5 GB table

DiffusionGemma feeds its own predictions back as conditioning: the reference computes softmax(logits) @ embedding_table. That's a ~378-GFLOP matmul and it requires keeping a dequantized ~1.5 GB copy of the embedding table resident in memory.

Instead, I gather the 64 embedding rows for the top-K candidates and take the probability-weighted sum. Near-identical conditioning signal at these temperatures, and it lets me drop the dequantized table entirely. Approximately free.

3. Compacted active-set decoder - forward only the live positions

With monotone commits, once a position clears a confidence threshold it freezes. So why keep forwarding it? The compacted decoder (turbo_compact) forwards only the live positions (plus the positions committed on the previous step, so their K/V reflect the committed token). Frozen positions are served from per-layer canvas K/V buffers, with RoPE for the scattered live positions applied from precomputed cos/sin tables that replicate mx.fast.rope.

The correctness anchor here is a bit-exactness invariant: with the forward set equal to all positions and an empty freeze state, one runner step matches the stock decoder forward bit-for-bit. Both RoPE variants (the sliding and the proportional partial-rotary full-attention paths) were validated numerically against the model's own modules. The shrinking working set also helps the MoE: below ~48 live tokens MLX drops onto its faster gemv-class dispatch.

4. Active-set shape bucketing - fixing the "memory leak" that wasn't

This one was a trap. Over a long-lived process, throughput drifted from ~120 down to ~20 tok/s over ~20 minutes. It looked exactly like a memory leak.

It wasn't. Peak memory was flat the whole time. The real cause: the compacted decoder forwards a different number of live positions every step (256 → ~16 as tokens freeze), and each distinct shape makes MLX JIT-compile and cache a fresh graph. Over time the compiled-kernel set grew unbounded and throughput decayed. It was compiled-kernel-cache growth, not a buffer leak.

The fix is to round both the forward set and the active set up to ≤10 fixed buckets (16, 32, …, 256), padding with duplicate positions. Duplicates recompute identical values and scatter idempotently, so results are unchanged - only the distinct-shape count is bounded, from ~240 possible down to ≤10. (A mx.clear_cache() every N steps was my first guess and was actively harmful - it zeros the buffer pool and causes realloc churn. Bucketing is the right answer.)

5. Two-stage confidence schedule + repair + EOS tail-drop

  • Two-stage schedule. A picky threshold (e.g. 0.95) while the canvas structure is still forming, then a relaxed one (e.g. 0.80) for stragglers. Crucially, no forced final flush. A forced flush argmax-commits every still-uncertain position on the last step; those marginal tokens get re-encoded into the next canvas as context, compounding across blocks into single-token repetition ('the the the…' collapse). The two-stage schedule avoids it by only ever committing high-confidence positions.
  • Repair pass (turbo_repair). One extra full-canvas forward after the canvas drains, re-argmaxing every position. This recovers the reference sampler's "final canvas = argmax of last forward" semantics and gives early-frozen tokens a chance to be revised. With it on, bounded-step denoising matches faithful accuracy.
  • EOS tail-drop. Once a confident EOS freezes, positions after it stop being denoised. Big win on real chat answers, where the stock loop denoises all 256 positions regardless of how short the answer is.
  • Repeat-guard. Stops generation if a token repeats ≥16 times in a row - natural text never does this, so it's an unambiguous degeneration signal and prevents emitting a wall of garbage up to max_tokens.

The scoreboard

512-token generation, fresh process, baseline = the stock engine's own 512-token throughput:

prompt baseline turbo (default, two-stage no-flush)
"explain speculative decoding" 27.1 ~70
"what is MLX" ~25 ~72
"write a short story" (creative) 23.3 ~45
GSM8K[:20] average 86 (short items burst to ~144)

A faster forced-flush mode hits 100–138 tok/s but commits marginal tokens that collapse long creative text, so the default trades ~30–40% speed for robustness.

GSM8K[:20] at temperature 0 averaged ~86 tok/s (math spends more steps per token). The recommended setting --gen-kwargs '{"diffusion_turbo": true, "turbo_compact": true, "turbo_accept": "confidence", "turbo_threshold": [0.95, 0.80, 8], "turbo_repair": true}' — scores 11/12 on the 12-item set, matching the faithful reference.


The honest failures

Five custom Metal kernels that didn't beat MLX

The whole game is streaming 4-bit expert weights faster. I wrote five custom Metal kernels trying to do it:

  1. Register-resident multi-row qmv - 27–37 GB/s. Holding M×8 activation registers spills. Dead.
  2. simdgroup-matrix (simdgroup_half8x8) MoE - 50–55 GB/s and numerically buggy. Dead.
  3. Capacity-padded NAX gather - 74–84 GB/s. Equal to MLX, no win.
  4. One-simdgroup-per-row, weights shared via L2 - 15 GB/s. The hypothesis was that 16 simdgroups reading the same weights would let L2 dedup the DRAM reads. Disproven: the GPU doesn't keep simdgroups in lockstep, so each one re-reads from DRAM. This is why weight reuse fundamentally needs explicit threadgroup staging.
  5. Metal-4 mpp::tensor_ops::matmul2d from a JIT kernel - this one is a genuinely useful unlock: #include <MetalPerformancePrimitives/...> works from a JIT mx.fast.metal_kernel on macOS 27, and the Metal 4.0 language version is honored. But the tensor-op operands want real tensor/tensor_handle device views; the pointer-backed tensor_inline you can build from JIT's raw device T* args isn't honored, and the output comes out scrambled. A correct NAX MoE kernel has to be a C++ primitive inside MLX core, reusing the steel NAXFrag machinery - and since MLX's existing NAX kernel already hits the ceiling, a faster one is open kernel research, not a tile tweak.

Best of the five: ~84 GB/s. MLX's own kernel: ~80. I could match the wall. I could not beat it. These kernels are kept in the tree as research artifacts only - explicitly not wired into the engine and not proposed as upstream performance wins. They document the ceiling; they don't move it.


The physics ceiling: why 300+ tok/s is blocked

This is the part I'm most confident about, because it's the part I measured most carefully.

The single-stream MoE forward is dominated by reading ~12.8 GB of 4-bit weights per step. So I benchmarked exactly one thing: 4-bit quantized matmul throughput as a function of M, the number of activation rows that reuse each weight. Same weights, varying only M:

M (rows reused per weight) GB/s kernel regime
1 220 qmv no reuse - direct stream
2 190 qmv
4 107 (transition) reuse begins
8 56 qmm staging begins
16 80 qmm (incl. NAX matmul2d) staged - this is the MoE's M
32 84 qmm staged
64 80 qmm staged

For calibration on the same machine: elementwise streaming hits ~263 GB/s, and MLX's own qmv gemv kernel hits ~220 GB/s. The DRAM and the dequant front-end are both clearly capable of >200 GB/s.

Now read the curve. At M ≤ 2 there's no weight reuse - each weight is consumed once, it's a straight gemv, and it streams at the full ~220 GB/s. But the moment M grows past ~4, reusing a weight row across multiple activation rows requires staging it through threadgroup memory, and every staged path collapses to ~80 GB/s. From M=16 through M=64 it's a flat plateau.

And the MoE needs exactly the M that sits in the bad regime:

M_per_expert ≈ (canvas 256 × top-8) / 128 experts ≈ 16 rows

So the model gives up ~2.7× of available bandwidth precisely where it spends ~80% of its time.

The crucial detail - and the reason this isn't just "MLX's kernel is slow" - is that the ~80 GB/s plateau includes Apple's own Metal-4 NAX kernel. The gather_qmm NAX path uses mpp::tensor_ops::matmul2d (the NAXFrag machinery in steel/gemm/nax.h). At the MoE's shapes, that tensor-op kernel is the 80 GB/s number. The 220 GB/s only exists at M ≤ 2, where there's no reuse to stage. This is a property of staged 4-bit weight-reuse on this hardware generation, confirmed against the best kernel Apple ships.

The implication is clean. Against 12.8 GB/step at ~80 GB/s, the MoE alone is ~160 ms/step; at the 8–14 steps coherent text needs, single-stream throughput is capped well below 300 well before any other factor matters. 300+ tok/s single-stream is blocked by a measured bandwidth ceiling, not by software I can easily improve.

Two levers could break it, and I measured both:

  • Batch multiple canvases (denoise B at once; the weight read is flat in token count). Measured 1.0 / 1.26 / 1.47× at B = 1/2/4 - sublinear, because attention, the dense MLP, the LM head, and the MoE compute all scale with tokens; only the weight read is flat. This is multi-stream throughput (~190 tok/s aggregate at best), not single-stream latency.
  • 2–3 bit experts - directly cut the bytes the wall is made of (~1.4–2×, plausibly enough for silver). Untried: it needs the bf16 checkpoint (~52 GB, not local) to requantize the experts cleanly without compounding 4-bit error, and it carries real quality risk. This is the single most promising next experiment.

Neither is a kernel tweak. The wall is real.


Real bugs found, and the upstream contributions

Chasing this turned up two genuine MLX bugs and produced a stack of drafted contributions.

1. gather_qmm NAX kernel-name mismatch (a real, one-line bug). On a NAX-capable GPU, gather_qmm_nax() builds its Metal kernel name with the tile parameter bk = 32, but the NAX gather kernels are only ever instantiated with bk = 64. So on M5 every batched expert matmul that takes the NAX fast path fails at dispatch:

Unable to load kernel affine_gather_qmm_t_nax_bfloat16_t_gs_64_b_4_bm64_bn64_bk32_wm2_wn2_alN_true

The fix is bk = 32bk = 64 - bk only feeds the kernel-name string in that dispatch, not the geometry, so it's a pure "ask for the kernel that exists" change. The sibling qmm_nax() already uses 64, so it reads as a copy-paste slip. This is the same fix as the open, already-approved (now-merged) ml-explore/mlx#3632; my contribution is an independent on-device confirmation and repro on M5 Pro / macOS 27 with the exact failing kernel name, positioned as a "+1, reproduced, please merge" rather than a competing PR. I also drafted a small separable MLX_DISABLE_QMM_NAX env hatch for A/B-ing the NAX path against the steel kernels (on M5 the steel path is ~2× worse, so NAX is correctly the default).

2. int64 scatter crashes the Metal JIT (a bug report). On macOS 27 beta, put_along_axis / scatter_add_axis with a 64-bit element dtype don't fail cleanly - they detonate the Metal library build. The fallback-atomic union computes packing_size<long> = sizeof(uint)/sizeof(long) = 4/8 = 0, which declares a zero-length array (T val[0]) and divides by zero. The plain Scatter primitive already guards 64-bit dtypes with a clean ValueError; the ScatterAxis path is missing that guard. The fix is to mirror the existing guard (or properly support 64-bit element types). The turbo engine sidesteps it by using int32 throughout, which is also why the engine is int32-disciplined by design.

The full set of drafted contributions:

  • mlx-vlm PR - the opt-in diffusion_turbo engine (~1,200 lines, mostly one new file; reference loop untouched), with an honest performance section and a suggested commit split.
  • mlx issue - the int64-scatter Metal JIT crash, with a minimal repro and dtype matrix. A fix has been proposed at the time of writing this.
  • mlx discussion - the measured 4-bit MoE bandwidth ceiling, framed as a question to the kernel authors: is small-M tile tuning of gather_qmm_rhs_nax (which carries a literal // TODO: Tune the block sizes) tractable, or is ~80 GB/s the understood staged ceiling on this generation? The honest answer either way redirects future effort.

What's next

In rough value order:

  1. 2–3 bit experts. Pull the bf16 checkpoint, requantize only the SwitchLinear experts to 3-bit (keep attention/router/embeddings at ≥4-bit), and validate GSM8K + free-form. This attacks the wall directly - fewer bytes to stream - and is the single most promising path to silver. Untried because it needs the ~52 GB bf16 checkpoint and carries quality risk.
  2. A C++ NAX MoE kernel in MLX core. A small-M, dense-per-expert variant reusing the steel NAXFrag machinery, profiling the QuantizedBlockLoader dequant stage as the prime suspect for the 80-vs-220 gap. Upstreamable, but uncertain it beats 80 - that's the open research question the discussion poses.
  3. A batched turbo engine. Lift the batch-1 restriction for ~1.5× server throughput. Sublinear but real, and the compacted runner already has the masking machinery.
  4. Land the upstream stack.

The takeaway

I took DiffusionGemma from 18 tok/s, aimed for 300+, and landed ~70–90 sustained (86 GSM8K average, ~140 on short prompts) at reference-level quality - a ~3× speedup (up to ~6× on short content) that turns a research curiosity into a usable interactive tool. The optimizations that got there were the boring-but-correct kind: exact top-K instead of full-vocab, gathered self-conditioning, a compacted decoder with a bit-exactness invariant, shape bucketing to bound the kernel cache, and a confidence schedule that doesn't collapse.

But the result I'm proudest of is the one that says no. Five custom kernels and a careful bandwidth sweep proved that 300+ tok/s single-stream isn't a software gap I failed to close - it's a hardware/precision ceiling: 4-bit weight-reuse GEMM tops out at ~80 GB/s on this machine, including Apple's own tuned kernel, exactly at the tile shape the MoE needs. Knowing precisely why you can't go faster, and being able to prove it, is worth as much as the speedup itself. The two levers that remain - lower-bit experts and multi-canvas batching - are now the obvious next moves, and I know exactly why.