Add per-layer hybrid sliding/full attention (Gemma 3 / Gemma 4) to CoreML static LLM export#19251
Open
john-rocky wants to merge 3 commits intopytorch:mainfrom
Open
Add per-layer hybrid sliding/full attention (Gemma 3 / Gemma 4) to CoreML static LLM export#19251john-rocky wants to merge 3 commits intopytorch:mainfrom
john-rocky wants to merge 3 commits intopytorch:mainfrom
Conversation
Models trained with sliding-window attention (Mistral 7B, Gemma 3, Gemma 4, Llama 4 Scout, …) only need each layer to attend to the last `W` tokens, but `export_static_llm_coreml.py` was always sizing the per-layer KV cache to `max_context_len - input_len`. That made longer contexts proportionally more expensive in both KV cache memory and per-token attention compute, even though the model was trained to ignore everything outside the window. Add a `--sliding_window` flag that caps the cache at the trained window. The downstream pieces — `StaticAttentionMask` invariants under cache eviction and the `StaticAttentionIOManager`'s per-layer `cache_lens` plumbing — already support this; the export script just needed to expose it. Per-layer mixed sliding/full attention (Gemma 3/4) is left for a follow-up; this PR uses one window for every layer. The cache_len computation is factored into `_resolve_cache_len` so it is unit-testable, and the README's ANE Optimizations section documents the new option. ### Memory savings example For a 32-layer / n_kv_heads=8 / head_dim=128 model exported with `max_context_len=8192` in fp16, dropping the cache from 8160 to 4096 cuts the per-method KV cache from ~1.07 GB to ~0.54 GB.
Builds on the prior --sliding_window flag. Gemma 3, Gemma 4, and the Llama 4 Scout family interleave sliding and full attention layers rather than using one global setting: Gemma 4 E2B is '4 sliding + 1 full' repeated 7 times across 35 layers; Gemma 3 is '5 sliding + 1 full' repeated. HuggingFace expresses this as a single integer `sliding_window_pattern`, which is what the new `--sliding_window_pattern` flag mirrors. Implementation: - `_resolve_per_layer_cache_lens(...)` produces a per-layer cache_lens list using the HF rule (layer i is full iff (i+1) % P == 0); the IO manager and the model already accept per-layer cache_lens, so the attention mask dict and the per-layer KV cache shapes follow. - `_get_metadata` now reads each cache's cache_len from the example tensor's sequence dimension instead of receiving a single scalar, so the C++ runner metadata describes each layer correctly under hybrid attention. - Both single-method and multifunction export paths use the per-layer resolver. The previous PR's uniform-sliding behavior is preserved when `--sliding_window_pattern` is not set. Authored with Claude.
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/19251
Note: Links to docs will display an error until the docs builds have been completed.
|
This PR needs a
|
Verifies a tiny static-attention transformer accepts the heterogeneous cache shapes produced by _resolve_per_layer_cache_lens and runs a forward pass without errors — the strongest signal that the model and IO Manager really do route the right mask per layer under hybrid sliding/full attention.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Stacked on top of #19250 — that PR caps every layer's KV cache at a single
global window. Gemma 3, Gemma 4, and the Llama 4 Scout family instead
interleave sliding and full attention layers:
[sliding × 4, full × 1]× 7 = 35 layers (P=5)[sliding × 5, full × 1]× N (P=6)HuggingFace expresses this as a single integer
sliding_window_pattern,which is what the new
--sliding_window_patternflag mirrors.What changed (this PR's commit only)
_resolve_per_layer_cache_lens(...)produces a per-layercache_lenslist using the HF rule (layer
iis full iff(i + 1) % P == 0).StaticAttentionIOManageralready accepts per-layer cache_lens, so theattention mask dict (one mask per unique cache_len) and per-layer KV
cache shapes fall out for free. The forward pass already keys
mask = masks[cache_len]per layer, so it picks the right mask withoutany model code change.
_get_metadatanow reads each cache'scache_lenfrom the exampletensor's sequence dimension instead of taking a single scalar, so the
C++ runner metadata reports each layer's actual length under hybrid
attention.
the per-layer resolver.
--sliding_window_patternis not set.Why it matters
For Gemma 4 E2B with
max_context_len=8192and--sliding_window 4096 --sliding_window_pattern 5:n_kv_heads×head_dim× 2Bn_kv_heads×head_dim× 2Bvs. naively giving every layer the full 8160-token cache. For the E2B
config (
n_kv_heads=1,head_dim=256) that is 86 MB hybrid vs. 143 MBuniform-full; the savings grow proportionally for E4B.
Review order
This PR contains two commits. The first ('Add
--sliding_windowflag…')is identical to #19250 — please merge that one first; the diff on this
PR will then collapse to just the per-layer commit. I'm happy to rebase
once #19250 lands.
Test plan
Added 7 unit tests in
examples/apple/coreml/llama/test.py:test_per_layer_cache_lens_uniform_when_no_pattern— back-compat withthe
--sliding_window-only path.test_per_layer_cache_lens_uniform_full_when_no_window— no flag atall leaves every layer at
max_context_len - input_len.test_per_layer_cache_lens_gemma4_e2b_pattern— 35 layers, P=5produces 28 sliding + 7 full in the right positions.
test_per_layer_cache_lens_gemma3_pattern— P=6 produces thedocumented
[s, s, s, s, s, f, ...]interleave.test_per_layer_cache_lens_pattern_requires_sliding_window— inputvalidation.
test_per_layer_cache_lens_rejects_pattern_le_one— input validation(P=1 would degenerate to all-full and is almost certainly a typo).
test_create_example_inputs_with_per_layer_pattern_yields_two_cache_sizes— full path: example inputs really do contain caches of both sizes
and a mask per cache_len.
(All 6 tests from #19250 still pass alongside the 7 new ones.)
Authored with Claude.