diff --git a/public/assets/img/vit-samples/golden-retriever.jpg b/public/assets/img/vit-samples/golden-retriever.jpg
new file mode 100644
index 00000000..ccade526
Binary files /dev/null and b/public/assets/img/vit-samples/golden-retriever.jpg differ
diff --git a/public/assets/img/vit-samples/macaw.jpg b/public/assets/img/vit-samples/macaw.jpg
new file mode 100644
index 00000000..7cabcd9d
Binary files /dev/null and b/public/assets/img/vit-samples/macaw.jpg differ
diff --git a/public/assets/img/vit-samples/sunflower.jpg b/public/assets/img/vit-samples/sunflower.jpg
new file mode 100644
index 00000000..f092bac8
Binary files /dev/null and b/public/assets/img/vit-samples/sunflower.jpg differ
diff --git a/public/assets/img/vit-samples/tabby-cat.jpg b/public/assets/img/vit-samples/tabby-cat.jpg
new file mode 100644
index 00000000..d762d00b
Binary files /dev/null and b/public/assets/img/vit-samples/tabby-cat.jpg differ
diff --git a/public/assets/models/deit-tiny-int8.bin b/public/assets/models/deit-tiny-int8.bin
new file mode 100644
index 00000000..cdd11400
Binary files /dev/null and b/public/assets/models/deit-tiny-int8.bin differ
diff --git a/public/assets/models/imagenet-labels.json b/public/assets/models/imagenet-labels.json
new file mode 100644
index 00000000..c0b134d8
--- /dev/null
+++ b/public/assets/models/imagenet-labels.json
@@ -0,0 +1 @@
+["tench", "goldfish", "great white shark", "tiger shark", "hammerhead", "electric ray", "stingray", "cock", "hen", "ostrich", "brambling", "goldfinch", "house finch", "junco", "indigo bunting", "robin", "bulbul", "jay", "magpie", "chickadee", "water ouzel", "kite", "bald eagle", "vulture", "great grey owl", "European fire salamander", "common newt", "eft", "spotted salamander", "axolotl", "bullfrog", "tree frog", "tailed frog", "loggerhead", "leatherback turtle", "mud turtle", "terrapin", "box turtle", "banded gecko", "common iguana", "American chameleon", "whiptail", "agama", "frilled lizard", "alligator lizard", "Gila monster", "green lizard", "African chameleon", "Komodo dragon", "African crocodile", "American alligator", "triceratops", "thunder snake", "ringneck snake", "hognose snake", "green snake", "king snake", "garter snake", "water snake", "vine snake", "night snake", "boa constrictor", "rock python", "Indian cobra", "green mamba", "sea snake", "horned viper", "diamondback", "sidewinder", "trilobite", "harvestman", "scorpion", "black and gold garden spider", "barn spider", "garden spider", "black widow", "tarantula", "wolf spider", "tick", "centipede", "black grouse", "ptarmigan", "ruffed grouse", "prairie chicken", "peacock", "quail", "partridge", "African grey", "macaw", "sulphur-crested cockatoo", "lorikeet", "coucal", "bee eater", "hornbill", "hummingbird", "jacamar", "toucan", "drake", "red-breasted merganser", "goose", "black swan", "tusker", "echidna", "platypus", "wallaby", "koala", "wombat", "jellyfish", "sea anemone", "brain coral", "flatworm", "nematode", "conch", "snail", "slug", "sea slug", "chiton", "chambered nautilus", "Dungeness crab", "rock crab", "fiddler crab", "king crab", "American lobster", "spiny lobster", "crayfish", "hermit crab", "isopod", "white stork", "black stork", "spoonbill", "flamingo", "little blue heron", "American egret", "bittern", "crane", "limpkin", "European gallinule", "American coot", "bustard", "ruddy turnstone", "red-backed sandpiper", "redshank", "dowitcher", "oystercatcher", "pelican", "king penguin", "albatross", "grey whale", "killer whale", "dugong", "sea lion", "Chihuahua", "Japanese spaniel", "Maltese dog", "Pekinese", "Shih-Tzu", "Blenheim spaniel", "papillon", "toy terrier", "Rhodesian ridgeback", "Afghan hound", "basset", "beagle", "bloodhound", "bluetick", "black-and-tan coonhound", "Walker hound", "English foxhound", "redbone", "borzoi", "Irish wolfhound", "Italian greyhound", "whippet", "Ibizan hound", "Norwegian elkhound", "otterhound", "Saluki", "Scottish deerhound", "Weimaraner", "Staffordshire bullterrier", "American Staffordshire terrier", "Bedlington terrier", "Border terrier", "Kerry blue terrier", "Irish terrier", "Norfolk terrier", "Norwich terrier", "Yorkshire terrier", "wire-haired fox terrier", "Lakeland terrier", "Sealyham terrier", "Airedale", "cairn", "Australian terrier", "Dandie Dinmont", "Boston bull", "miniature schnauzer", "giant schnauzer", "standard schnauzer", "Scotch terrier", "Tibetan terrier", "silky terrier", "soft-coated wheaten terrier", "West Highland white terrier", "Lhasa", "flat-coated retriever", "curly-coated retriever", "golden retriever", "Labrador retriever", "Chesapeake Bay retriever", "German short-haired pointer", "vizsla", "English setter", "Irish setter", "Gordon setter", "Brittany spaniel", "clumber", "English springer", "Welsh springer spaniel", "cocker spaniel", "Sussex spaniel", "Irish water spaniel", "kuvasz", "schipperke", "groenendael", "malinois", "briard", "kelpie", "komondor", "Old English sheepdog", "Shetland sheepdog", "collie", "Border collie", "Bouvier des Flandres", "Rottweiler", "German shepherd", "Doberman", "miniature pinscher", "Greater Swiss Mountain dog", "Bernese mountain dog", "Appenzeller", "EntleBucher", "boxer", "bull mastiff", "Tibetan mastiff", "French bulldog", "Great Dane", "Saint Bernard", "Eskimo dog", "malamute", "Siberian husky", "dalmatian", "affenpinscher", "basenji", "pug", "Leonberg", "Newfoundland", "Great Pyrenees", "Samoyed", "Pomeranian", "chow", "keeshond", "Brabancon griffon", "Pembroke", "Cardigan", "toy poodle", "miniature poodle", "standard poodle", "Mexican hairless", "timber wolf", "white wolf", "red wolf", "coyote", "dingo", "dhole", "African hunting dog", "hyena", "red fox", "kit fox", "Arctic fox", "grey fox", "tabby", "tiger cat", "Persian cat", "Siamese cat", "Egyptian cat", "cougar", "lynx", "leopard", "snow leopard", "jaguar", "lion", "tiger", "cheetah", "brown bear", "American black bear", "ice bear", "sloth bear", "mongoose", "meerkat", "tiger beetle", "ladybug", "ground beetle", "long-horned beetle", "leaf beetle", "dung beetle", "rhinoceros beetle", "weevil", "fly", "bee", "ant", "grasshopper", "cricket", "walking stick", "cockroach", "mantis", "cicada", "leafhopper", "lacewing", "dragonfly", "damselfly", "admiral", "ringlet", "monarch", "cabbage butterfly", "sulphur butterfly", "lycaenid", "starfish", "sea urchin", "sea cucumber", "wood rabbit", "hare", "Angora", "hamster", "porcupine", "fox squirrel", "marmot", "beaver", "guinea pig", "sorrel", "zebra", "hog", "wild boar", "warthog", "hippopotamus", "ox", "water buffalo", "bison", "ram", "bighorn", "ibex", "hartebeest", "impala", "gazelle", "Arabian camel", "llama", "weasel", "mink", "polecat", "black-footed ferret", "otter", "skunk", "badger", "armadillo", "three-toed sloth", "orangutan", "gorilla", "chimpanzee", "gibbon", "siamang", "guenon", "patas", "baboon", "macaque", "langur", "colobus", "proboscis monkey", "marmoset", "capuchin", "howler monkey", "titi", "spider monkey", "squirrel monkey", "Madagascar cat", "indri", "Indian elephant", "African elephant", "lesser panda", "giant panda", "barracouta", "eel", "coho", "rock beauty", "anemone fish", "sturgeon", "gar", "lionfish", "puffer", "abacus", "abaya", "academic gown", "accordion", "acoustic guitar", "aircraft carrier", "airliner", "airship", "altar", "ambulance", "amphibian", "analog clock", "apiary", "apron", "ashcan", "assault rifle", "backpack", "bakery", "balance beam", "balloon", "ballpoint", "Band Aid", "banjo", "bannister", "barbell", "barber chair", "barbershop", "barn", "barometer", "barrel", "barrow", "baseball", "basketball", "bassinet", "bassoon", "bathing cap", "bath towel", "bathtub", "beach wagon", "beacon", "beaker", "bearskin", "beer bottle", "beer glass", "bell cote", "bib", "bicycle-built-for-two", "bikini", "binder", "binoculars", "birdhouse", "boathouse", "bobsled", "bolo tie", "bonnet", "bookcase", "bookshop", "bottlecap", "bow", "bow tie", "brass", "brassiere", "breakwater", "breastplate", "broom", "bucket", "buckle", "bulletproof vest", "bullet train", "butcher shop", "cab", "caldron", "candle", "cannon", "canoe", "can opener", "cardigan", "car mirror", "carousel", "carpenter's kit", "carton", "car wheel", "cash machine", "cassette", "cassette player", "castle", "catamaran", "CD player", "cello", "cellular telephone", "chain", "chainlink fence", "chain mail", "chain saw", "chest", "chiffonier", "chime", "china cabinet", "Christmas stocking", "church", "cinema", "cleaver", "cliff dwelling", "cloak", "clog", "cocktail shaker", "coffee mug", "coffeepot", "coil", "combination lock", "computer keyboard", "confectionery", "container ship", "convertible", "corkscrew", "cornet", "cowboy boot", "cowboy hat", "cradle", "crane", "crash helmet", "crate", "crib", "Crock Pot", "croquet ball", "crutch", "cuirass", "dam", "desk", "desktop computer", "dial telephone", "diaper", "digital clock", "digital watch", "dining table", "dishrag", "dishwasher", "disk brake", "dock", "dogsled", "dome", "doormat", "drilling platform", "drum", "drumstick", "dumbbell", "Dutch oven", "electric fan", "electric guitar", "electric locomotive", "entertainment center", "envelope", "espresso maker", "face powder", "feather boa", "file", "fireboat", "fire engine", "fire screen", "flagpole", "flute", "folding chair", "football helmet", "forklift", "fountain", "fountain pen", "four-poster", "freight car", "French horn", "frying pan", "fur coat", "garbage truck", "gasmask", "gas pump", "goblet", "go-kart", "golf ball", "golfcart", "gondola", "gong", "gown", "grand piano", "greenhouse", "grille", "grocery store", "guillotine", "hair slide", "hair spray", "half track", "hammer", "hamper", "hand blower", "hand-held computer", "handkerchief", "hard disc", "harmonica", "harp", "harvester", "hatchet", "holster", "home theater", "honeycomb", "hook", "hoopskirt", "horizontal bar", "horse cart", "hourglass", "iPod", "iron", "jack-o'-lantern", "jean", "jeep", "jersey", "jigsaw puzzle", "jinrikisha", "joystick", "kimono", "knee pad", "knot", "lab coat", "ladle", "lampshade", "laptop", "lawn mower", "lens cap", "letter opener", "library", "lifeboat", "lighter", "limousine", "liner", "lipstick", "Loafer", "lotion", "loudspeaker", "loupe", "lumbermill", "magnetic compass", "mailbag", "mailbox", "maillot", "maillot", "manhole cover", "maraca", "marimba", "mask", "matchstick", "maypole", "maze", "measuring cup", "medicine chest", "megalith", "microphone", "microwave", "military uniform", "milk can", "minibus", "miniskirt", "minivan", "missile", "mitten", "mixing bowl", "mobile home", "Model T", "modem", "monastery", "monitor", "moped", "mortar", "mortarboard", "mosque", "mosquito net", "motor scooter", "mountain bike", "mountain tent", "mouse", "mousetrap", "moving van", "muzzle", "nail", "neck brace", "necklace", "nipple", "notebook", "obelisk", "oboe", "ocarina", "odometer", "oil filter", "organ", "oscilloscope", "overskirt", "oxcart", "oxygen mask", "packet", "paddle", "paddlewheel", "padlock", "paintbrush", "pajama", "palace", "panpipe", "paper towel", "parachute", "parallel bars", "park bench", "parking meter", "passenger car", "patio", "pay-phone", "pedestal", "pencil box", "pencil sharpener", "perfume", "Petri dish", "photocopier", "pick", "pickelhaube", "picket fence", "pickup", "pier", "piggy bank", "pill bottle", "pillow", "ping-pong ball", "pinwheel", "pirate", "pitcher", "plane", "planetarium", "plastic bag", "plate rack", "plow", "plunger", "Polaroid camera", "pole", "police van", "poncho", "pool table", "pop bottle", "pot", "potter's wheel", "power drill", "prayer rug", "printer", "prison", "projectile", "projector", "puck", "punching bag", "purse", "quill", "quilt", "racer", "racket", "radiator", "radio", "radio telescope", "rain barrel", "recreational vehicle", "reel", "reflex camera", "refrigerator", "remote control", "restaurant", "revolver", "rifle", "rocking chair", "rotisserie", "rubber eraser", "rugby ball", "rule", "running shoe", "safe", "safety pin", "saltshaker", "sandal", "sarong", "sax", "scabbard", "scale", "school bus", "schooner", "scoreboard", "screen", "screw", "screwdriver", "seat belt", "sewing machine", "shield", "shoe shop", "shoji", "shopping basket", "shopping cart", "shovel", "shower cap", "shower curtain", "ski", "ski mask", "sleeping bag", "slide rule", "sliding door", "slot", "snorkel", "snowmobile", "snowplow", "soap dispenser", "soccer ball", "sock", "solar dish", "sombrero", "soup bowl", "space bar", "space heater", "space shuttle", "spatula", "speedboat", "spider web", "spindle", "sports car", "spotlight", "stage", "steam locomotive", "steel arch bridge", "steel drum", "stethoscope", "stole", "stone wall", "stopwatch", "stove", "strainer", "streetcar", "stretcher", "studio couch", "stupa", "submarine", "suit", "sundial", "sunglass", "sunglasses", "sunscreen", "suspension bridge", "swab", "sweatshirt", "swimming trunks", "swing", "switch", "syringe", "table lamp", "tank", "tape player", "teapot", "teddy", "television", "tennis ball", "thatch", "theater curtain", "thimble", "thresher", "throne", "tile roof", "toaster", "tobacco shop", "toilet seat", "torch", "totem pole", "tow truck", "toyshop", "tractor", "trailer truck", "tray", "trench coat", "tricycle", "trimaran", "tripod", "triumphal arch", "trolleybus", "trombone", "tub", "turnstile", "typewriter keyboard", "umbrella", "unicycle", "upright", "vacuum", "vase", "vault", "velvet", "vending machine", "vestment", "viaduct", "violin", "volleyball", "waffle iron", "wall clock", "wallet", "wardrobe", "warplane", "washbasin", "washer", "water bottle", "water jug", "water tower", "whiskey jug", "whistle", "wig", "window screen", "window shade", "Windsor tie", "wine bottle", "wing", "wok", "wooden spoon", "wool", "worm fence", "wreck", "yawl", "yurt", "web site", "comic book", "crossword puzzle", "street sign", "traffic light", "book jacket", "menu", "plate", "guacamole", "consomme", "hot pot", "trifle", "ice cream", "ice lolly", "French loaf", "bagel", "pretzel", "cheeseburger", "hotdog", "mashed potato", "head cabbage", "broccoli", "cauliflower", "zucchini", "spaghetti squash", "acorn squash", "butternut squash", "cucumber", "artichoke", "bell pepper", "cardoon", "mushroom", "Granny Smith", "strawberry", "orange", "lemon", "fig", "pineapple", "banana", "jackfruit", "custard apple", "pomegranate", "hay", "carbonara", "chocolate sauce", "dough", "meat loaf", "pizza", "potpie", "burrito", "red wine", "espresso", "cup", "eggnog", "alp", "bubble", "cliff", "coral reef", "geyser", "lakeside", "promontory", "sandbar", "seashore", "valley", "volcano", "ballplayer", "groom", "scuba diver", "rapeseed", "daisy", "yellow lady's slipper", "corn", "acorn", "hip", "buckeye", "coral fungus", "agaric", "gyromitra", "stinkhorn", "earthstar", "hen-of-the-woods", "bolete", "ear", "toilet tissue"]
\ No newline at end of file
diff --git a/sample/visionTransformer/README.md b/sample/visionTransformer/README.md
new file mode 100644
index 00000000..6752522a
--- /dev/null
+++ b/sample/visionTransformer/README.md
@@ -0,0 +1,50 @@
+# Vision Transformer (ViT) — WebGPU Compute
+
+Runs a DeiT-Tiny vision transformer (5.7M params) entirely in WebGPU compute shaders to classify images, and visualizes per-head attention maps as interactive heatmap overlays.
+
+**[Live demo](https://lyonsno.github.io/webgpu-samples/)**
+
+## What it does
+
+1. **Patch embedding**: splits a 224x224 image into a 14x14 grid of 16x16-pixel patches, projects each to a 192-dim token
+2. **12 transformer blocks**: each applies layer normalization, multi-head self-attention (3 heads), and an MLP with GELU activation, connected by residual additions
+3. **Classification**: the CLS token is projected to 1000 ImageNet class logits
+4. **Attention visualization**: select any layer and head to see which image patches the model attends to, rendered as a viridis heatmap overlay
+
+## Running locally
+
+```bash
+npm install
+npm run serve
+# Open http://localhost:8080/?sample=visionTransformer
+```
+
+## Model weights
+
+The int8-quantized weight file (`public/assets/models/deit-tiny-int8.bin`, 5.5MB) is derived from Meta's [DeiT-Tiny](https://huggingface.co/facebook/deit-tiny-patch16-224) (Apache-2.0). To regenerate:
+
+```bash
+pip install torch timm
+python tools/convert_deit_weights.py
+```
+
+## Files
+
+| File | Description |
+|------|-------------|
+| `main.ts` | UI, image loading, WebGPU setup |
+| `inference.ts` | Forward pass orchestration (patch embed → 12 transformer blocks → classify) |
+| `weights.ts` | Binary weight loader with int8 dequantization |
+| `visualize.ts` | Attention map extraction and heatmap rendering |
+| `shaders/patchEmbed.wgsl` | Image patches → token embeddings |
+| `shaders/layerNorm.wgsl` | Layer normalization |
+| `shaders/mlp.wgsl` | Feed-forward network (linear + GELU) |
+| `shaders/attnScores.wgsl` | Q·K^T scaled dot-product scores |
+| `shaders/attnSoftmax.wgsl` | Softmax normalization + attention weight storage |
+| `shaders/attnApply.wgsl` | Attention-weighted value summation |
+| `shaders/visualize.wgsl` | Viridis heatmap overlay render pass |
+
+## License
+
+Sample code: BSD-3-Clause (matching webgpu-samples).
+Model weights: Apache-2.0 (Meta DeiT). See [THIRD_PARTY_NOTICES.md](THIRD_PARTY_NOTICES.md).
diff --git a/sample/visionTransformer/THIRD_PARTY_NOTICES.md b/sample/visionTransformer/THIRD_PARTY_NOTICES.md
new file mode 100644
index 00000000..9d1c57be
--- /dev/null
+++ b/sample/visionTransformer/THIRD_PARTY_NOTICES.md
@@ -0,0 +1,32 @@
+# Third-Party Notices
+
+## Model Weights
+
+The model weights (`public/assets/models/deit-tiny-int8.bin`) are derived from
+**DeiT-Tiny** (`facebook/deit-tiny-patch16-224`) by Meta Research.
+
+- Paper: "Training data-efficient image transformers & distillation through attention" (Touvron et al., 2021)
+- Source: https://huggingface.co/facebook/deit-tiny-patch16-224
+- License: Apache License 2.0 (https://www.apache.org/licenses/LICENSE-2.0)
+
+The weights have been quantized to int8 with per-tensor scale factors.
+
+## ImageNet Class Labels
+
+The class labels (`public/assets/models/imagenet-labels.json`) are human-readable
+English descriptions derived from WordNet synsets via the `timm` library.
+
+- WordNet License: https://wordnet.princeton.edu/license-and-commercial-use
+- timm Library: Apache License 2.0
+
+## Sample Images
+
+The sample images in `public/assets/img/vit-samples/` are sourced from
+[Unsplash](https://unsplash.com/) and used under the
+[Unsplash License](https://unsplash.com/license), which permits free use
+for commercial and non-commercial purposes without attribution.
+
+- `golden-retriever.jpg` — Photo by [Unsplash](https://unsplash.com/photos/552053831-71594a27632d)
+- `tabby-cat.jpg` — Photo by [Unsplash](https://unsplash.com/photos/514888286974-6c03e2ca1dba)
+- `macaw.jpg` — Photo by [Unsplash](https://unsplash.com/photos/552728089-57bdde30beb3)
+- `sunflower.jpg` — Photo by [Unsplash](https://unsplash.com/photos/597848212624-a19eb35e2651)
diff --git a/sample/visionTransformer/index.html b/sample/visionTransformer/index.html
new file mode 100644
index 00000000..c6a37991
--- /dev/null
+++ b/sample/visionTransformer/index.html
@@ -0,0 +1,165 @@
+
+
+
+
+
+ webgpu-samples: visionTransformer
+
+
+
+
+
+
+
+
+
+
+
+
+
Drop image to classify
+
+
+
+
+
+
+
+
+
+
Loading model weights...
+
+
+
diff --git a/sample/visionTransformer/inference.ts b/sample/visionTransformer/inference.ts
new file mode 100644
index 00000000..3d19ad66
--- /dev/null
+++ b/sample/visionTransformer/inference.ts
@@ -0,0 +1,803 @@
+// DeiT-Tiny inference engine
+// All compute dispatches use at most 7 storage/uniform bindings per bind group.
+
+import { ModelWeights, createTensorBuffer, DEIT_CONFIG } from './weights';
+import patchEmbedWGSL from './shaders/patchEmbed.wgsl';
+import layerNormWGSL from './shaders/layerNorm.wgsl';
+import mlpWGSL from './shaders/mlp.wgsl';
+import attnScoresWGSL from './shaders/attnScores.wgsl';
+import attnSoftmaxWGSL from './shaders/attnSoftmax.wgsl';
+import attnApplyWGSL from './shaders/attnApply.wgsl';
+
+const C = DEIT_CONFIG;
+
+function ceilDiv(a: number, b: number): number {
+ return Math.ceil(a / b);
+}
+
+// Reusable pool of uniform buffers to avoid per-dispatch GPU memory allocation.
+// Buffers are allocated on first use and retained for subsequent inference runs.
+class UniformBufferPool {
+ private device: GPUDevice;
+ private buffers: GPUBuffer[] = [];
+ private index = 0;
+
+ constructor(device: GPUDevice) {
+ this.device = device;
+ }
+
+ // Reset the pool index at the start of each run. No buffers are freed.
+ reset() {
+ this.index = 0;
+ }
+
+ // Get a uniform buffer, writing the given data into it.
+ get(data: ArrayBuffer): GPUBuffer {
+ const size = Math.max(data.byteLength, 16);
+ if (this.index >= this.buffers.length) {
+ this.buffers.push(
+ this.device.createBuffer({
+ size,
+ usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST,
+ })
+ );
+ }
+ const buf = this.buffers[this.index++];
+ this.device.queue.writeBuffer(buf, 0, new Uint8Array(data));
+ return buf;
+ }
+}
+
+/**
+ * VitInference runs a full DeiT-Tiny vision transformer forward pass in
+ * WebGPU compute shaders.
+ *
+ * The forward pass follows the standard ViT architecture:
+ * 1. Patch embedding: split image into 14x14 grid of 16x16 patches,
+ * project each to a 192-dim token, prepend a CLS token, add position embeddings.
+ * 2. 12 transformer blocks, each containing:
+ * - Layer norm → multi-head self-attention (3 heads, 64 dims each) → residual add
+ * - Layer norm → MLP (192 → 768 with GELU → 768 → 192) → residual add
+ * 3. Final layer norm → classify using the CLS token → 1000-class logits
+ *
+ * Attention weights from each layer/head are stored for visualization,
+ * allowing interactive exploration of what image regions the model attends to.
+ */
+export class VitInference {
+ private device: GPUDevice;
+ private uniformPool!: UniformBufferPool;
+
+ // Compute pipelines — each shader uses at most 7 bindings per bind group
+ private patchEmbedPipeline!: GPUComputePipeline;
+ private layerNormPipeline!: GPUComputePipeline;
+ private linearPipeline!: GPUComputePipeline; // plain linear projection
+ private linearGeluPipeline!: GPUComputePipeline; // linear + GELU
+ private attnScoresPipeline!: GPUComputePipeline;
+ private attnSoftmaxPipeline!: GPUComputePipeline;
+ private attnApplyPipeline!: GPUComputePipeline;
+ private residualAddPipeline!: GPUComputePipeline;
+
+ // Buffers
+ private imageBuffer!: GPUBuffer;
+ private tokenBuffer!: GPUBuffer; // (197, 192) main token state
+ private normBuffer!: GPUBuffer; // (197, 192) output of layer norm
+ private qBuffer!: GPUBuffer; // (197, 192)
+ private kBuffer!: GPUBuffer; // (197, 192)
+ private vBuffer!: GPUBuffer; // (197, 192)
+ private attnOutBuffer!: GPUBuffer; // (197, 192) attention-weighted output
+ private projOutBuffer!: GPUBuffer; // (197, 192) after output projection
+ private scoreBuffer!: GPUBuffer; // (3, 197, 197)
+ private hiddenBuffer!: GPUBuffer; // (197, 768) MLP hidden
+ private mlpOutBuffer!: GPUBuffer; // (197, 192) MLP output
+ private classLogitsBuffer!: GPUBuffer; // (1000)
+ private attnWeightsBuffer!: GPUBuffer; // (12, 3, 197, 197) all attention weights
+ private logitsReadbackBuffer!: GPUBuffer;
+ private attnReadbackBuffer!: GPUBuffer;
+
+ // Weight buffers
+ private layerWeights: Map[] = [];
+ private patchEmbedWeights!: {
+ projWeight: GPUBuffer;
+ projBias: GPUBuffer;
+ clsToken: GPUBuffer;
+ posEmbed: GPUBuffer;
+ };
+ private classHeadWeights!: { weight: GPUBuffer; bias: GPUBuffer };
+ private finalNormWeights!: { gamma: GPUBuffer; beta: GPUBuffer };
+
+ constructor(device: GPUDevice) {
+ this.device = device;
+ }
+
+ async initialize(weights: ModelWeights) {
+ this.uniformPool = new UniformBufferPool(this.device);
+ this.createPipelines();
+ this.createBuffers();
+ this.uploadWeights(weights);
+ }
+
+ private createPipelines() {
+ const device = this.device;
+
+ const makePipeline = (label: string, code: string, entryPoint: string) =>
+ device.createComputePipeline({
+ label,
+ layout: 'auto',
+ compute: {
+ module: device.createShaderModule({ label, code }),
+ entryPoint,
+ },
+ });
+
+ // Patch embedding: image pixels → token embeddings
+ this.patchEmbedPipeline = makePipeline(
+ 'patchEmbed',
+ patchEmbedWGSL,
+ 'main'
+ );
+ // Layer normalization: stabilizes activations between layers
+ this.layerNormPipeline = makePipeline('layerNorm', layerNormWGSL, 'main');
+
+ // Linear projections (used for Q/K/V, output projection, MLP, and class head).
+ // The MLP shader has two entry points with identical bindings: 'linear' (plain)
+ // and 'linearGelu' (with GELU activation fused into the output).
+ const mlpModule = device.createShaderModule({
+ label: 'mlp',
+ code: mlpWGSL,
+ });
+ this.linearGeluPipeline = device.createComputePipeline({
+ label: 'linearGelu',
+ layout: 'auto',
+ compute: { module: mlpModule, entryPoint: 'linearGelu' },
+ });
+ this.linearPipeline = device.createComputePipeline({
+ label: 'linear',
+ layout: 'auto',
+ compute: { module: mlpModule, entryPoint: 'linear' },
+ });
+
+ // Attention is split into three stages to keep bindings per shader low:
+ // 1. Scores: Q·K^T scaled by 1/sqrt(head_dim) — measures query-key similarity
+ // 2. Softmax: normalizes scores to attention probabilities
+ // 3. Apply: weighted sum of values using attention probabilities
+ this.attnScoresPipeline = makePipeline(
+ 'attnScores',
+ attnScoresWGSL,
+ 'computeScores'
+ );
+ this.attnSoftmaxPipeline = makePipeline(
+ 'attnSoftmax',
+ attnSoftmaxWGSL,
+ 'main'
+ );
+ this.attnApplyPipeline = makePipeline('attnApply', attnApplyWGSL, 'main');
+
+ // Element-wise residual addition: enables gradient flow through deep networks
+ const residualModule = device.createShaderModule({
+ label: 'residualAdd',
+ code: `
+ @group(0) @binding(0) var dst: array;
+ @group(0) @binding(1) var src: array;
+ @group(0) @binding(2) var count: u32;
+
+ @compute @workgroup_size(256)
+ fn main(@builtin(global_invocation_id) gid: vec3u) {
+ if (gid.x >= count) { return; }
+ dst[gid.x] = dst[gid.x] + src[gid.x];
+ }
+ `,
+ });
+ this.residualAddPipeline = device.createComputePipeline({
+ label: 'residualAdd',
+ layout: 'auto',
+ compute: { module: residualModule, entryPoint: 'main' },
+ });
+ }
+
+ private createBuffers() {
+ const device = this.device;
+ const T = C.numTokens * C.dim * 4; // token buffer size in bytes
+
+ const storage = (label: string, size: number, extra = 0) =>
+ device.createBuffer({
+ label,
+ size,
+ usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | extra,
+ });
+
+ // Input image as normalized float32 RGB
+ this.imageBuffer = storage(
+ 'image',
+ C.imgSize * C.imgSize * C.channels * 4,
+ GPUBufferUsage.COPY_DST
+ );
+ // Token embeddings: (197 tokens, 192 dims) — the main state through the network
+ this.tokenBuffer = storage('tokens', T, GPUBufferUsage.COPY_DST);
+ // Scratch buffer for layer norm output
+ this.normBuffer = storage('norm', T);
+ // Query, Key, Value projections for attention
+ this.qBuffer = storage('Q', T);
+ this.kBuffer = storage('K', T);
+ this.vBuffer = storage('V', T);
+ // Attention-weighted value sum (before output projection)
+ this.attnOutBuffer = storage('attnOut', T);
+ // Output of attention block (after output projection)
+ this.projOutBuffer = storage('projOut', T);
+ // Attention scores: (3 heads, 197 queries, 197 keys)
+ this.scoreBuffer = storage(
+ 'attnScores',
+ C.numHeads * C.numTokens * C.numTokens * 4
+ );
+ // MLP hidden activations: expanded from 192 to 768 dims
+ this.hiddenBuffer = storage('mlpHidden', C.numTokens * C.mlpHiddenDim * 4);
+ // MLP output: projected back from 768 to 192 dims
+ this.mlpOutBuffer = storage('mlpOut', T);
+ // Classification logits: one score per ImageNet class
+ this.classLogitsBuffer = storage('classLogits', C.numClasses * 4);
+ // Stored attention weights from all layers for visualization readback
+ this.attnWeightsBuffer = storage(
+ 'attnWeights',
+ C.numLayers * C.numHeads * C.numTokens * C.numTokens * 4
+ );
+ this.logitsReadbackBuffer = device.createBuffer({
+ label: 'logitsReadback',
+ size: C.numClasses * 4,
+ usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST,
+ });
+ this.attnReadbackBuffer = device.createBuffer({
+ label: 'attnReadback',
+ size: C.numLayers * C.numHeads * C.numTokens * C.numTokens * 4,
+ usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST,
+ });
+ }
+
+ private uploadWeights(weights: ModelWeights) {
+ const device = this.device;
+ const t = weights.tensors;
+
+ this.patchEmbedWeights = {
+ projWeight: createTensorBuffer(device, t.get('patch_embed.proj.weight')!),
+ projBias: createTensorBuffer(device, t.get('patch_embed.proj.bias')!),
+ clsToken: createTensorBuffer(device, t.get('cls_token')!),
+ posEmbed: createTensorBuffer(device, t.get('pos_embed')!),
+ };
+
+ for (let l = 0; l < C.numLayers; l++) {
+ const bufs = new Map();
+ const prefix = `blocks.${l}`;
+ for (const name of [
+ 'attn.qkv.weight',
+ 'attn.qkv.bias',
+ 'attn.proj.weight',
+ 'attn.proj.bias',
+ 'norm1.weight',
+ 'norm1.bias',
+ 'norm2.weight',
+ 'norm2.bias',
+ 'mlp.fc1.weight',
+ 'mlp.fc1.bias',
+ 'mlp.fc2.weight',
+ 'mlp.fc2.bias',
+ ]) {
+ const tensor = t.get(`${prefix}.${name}`);
+ if (tensor) bufs.set(name, createTensorBuffer(device, tensor));
+ }
+ this.layerWeights.push(bufs);
+ }
+
+ this.finalNormWeights = {
+ gamma: createTensorBuffer(device, t.get('norm.weight')!),
+ beta: createTensorBuffer(device, t.get('norm.bias')!),
+ };
+ this.classHeadWeights = {
+ weight: createTensorBuffer(device, t.get('head.weight')!),
+ bias: createTensorBuffer(device, t.get('head.bias')!),
+ };
+ }
+
+ uploadImage(imageData: Float32Array) {
+ this.device.queue.writeBuffer(this.imageBuffer, 0, imageData);
+ }
+
+ /**
+ * Run the full forward pass: image → classification logits + attention maps.
+ *
+ * All compute work is recorded into a single command encoder and submitted
+ * in one GPU queue submission. The results are read back asynchronously
+ * via two mapAsync calls in parallel.
+ */
+ async run(): Promise<{
+ logits: Float32Array;
+ attnWeights: Float32Array;
+ elapsedMs: number;
+ }> {
+ const device = this.device;
+ const startTime = performance.now();
+ this.uniformPool.reset();
+ const encoder = device.createCommandEncoder();
+
+ // Stage 1: Convert the 224x224 RGB image into 197 token embeddings.
+ // The image is split into a 14x14 grid of 16x16-pixel patches, each
+ // linearly projected to 192 dimensions. A learnable CLS (classification)
+ // token is prepended, and position embeddings are added.
+ this.encodePatchEmbed(encoder);
+
+ // Stage 2: Pass tokens through 12 transformer blocks. Each block lets
+ // every token attend to every other token (self-attention), then processes
+ // each token independently through an MLP. The token representations are
+ // progressively refined — early layers detect local features like edges
+ // and textures, while later layers capture global semantic relationships.
+ for (let l = 0; l < C.numLayers; l++) {
+ this.encodeTransformerBlock(encoder, l);
+ }
+
+ // Stage 3: Final layer norm + classification. The CLS token (index 0)
+ // has aggregated information from all image patches via attention. We
+ // project it to 1000 dimensions — one logit per ImageNet class.
+ this.encodeLayerNorm(
+ encoder,
+ this.tokenBuffer,
+ this.normBuffer,
+ this.finalNormWeights.gamma,
+ this.finalNormWeights.beta
+ );
+ this.encodeClassHead(encoder);
+
+ // Copy both logits and attention weights for readback in one submit
+ const attnBytes = C.numLayers * C.numHeads * C.numTokens * C.numTokens * 4;
+ encoder.copyBufferToBuffer(
+ this.classLogitsBuffer,
+ 0,
+ this.logitsReadbackBuffer,
+ 0,
+ C.numClasses * 4
+ );
+ encoder.copyBufferToBuffer(
+ this.attnWeightsBuffer,
+ 0,
+ this.attnReadbackBuffer,
+ 0,
+ attnBytes
+ );
+
+ device.queue.submit([encoder.finish()]);
+
+ // Read both results after a single submit
+ const [logits, attnWeights] = await Promise.all([
+ this.logitsReadbackBuffer.mapAsync(GPUMapMode.READ).then(() => {
+ const data = new Float32Array(
+ this.logitsReadbackBuffer.getMappedRange(0, C.numClasses * 4).slice(0)
+ );
+ this.logitsReadbackBuffer.unmap();
+ return data;
+ }),
+ this.attnReadbackBuffer.mapAsync(GPUMapMode.READ).then(() => {
+ const data = new Float32Array(
+ this.attnReadbackBuffer.getMappedRange(0, attnBytes).slice(0)
+ );
+ this.attnReadbackBuffer.unmap();
+ return data;
+ }),
+ ]);
+
+ return { logits, attnWeights, elapsedMs: performance.now() - startTime };
+ }
+
+ // --- Dispatch helpers (each uses <= 7 bindings) ---
+
+ private encodePatchEmbed(encoder: GPUCommandEncoder) {
+ const device = this.device;
+ const params = this.uniformPool.get(
+ new Uint32Array([C.imgSize, C.patchSize, C.numPatches, C.channels, C.dim])
+ .buffer
+ );
+
+ const bg = device.createBindGroup({
+ layout: this.patchEmbedPipeline.getBindGroupLayout(0),
+ entries: [
+ { binding: 0, resource: { buffer: params } },
+ { binding: 1, resource: { buffer: this.imageBuffer } },
+ { binding: 2, resource: { buffer: this.patchEmbedWeights.projWeight } },
+ { binding: 3, resource: { buffer: this.patchEmbedWeights.projBias } },
+ { binding: 4, resource: { buffer: this.patchEmbedWeights.clsToken } },
+ { binding: 5, resource: { buffer: this.patchEmbedWeights.posEmbed } },
+ { binding: 6, resource: { buffer: this.tokenBuffer } },
+ ],
+ });
+
+ const pass = encoder.beginComputePass();
+ pass.setPipeline(this.patchEmbedPipeline);
+ pass.setBindGroup(0, bg);
+ pass.dispatchWorkgroups(ceilDiv(C.numTokens * C.dim, 256));
+ pass.end();
+ }
+
+ private encodeLayerNorm(
+ encoder: GPUCommandEncoder,
+ input: GPUBuffer,
+ output: GPUBuffer,
+ gamma: GPUBuffer,
+ beta: GPUBuffer
+ ) {
+ const device = this.device;
+ const paramsData = new ArrayBuffer(16);
+ const v = new DataView(paramsData);
+ v.setUint32(0, C.numTokens, true);
+ v.setUint32(4, C.dim, true);
+ v.setFloat32(8, 1e-6, true);
+ const params = this.uniformPool.get(paramsData);
+
+ const bg = device.createBindGroup({
+ layout: this.layerNormPipeline.getBindGroupLayout(0),
+ entries: [
+ { binding: 0, resource: { buffer: params } },
+ { binding: 1, resource: { buffer: input } },
+ { binding: 2, resource: { buffer: gamma } },
+ { binding: 3, resource: { buffer: beta } },
+ { binding: 4, resource: { buffer: output } },
+ ],
+ });
+
+ const pass = encoder.beginComputePass();
+ pass.setPipeline(this.layerNormPipeline);
+ pass.setBindGroup(0, bg);
+ pass.dispatchWorkgroups(C.numTokens);
+ pass.end();
+ }
+
+ private encodeLinear(
+ encoder: GPUCommandEncoder,
+ pipeline: GPUComputePipeline,
+ input: GPUBuffer,
+ weight: GPUBuffer,
+ bias: GPUBuffer,
+ output: GPUBuffer,
+ numRows: number,
+ inDim: number,
+ outDim: number
+ ) {
+ const device = this.device;
+ const params = this.uniformPool.get(
+ new Uint32Array([numRows, inDim, outDim]).buffer
+ );
+
+ const bg = device.createBindGroup({
+ layout: pipeline.getBindGroupLayout(0),
+ entries: [
+ { binding: 0, resource: { buffer: params } },
+ { binding: 1, resource: { buffer: input } },
+ { binding: 2, resource: { buffer: weight } },
+ { binding: 3, resource: { buffer: bias } },
+ { binding: 4, resource: { buffer: output } },
+ ],
+ });
+
+ const pass = encoder.beginComputePass();
+ pass.setPipeline(pipeline);
+ pass.setBindGroup(0, bg);
+ pass.dispatchWorkgroups(ceilDiv(numRows * outDim, 256));
+ pass.end();
+ }
+
+ private encodeResidualAdd(
+ encoder: GPUCommandEncoder,
+ dst: GPUBuffer,
+ src: GPUBuffer
+ ) {
+ const device = this.device;
+ const count = C.numTokens * C.dim;
+ const countBuf = this.uniformPool.get(new Uint32Array([count]).buffer);
+
+ const bg = device.createBindGroup({
+ layout: this.residualAddPipeline.getBindGroupLayout(0),
+ entries: [
+ { binding: 0, resource: { buffer: dst } },
+ { binding: 1, resource: { buffer: src } },
+ { binding: 2, resource: { buffer: countBuf } },
+ ],
+ });
+
+ const pass = encoder.beginComputePass();
+ pass.setPipeline(this.residualAddPipeline);
+ pass.setBindGroup(0, bg);
+ pass.dispatchWorkgroups(ceilDiv(count, 256));
+ pass.end();
+ }
+
+ /**
+ * Encodes one transformer block (pre-norm architecture):
+ * norm1 → attention → residual add → norm2 → MLP → residual add
+ *
+ * Each block refines the token representations. Early layers learn local
+ * features (edges, textures), later layers learn global relationships.
+ * The residual connections let information flow directly through the network.
+ */
+ private encodeTransformerBlock(encoder: GPUCommandEncoder, layerIdx: number) {
+ const lw = this.layerWeights[layerIdx];
+
+ // 1. Layer norm stabilizes activations before attention
+ this.encodeLayerNorm(
+ encoder,
+ this.tokenBuffer,
+ this.normBuffer,
+ lw.get('norm1.weight')!,
+ lw.get('norm1.bias')!
+ );
+
+ // 2. Self-attention: normBuffer -> projOutBuffer
+ this.encodeAttention(encoder, layerIdx, lw);
+
+ // 3. Residual: tokenBuffer += projOutBuffer
+ this.encodeResidualAdd(encoder, this.tokenBuffer, this.projOutBuffer);
+
+ // 4. LayerNorm2: tokenBuffer -> normBuffer
+ this.encodeLayerNorm(
+ encoder,
+ this.tokenBuffer,
+ this.normBuffer,
+ lw.get('norm2.weight')!,
+ lw.get('norm2.bias')!
+ );
+
+ // 5. MLP: normBuffer -> hidden (GELU) -> mlpOutBuffer
+ this.encodeLinear(
+ encoder,
+ this.linearGeluPipeline,
+ this.normBuffer,
+ lw.get('mlp.fc1.weight')!,
+ lw.get('mlp.fc1.bias')!,
+ this.hiddenBuffer,
+ C.numTokens,
+ C.dim,
+ C.mlpHiddenDim
+ );
+
+ this.encodeLinear(
+ encoder,
+ this.linearPipeline,
+ this.hiddenBuffer,
+ lw.get('mlp.fc2.weight')!,
+ lw.get('mlp.fc2.bias')!,
+ this.mlpOutBuffer,
+ C.numTokens,
+ C.mlpHiddenDim,
+ C.dim
+ );
+
+ // 6. Residual: tokenBuffer += mlpOutBuffer
+ this.encodeResidualAdd(encoder, this.tokenBuffer, this.mlpOutBuffer);
+ }
+
+ /**
+ * Multi-head self-attention: lets each token "look at" every other token.
+ *
+ * Each token is projected into Query (what am I looking for?), Key (what do
+ * I contain?), and Value (what information do I carry?) vectors. Attention
+ * scores = Q·K^T / sqrt(head_dim) measure how relevant each key is to each
+ * query. After softmax normalization, these scores weight the values to
+ * produce the output. Multiple heads allow attending to different aspects
+ * simultaneously (e.g., color, shape, position).
+ *
+ * The attention weights (softmax output) are stored for visualization —
+ * they reveal which image patches the model focuses on at each layer/head.
+ */
+ private encodeAttention(
+ encoder: GPUCommandEncoder,
+ layerIdx: number,
+ lw: Map
+ ) {
+ const device = this.device;
+
+ const qkvWeight = lw.get('attn.qkv.weight')!;
+ const qkvBias = lw.get('attn.qkv.bias')!;
+ const wSize = C.dim * C.dim * 4;
+ const bSize = C.dim * 4;
+
+ // Q projection (5 bindings)
+ this.encodeLinearWithOffsets(
+ encoder,
+ this.linearPipeline,
+ this.normBuffer,
+ qkvWeight,
+ 0,
+ wSize,
+ qkvBias,
+ 0,
+ bSize,
+ this.qBuffer,
+ C.numTokens,
+ C.dim,
+ C.dim
+ );
+
+ // K projection (5 bindings)
+ this.encodeLinearWithOffsets(
+ encoder,
+ this.linearPipeline,
+ this.normBuffer,
+ qkvWeight,
+ wSize,
+ wSize,
+ qkvBias,
+ bSize,
+ bSize,
+ this.kBuffer,
+ C.numTokens,
+ C.dim,
+ C.dim
+ );
+
+ // V projection (5 bindings)
+ this.encodeLinearWithOffsets(
+ encoder,
+ this.linearPipeline,
+ this.normBuffer,
+ qkvWeight,
+ 2 * wSize,
+ wSize,
+ qkvBias,
+ 2 * bSize,
+ bSize,
+ this.vBuffer,
+ C.numTokens,
+ C.dim,
+ C.dim
+ );
+
+ // Attention scores: Q, K -> scoreBuf (4 bindings)
+ {
+ const paramsData = new ArrayBuffer(24);
+ const v = new DataView(paramsData);
+ v.setUint32(0, C.numTokens, true);
+ v.setUint32(4, C.dim, true);
+ v.setUint32(8, C.numHeads, true);
+ v.setUint32(12, C.headDim, true);
+ v.setFloat32(16, C.scale, true);
+ v.setUint32(20, layerIdx, true);
+ const params = this.uniformPool.get(paramsData);
+
+ const bg = device.createBindGroup({
+ layout: this.attnScoresPipeline.getBindGroupLayout(0),
+ entries: [
+ { binding: 0, resource: { buffer: params } },
+ { binding: 1, resource: { buffer: this.qBuffer } },
+ { binding: 2, resource: { buffer: this.kBuffer } },
+ { binding: 3, resource: { buffer: this.scoreBuffer } },
+ ],
+ });
+
+ const pass = encoder.beginComputePass();
+ pass.setPipeline(this.attnScoresPipeline);
+ pass.setBindGroup(0, bg);
+ pass.dispatchWorkgroups(
+ ceilDiv(C.numHeads * C.numTokens * C.numTokens, 256)
+ );
+ pass.end();
+ }
+
+ // Softmax + store attention weights (3 bindings)
+ {
+ const paramsData = new ArrayBuffer(16);
+ const v = new DataView(paramsData);
+ v.setUint32(0, C.numTokens, true);
+ v.setUint32(4, C.numHeads, true);
+ v.setUint32(8, layerIdx, true);
+ const params = this.uniformPool.get(paramsData);
+
+ const bg = device.createBindGroup({
+ layout: this.attnSoftmaxPipeline.getBindGroupLayout(0),
+ entries: [
+ { binding: 0, resource: { buffer: params } },
+ { binding: 1, resource: { buffer: this.scoreBuffer } },
+ { binding: 2, resource: { buffer: this.attnWeightsBuffer } },
+ ],
+ });
+
+ const pass = encoder.beginComputePass();
+ pass.setPipeline(this.attnSoftmaxPipeline);
+ pass.setBindGroup(0, bg);
+ pass.dispatchWorkgroups(ceilDiv(C.numHeads * C.numTokens, 256));
+ pass.end();
+ }
+
+ // Apply attention: scores, V -> attnOutBuffer (4 bindings)
+ {
+ const paramsData = new ArrayBuffer(16);
+ const v = new DataView(paramsData);
+ v.setUint32(0, C.numTokens, true);
+ v.setUint32(4, C.dim, true);
+ v.setUint32(8, C.numHeads, true);
+ v.setUint32(12, C.headDim, true);
+ const params = this.uniformPool.get(paramsData);
+
+ const bg = device.createBindGroup({
+ layout: this.attnApplyPipeline.getBindGroupLayout(0),
+ entries: [
+ { binding: 0, resource: { buffer: params } },
+ { binding: 1, resource: { buffer: this.scoreBuffer } },
+ { binding: 2, resource: { buffer: this.vBuffer } },
+ { binding: 3, resource: { buffer: this.attnOutBuffer } },
+ ],
+ });
+
+ const pass = encoder.beginComputePass();
+ pass.setPipeline(this.attnApplyPipeline);
+ pass.setBindGroup(0, bg);
+ pass.dispatchWorkgroups(ceilDiv(C.numTokens * C.dim, 256));
+ pass.end();
+ }
+
+ // Output projection: attnOutBuffer @ Wo + bo -> projOutBuffer (5 bindings)
+ this.encodeLinear(
+ encoder,
+ this.linearPipeline,
+ this.attnOutBuffer,
+ lw.get('attn.proj.weight')!,
+ lw.get('attn.proj.bias')!,
+ this.projOutBuffer,
+ C.numTokens,
+ C.dim,
+ C.dim
+ );
+ }
+
+ private encodeLinearWithOffsets(
+ encoder: GPUCommandEncoder,
+ pipeline: GPUComputePipeline,
+ input: GPUBuffer,
+ weight: GPUBuffer,
+ weightOffset: number,
+ weightSize: number,
+ bias: GPUBuffer,
+ biasOffset: number,
+ biasSize: number,
+ output: GPUBuffer,
+ numRows: number,
+ inDim: number,
+ outDim: number
+ ) {
+ const device = this.device;
+ const params = this.uniformPool.get(
+ new Uint32Array([numRows, inDim, outDim]).buffer
+ );
+
+ const bg = device.createBindGroup({
+ layout: pipeline.getBindGroupLayout(0),
+ entries: [
+ { binding: 0, resource: { buffer: params } },
+ { binding: 1, resource: { buffer: input } },
+ {
+ binding: 2,
+ resource: { buffer: weight, offset: weightOffset, size: weightSize },
+ },
+ {
+ binding: 3,
+ resource: { buffer: bias, offset: biasOffset, size: biasSize },
+ },
+ { binding: 4, resource: { buffer: output } },
+ ],
+ });
+
+ const pass = encoder.beginComputePass();
+ pass.setPipeline(pipeline);
+ pass.setBindGroup(0, bg);
+ pass.dispatchWorkgroups(ceilDiv(numRows * outDim, 256));
+ pass.end();
+ }
+
+ private encodeClassHead(encoder: GPUCommandEncoder) {
+ // Classify using the CLS token (index 0 in the token sequence).
+ // normBuffer[0..dim] contains the CLS token after final layer norm.
+ this.encodeLinear(
+ encoder,
+ this.linearPipeline,
+ this.normBuffer,
+ this.classHeadWeights.weight,
+ this.classHeadWeights.bias,
+ this.classLogitsBuffer,
+ 1,
+ C.dim,
+ C.numClasses
+ );
+ }
+}
diff --git a/sample/visionTransformer/main.ts b/sample/visionTransformer/main.ts
new file mode 100644
index 00000000..78652705
--- /dev/null
+++ b/sample/visionTransformer/main.ts
@@ -0,0 +1,277 @@
+// Vision Transformer (DeiT-Tiny) — WebGPU Compute Inference + Attention Visualization
+//
+// This sample demonstrates how to run a real neural network (a vision
+// transformer) entirely in WebGPU compute shaders. It takes an input image,
+// splits it into patches, processes them through 12 transformer layers of
+// self-attention and feed-forward networks, and outputs both a classification
+// (what's in the image?) and attention maps (what did the model look at?).
+//
+// The attention visualization is the key output: by selecting different layers
+// and heads, you can see how the model's focus shifts from local texture
+// features in early layers to global semantic understanding in later layers.
+
+import {
+ quitIfAdapterNotAvailable,
+ quitIfWebGPUNotAvailableOrMissingFeatures,
+ quitIfLimitLessThan,
+} from '../util';
+import { loadWeights, DEIT_CONFIG } from './weights';
+import { VitInference } from './inference';
+import { AttentionVisualizer, topK, loadImageNetLabels } from './visualize';
+
+const C = DEIT_CONFIG;
+
+const canvas = document.querySelector('canvas') as HTMLCanvasElement;
+const dropZone = document.querySelector('#dropZone') as HTMLDivElement;
+const layerSlider = document.querySelector('#layerSlider') as HTMLInputElement;
+const headSlider = document.querySelector('#headSlider') as HTMLInputElement;
+const alphaSlider = document.querySelector('#alphaSlider') as HTMLInputElement;
+const layerValue = document.querySelector('#layerValue') as HTMLSpanElement;
+const headValue = document.querySelector('#headValue') as HTMLSpanElement;
+const alphaValue = document.querySelector('#alphaValue') as HTMLSpanElement;
+const resultsDiv = document.querySelector('#results') as HTMLDivElement;
+const statusDiv = document.querySelector('#status') as HTMLDivElement;
+
+const adapter = await navigator.gpu?.requestAdapter({
+ featureLevel: 'compatibility',
+});
+quitIfAdapterNotAvailable(adapter);
+
+const limits: Record = {};
+quitIfLimitLessThan(adapter, 'maxComputeWorkgroupSizeX', 256, limits);
+quitIfLimitLessThan(adapter, 'maxComputeInvocationsPerWorkgroup', 256, limits);
+const device = await adapter.requestDevice({ requiredLimits: limits });
+quitIfWebGPUNotAvailableOrMissingFeatures(adapter, device);
+
+canvas.width = C.imgSize * 2; // 448 for retina
+canvas.height = C.imgSize * 2;
+
+const context = canvas.getContext('webgpu')!;
+const presentationFormat = navigator.gpu.getPreferredCanvasFormat();
+context.configure({ device, format: presentationFormat });
+
+// --- Initialize components ---
+const inference = new VitInference(device);
+const visualizer = new AttentionVisualizer(device, context, presentationFormat);
+visualizer.initialize();
+
+// --- Load weights and labels ---
+statusDiv.textContent = 'Loading model weights...';
+
+let labels: string[] = [];
+let currentAttnWeights: Float32Array | null = null;
+let modelReady = false;
+
+try {
+ const [weights, loadedLabels] = await Promise.all([
+ loadWeights('../../assets/models/deit-tiny-int8.bin'),
+ loadImageNetLabels('../../assets/models/imagenet-labels.json'),
+ ]);
+ labels = loadedLabels;
+ await inference.initialize(weights);
+ modelReady = true;
+ statusDiv.textContent = 'Ready. Drag and drop an image to classify.';
+} catch (e) {
+ statusDiv.textContent = `Failed to load: ${
+ e instanceof Error ? e.message : String(e)
+ }`;
+ console.error(e);
+}
+
+// --- Image preprocessing ---
+function preprocessImage(img: HTMLImageElement | HTMLCanvasElement): {
+ normalized: Float32Array;
+ rgba: Uint8ClampedArray;
+} {
+ // Resize to 224x224
+ const tmpCanvas = document.createElement('canvas');
+ tmpCanvas.width = C.imgSize;
+ tmpCanvas.height = C.imgSize;
+ const ctx = tmpCanvas.getContext('2d')!;
+
+ // Center crop: scale shortest side to 224, then center crop
+ const scale = Math.max(C.imgSize / img.width, C.imgSize / img.height);
+ const sw = img.width * scale;
+ const sh = img.height * scale;
+ const sx = (C.imgSize - sw) / 2;
+ const sy = (C.imgSize - sh) / 2;
+ ctx.drawImage(img, sx, sy, sw, sh);
+
+ const imageData = ctx.getImageData(0, 0, C.imgSize, C.imgSize);
+ const rgba = imageData.data;
+
+ // ImageNet normalization: (pixel/255 - mean) / std
+ const mean = [0.485, 0.456, 0.406];
+ const std = [0.229, 0.224, 0.225];
+
+ const normalized = new Float32Array(C.imgSize * C.imgSize * C.channels);
+ for (let i = 0; i < C.imgSize * C.imgSize; i++) {
+ for (let c = 0; c < 3; c++) {
+ normalized[i * 3 + c] =
+ (imageData.data[i * 4 + c] / 255.0 - mean[c]) / std[c];
+ }
+ }
+
+ return { normalized, rgba };
+}
+
+// --- Run inference ---
+let isRunning = false;
+async function classifyImage(img: HTMLImageElement | HTMLCanvasElement) {
+ if (isRunning || !modelReady) return;
+ isRunning = true;
+ try {
+ statusDiv.textContent = 'Running inference...';
+
+ const { normalized, rgba } = preprocessImage(img);
+
+ // Upload image for visualization
+ visualizer.uploadImage(rgba, C.imgSize, C.imgSize);
+
+ // Upload normalized image for inference
+ inference.uploadImage(normalized);
+
+ // Run forward pass
+ const { logits, attnWeights, elapsedMs } = await inference.run();
+ currentAttnWeights = attnWeights;
+
+ // Display results
+ const predictions = topK(logits, labels, 5);
+ resultsDiv.innerHTML = predictions
+ .map(
+ (p) =>
+ `${p.label}${(
+ p.probability * 100
+ ).toFixed(1)}%
`
+ )
+ .join('');
+
+ statusDiv.textContent = `Inference: ${elapsedMs.toFixed(0)}ms`;
+
+ // Update attention visualization
+ const layer = parseInt(layerSlider.value);
+ const head = parseInt(headSlider.value);
+ visualizer.updateAttentionMap(attnWeights, layer, head);
+ visualizer.render();
+ } finally {
+ isRunning = false;
+ }
+}
+
+// --- UI event handlers ---
+layerSlider.addEventListener('input', () => {
+ const layer = parseInt(layerSlider.value);
+ layerValue.textContent = String(layer + 1);
+ if (currentAttnWeights) {
+ visualizer.updateAttentionMap(
+ currentAttnWeights,
+ layer,
+ parseInt(headSlider.value)
+ );
+ visualizer.render();
+ }
+});
+
+headSlider.addEventListener('input', () => {
+ const head = parseInt(headSlider.value);
+ headValue.textContent = String(head + 1);
+ if (currentAttnWeights) {
+ visualizer.updateAttentionMap(
+ currentAttnWeights,
+ parseInt(layerSlider.value),
+ head
+ );
+ visualizer.render();
+ }
+});
+
+alphaSlider.addEventListener('input', () => {
+ const alpha = parseInt(alphaSlider.value);
+ alphaValue.textContent = `${alpha}%`;
+ visualizer.setOverlayAlpha(alpha / 100);
+ if (currentAttnWeights) {
+ visualizer.render();
+ }
+});
+
+// --- Drag and drop ---
+function handleFile(file: File) {
+ if (!file.type.startsWith('image/')) return;
+
+ const img = new Image();
+ img.onload = () => {
+ classifyImage(img);
+ URL.revokeObjectURL(img.src);
+ };
+ img.src = URL.createObjectURL(file);
+}
+
+dropZone.addEventListener('dragover', (e) => {
+ e.preventDefault();
+ dropZone.classList.add('dragover');
+});
+
+dropZone.addEventListener('dragleave', () => {
+ dropZone.classList.remove('dragover');
+});
+
+dropZone.addEventListener('drop', (e) => {
+ e.preventDefault();
+ dropZone.classList.remove('dragover');
+ if (e.dataTransfer?.files.length) {
+ handleFile(e.dataTransfer.files[0]);
+ }
+});
+
+// Also support click to upload
+canvas.addEventListener('click', () => {
+ const input = document.createElement('input');
+ input.type = 'file';
+ input.accept = 'image/*';
+ input.onchange = () => {
+ if (input.files?.length) {
+ handleFile(input.files[0]);
+ }
+ };
+ input.click();
+});
+
+// --- Sample image buttons ---
+const sampleImagesDiv = document.querySelector(
+ '#sampleImages'
+) as HTMLDivElement;
+
+async function loadSampleImage(filename: string) {
+ statusDiv.textContent = `Loading ${filename}...`;
+
+ // Highlight active button
+ sampleImagesDiv.querySelectorAll('button').forEach((btn) => {
+ btn.classList.toggle('active', btn.getAttribute('data-img') === filename);
+ });
+
+ try {
+ const response = await fetch(`../../assets/img/vit-samples/${filename}`);
+ if (!response.ok) throw new Error(`${response.status}`);
+ const blob = await response.blob();
+ const img = new Image();
+ img.onload = () => {
+ classifyImage(img);
+ URL.revokeObjectURL(img.src);
+ };
+ img.src = URL.createObjectURL(blob);
+ } catch (e) {
+ statusDiv.textContent = `Failed to load ${filename}: ${e}`;
+ }
+}
+
+sampleImagesDiv.addEventListener('click', (e) => {
+ const btn = (e.target as HTMLElement).closest('button');
+ if (!btn) return;
+ const imgFile = btn.getAttribute('data-img');
+ if (imgFile) loadSampleImage(imgFile);
+});
+
+// Load first sample image on start (only if model loaded successfully)
+if (modelReady) {
+ loadSampleImage('golden-retriever.jpg');
+}
diff --git a/sample/visionTransformer/meta.ts b/sample/visionTransformer/meta.ts
new file mode 100644
index 00000000..772cc545
--- /dev/null
+++ b/sample/visionTransformer/meta.ts
@@ -0,0 +1,21 @@
+export default {
+ name: 'Vision Transformer (ViT)',
+ description:
+ 'Runs a DeiT-Tiny vision transformer entirely in WebGPU compute shaders to classify images, \
+and visualizes per-head attention maps showing which image patches the model focuses on. \
+Drag and drop an image to classify it.',
+ filename: __DIRNAME__,
+ sources: [
+ { path: 'main.ts' },
+ { path: 'inference.ts' },
+ { path: 'weights.ts' },
+ { path: 'visualize.ts' },
+ { path: 'shaders/patchEmbed.wgsl' },
+ { path: 'shaders/layerNorm.wgsl' },
+ { path: 'shaders/mlp.wgsl' },
+ { path: 'shaders/attnScores.wgsl' },
+ { path: 'shaders/attnSoftmax.wgsl' },
+ { path: 'shaders/attnApply.wgsl' },
+ { path: 'shaders/visualize.wgsl' },
+ ],
+};
diff --git a/sample/visionTransformer/shaders/attnApply.wgsl b/sample/visionTransformer/shaders/attnApply.wgsl
new file mode 100644
index 00000000..19845d40
--- /dev/null
+++ b/sample/visionTransformer/shaders/attnApply.wgsl
@@ -0,0 +1,44 @@
+// Apply attention weights to value vectors.
+//
+// For each token, computes a weighted sum of all value vectors using the
+// attention probabilities from softmax. Tokens that received high attention
+// scores contribute more to the output. This is how the model "reads from"
+// the patches it decided to attend to.
+//
+// Each thread computes one element of the (N, D) output, where the column
+// index maps to a specific head and position within that head.
+
+struct Params {
+ N: u32, // number of tokens (197)
+ D: u32, // model dimension (192)
+ numHeads: u32, // number of attention heads (3)
+ headDim: u32, // dimension per head (64)
+}
+
+@group(0) @binding(0) var params: Params;
+@group(0) @binding(1) var scoreBuf: array; // (numHeads, N, N)
+@group(0) @binding(2) var vBuf: array; // (N, D)
+@group(0) @binding(3) var output: array; // (N, D)
+
+@compute @workgroup_size(256)
+fn main(@builtin(global_invocation_id) gid: vec3u) {
+ let idx = gid.x;
+ let N = params.N;
+ let D = params.D;
+ let numHeads = params.numHeads;
+ let headDim = params.headDim;
+
+ if (idx >= N * D) { return; }
+
+ let row = idx / D;
+ let col = idx % D;
+ let head = col / headDim;
+ let d = col % headDim;
+
+ var val = 0.0;
+ let scoreBase = head * N * N + row * N;
+ for (var j = 0u; j < N; j++) {
+ val += scoreBuf[scoreBase + j] * vBuf[j * D + head * headDim + d];
+ }
+ output[idx] = val;
+}
diff --git a/sample/visionTransformer/shaders/attnScores.wgsl b/sample/visionTransformer/shaders/attnScores.wgsl
new file mode 100644
index 00000000..e9ab03e8
--- /dev/null
+++ b/sample/visionTransformer/shaders/attnScores.wgsl
@@ -0,0 +1,50 @@
+// Attention score computation: how much should query token qi attend to key token ki?
+//
+// For each attention head, computes the scaled dot product between the query
+// vector of token qi and the key vector of token ki:
+// score = (Q[qi] . K[ki]) / sqrt(head_dim)
+//
+// The scaling by 1/sqrt(head_dim) prevents dot products from growing too large
+// with increasing dimension, which would push softmax into saturation.
+// Each thread computes one element of the (numHeads, N, N) score tensor.
+
+struct Params {
+ N: u32, // number of tokens (197)
+ D: u32, // model dimension (192)
+ numHeads: u32, // number of attention heads (3)
+ headDim: u32, // dimension per head (64)
+ scale: f32, // 1/sqrt(headDim)
+ layerIdx: u32, // which layer for attnWeights storage offset
+}
+
+@group(0) @binding(0) var params: Params;
+@group(0) @binding(1) var qBuf: array; // (N, D)
+@group(0) @binding(2) var kBuf: array; // (N, D)
+@group(0) @binding(3) var scoreBuf: array; // (numHeads, N, N)
+
+// Compute raw scores
+@compute @workgroup_size(256)
+fn computeScores(@builtin(global_invocation_id) gid: vec3u) {
+ let idx = gid.x;
+ let N = params.N;
+ let numHeads = params.numHeads;
+ let headDim = params.headDim;
+ let D = params.D;
+ let totalScores = numHeads * N * N;
+
+ if (idx >= totalScores) { return; }
+
+ let head = idx / (N * N);
+ let remainder = idx % (N * N);
+ let qi = remainder / N;
+ let ki = remainder % N;
+
+ let headOffset = head * headDim;
+
+ var dot = 0.0;
+ for (var d = 0u; d < headDim; d++) {
+ dot += qBuf[qi * D + headOffset + d] * kBuf[ki * D + headOffset + d];
+ }
+
+ scoreBuf[idx] = dot * params.scale;
+}
diff --git a/sample/visionTransformer/shaders/attnSoftmax.wgsl b/sample/visionTransformer/shaders/attnSoftmax.wgsl
new file mode 100644
index 00000000..efe18b82
--- /dev/null
+++ b/sample/visionTransformer/shaders/attnSoftmax.wgsl
@@ -0,0 +1,53 @@
+// Softmax normalization of attention scores.
+//
+// Converts raw attention scores into a probability distribution: for each
+// query token, the attention weights over all key tokens sum to 1.0.
+// Uses the numerically stable form: subtract the row maximum before
+// exponentiating to prevent overflow.
+//
+// The resulting attention weights are also stored to a readback buffer
+// for visualization — they show which image patches the model attends to.
+
+struct Params {
+ N: u32,
+ numHeads: u32,
+ layerIdx: u32,
+}
+
+@group(0) @binding(0) var params: Params;
+@group(0) @binding(1) var scoreBuf: array; // (numHeads, N, N)
+@group(0) @binding(2) var attnWeights: array; // (numLayers, numHeads, N, N)
+
+@compute @workgroup_size(256)
+fn main(@builtin(global_invocation_id) gid: vec3u) {
+ let idx = gid.x;
+ let N = params.N;
+ let numHeads = params.numHeads;
+ let totalRows = numHeads * N;
+
+ if (idx >= totalRows) { return; }
+
+ let base = idx * N;
+
+ // Find max for numerical stability
+ var m = -1e30;
+ for (var i = 0u; i < N; i++) {
+ m = max(m, scoreBuf[base + i]);
+ }
+
+ // Exp and sum
+ var s = 0.0;
+ for (var i = 0u; i < N; i++) {
+ let e = exp(scoreBuf[base + i] - m);
+ scoreBuf[base + i] = e;
+ s += e;
+ }
+
+ // Normalize and store
+ let layerOffset = params.layerIdx * numHeads * N * N;
+ for (var i = 0u; i < N; i++) {
+ let val = scoreBuf[base + i] / s;
+ scoreBuf[base + i] = val;
+ attnWeights[layerOffset + idx * N + i] = val;
+ }
+}
diff --git a/sample/visionTransformer/shaders/layerNorm.wgsl b/sample/visionTransformer/shaders/layerNorm.wgsl
new file mode 100644
index 00000000..a928b1dd
--- /dev/null
+++ b/sample/visionTransformer/shaders/layerNorm.wgsl
@@ -0,0 +1,61 @@
+// Layer normalization: stabilizes training and inference by normalizing
+// each token's activations to zero mean and unit variance, then applying
+// learned scale (gamma) and shift (beta) parameters.
+//
+// Without normalization, activations can drift to extreme values as they
+// pass through many layers, causing numerical instability. Layer norm is
+// applied before attention and before the MLP in each transformer block.
+//
+// Each workgroup processes one token (row). Thread 0 computes mean/variance
+// serially (D=192 is small enough), then all threads normalize in parallel.
+
+struct Params {
+ N: u32, // number of rows (tokens)
+ D: u32, // dimension per row
+ eps: f32,
+}
+
+@group(0) @binding(0) var params: Params;
+@group(0) @binding(1) var input: array;
+@group(0) @binding(2) var gamma: array;
+@group(0) @binding(3) var beta: array;
+@group(0) @binding(4) var output: array;
+
+var shared_mean: f32;
+var shared_inv_std: f32;
+
+@compute @workgroup_size(256)
+fn main(
+ @builtin(workgroup_id) wg_id: vec3u,
+ @builtin(local_invocation_id) local_id: vec3u,
+) {
+ let row = wg_id.x;
+ let tid = local_id.x;
+ let D = params.D;
+ let base = row * D;
+
+ // Thread 0 computes mean and variance
+ if (tid == 0u) {
+ var sum = 0.0;
+ var sq_sum = 0.0;
+ for (var i = 0u; i < D; i++) {
+ let val = input[base + i];
+ sum += val;
+ sq_sum += val * val;
+ }
+ let mean = sum / f32(D);
+ let variance = sq_sum / f32(D) - mean * mean;
+ shared_mean = mean;
+ shared_inv_std = 1.0 / sqrt(variance + params.eps);
+ }
+ workgroupBarrier();
+
+ let mean = shared_mean;
+ let inv_std = shared_inv_std;
+
+ // All threads normalize and apply affine transform in parallel
+ for (var i = tid; i < D; i += 256u) {
+ let val = input[base + i];
+ output[base + i] = (val - mean) * inv_std * gamma[i] + beta[i];
+ }
+}
diff --git a/sample/visionTransformer/shaders/mlp.wgsl b/sample/visionTransformer/shaders/mlp.wgsl
new file mode 100644
index 00000000..2776fbb4
--- /dev/null
+++ b/sample/visionTransformer/shaders/mlp.wgsl
@@ -0,0 +1,71 @@
+// MLP (feed-forward network): the "thinking" part of each transformer block.
+//
+// After attention gathers information from other tokens, the MLP processes
+// each token independently through two linear layers with a GELU activation:
+// hidden = GELU(input @ W1 + b1) — expand from 192 to 768 dims
+// output = hidden @ W2 + b2 — project back from 768 to 192 dims
+//
+// The 4x expansion lets the network learn richer intermediate representations.
+// Two entry points share the same binding layout: 'linearGelu' (fused first
+// layer + activation) and 'linear' (plain second layer).
+
+struct Params {
+ numRows: u32, // number of tokens
+ inDim: u32, // input dimension
+ outDim: u32, // output dimension
+}
+
+@group(0) @binding(0) var params: Params;
+@group(0) @binding(1) var input: array;
+@group(0) @binding(2) var weight: array;
+@group(0) @binding(3) var bias: array;
+@group(0) @binding(4) var output: array;
+
+// Tanh-based GELU approximation. WGSL does not have erf(), so we use the
+// standard approximation from Hendrycks & Gimpel (2016). The max error vs
+// exact GELU is ~3e-4, which is negligible for inference.
+fn gelu(x: f32) -> f32 {
+ let c = 0.7978845608; // sqrt(2/pi)
+ let inner = c * (x + 0.044715 * x * x * x);
+ return 0.5 * x * (1.0 + tanh(inner));
+}
+
+// Linear + GELU: output = GELU(input @ weight + bias)
+@compute @workgroup_size(256)
+fn linearGelu(@builtin(global_invocation_id) gid: vec3u) {
+ let idx = gid.x;
+ let numRows = params.numRows;
+ let inDim = params.inDim;
+ let outDim = params.outDim;
+
+ if (idx >= numRows * outDim) { return; }
+
+ let row = idx / outDim;
+ let col = idx % outDim;
+
+ var val = bias[col];
+ for (var k = 0u; k < inDim; k++) {
+ val += input[row * inDim + k] * weight[k * outDim + col];
+ }
+ output[idx] = gelu(val);
+}
+
+// Plain linear: output = input @ weight + bias
+@compute @workgroup_size(256)
+fn linear(@builtin(global_invocation_id) gid: vec3u) {
+ let idx = gid.x;
+ let numRows = params.numRows;
+ let inDim = params.inDim;
+ let outDim = params.outDim;
+
+ if (idx >= numRows * outDim) { return; }
+
+ let row = idx / outDim;
+ let col = idx % outDim;
+
+ var val = bias[col];
+ for (var k = 0u; k < inDim; k++) {
+ val += input[row * inDim + k] * weight[k * outDim + col];
+ }
+ output[idx] = val;
+}
diff --git a/sample/visionTransformer/shaders/patchEmbed.wgsl b/sample/visionTransformer/shaders/patchEmbed.wgsl
new file mode 100644
index 00000000..7809f50a
--- /dev/null
+++ b/sample/visionTransformer/shaders/patchEmbed.wgsl
@@ -0,0 +1,76 @@
+// Patch embedding compute shader.
+// Takes a 224x224x3 image and produces (197, 192) token embeddings:
+// 196 patches (14x14 grid of 16x16 patches) + 1 CLS token.
+// Each patch is flattened (16*16*3 = 768) then linearly projected to dim 192.
+// The inner loop is a naive dot product (768 iterations per thread) for clarity;
+// a production implementation would use tiled shared-memory matmul.
+
+struct Params {
+ imgSize: u32, // 224
+ patchSize: u32, // 16
+ numPatches: u32, // 196
+ channels: u32, // 3
+ dim: u32, // 192
+}
+
+@group(0) @binding(0) var params: Params;
+@group(0) @binding(1) var image: array; // (224*224*3)
+@group(0) @binding(2) var projWeight: array; // (768, 192) = (patchSize*patchSize*channels, dim)
+@group(0) @binding(3) var projBias: array; // (192)
+@group(0) @binding(4) var clsToken: array; // (192)
+@group(0) @binding(5) var posEmbed: array; // (197, 192)
+@group(0) @binding(6) var output: array; // (197, 192)
+
+// Compute patch embeddings + CLS token + position embeddings
+@compute @workgroup_size(256)
+fn main(
+ @builtin(global_invocation_id) gid: vec3u,
+) {
+ let idx = gid.x;
+ let numTokens = params.numPatches + 1u; // 197
+ let D = params.dim;
+ let totalElements = numTokens * D;
+
+ if (idx >= totalElements) { return; }
+
+ let token = idx / D;
+ let d = idx % D;
+
+ var val = 0.0;
+
+ if (token == 0u) {
+ // CLS token
+ val = clsToken[d];
+ } else {
+ // Patch embedding: flatten patch pixels, project to dim
+ let patchIdx = token - 1u;
+ let patchSize = params.patchSize;
+ let imgSize = params.imgSize;
+ let channels = params.channels;
+ let patchesPerRow = imgSize / patchSize; // 14
+
+ let patchRow = patchIdx / patchesPerRow;
+ let patchCol = patchIdx % patchesPerRow;
+ let startY = patchRow * patchSize;
+ let startX = patchCol * patchSize;
+
+ // Linear projection: sum over flattened patch
+ val = projBias[d];
+ let flatDim = patchSize * patchSize * channels; // 768
+ for (var i = 0u; i < flatDim; i++) {
+ let c = i % channels;
+ let pixelInPatch = i / channels;
+ let py = pixelInPatch / patchSize;
+ let px = pixelInPatch % patchSize;
+ let imgY = startY + py;
+ let imgX = startX + px;
+ let pixelVal = image[(imgY * imgSize + imgX) * channels + c];
+ val += pixelVal * projWeight[i * D + d];
+ }
+ }
+
+ // Add position embedding
+ val += posEmbed[idx];
+
+ output[idx] = val;
+}
diff --git a/sample/visionTransformer/shaders/visualize.wgsl b/sample/visionTransformer/shaders/visualize.wgsl
new file mode 100644
index 00000000..9bd4b3b2
--- /dev/null
+++ b/sample/visionTransformer/shaders/visualize.wgsl
@@ -0,0 +1,80 @@
+// Attention map visualization render shader.
+//
+// Blends the original input image with a viridis-colored heatmap showing
+// attention intensities. The attention map is a 14x14 grid (one value per
+// image patch) that gets bilinearly upsampled to the full image resolution
+// by the texture sampler.
+//
+// The viridis colormap maps low attention (dark purple/blue) to high
+// attention (yellow/green), making it easy to see where the model "looks."
+
+struct VertexOutput {
+ @builtin(position) position: vec4f,
+ @location(0) uv: vec2f,
+}
+
+@vertex
+fn vs(@builtin(vertex_index) vertexIndex: u32) -> VertexOutput {
+ // Full-screen triangle
+ var pos = array(
+ vec2f(-1.0, -1.0),
+ vec2f( 3.0, -1.0),
+ vec2f(-1.0, 3.0),
+ );
+
+ var output: VertexOutput;
+ output.position = vec4f(pos[vertexIndex], 0.0, 1.0);
+ output.uv = (pos[vertexIndex] + 1.0) * 0.5;
+ output.uv.y = 1.0 - output.uv.y; // flip Y for image coordinates
+ return output;
+}
+
+@group(0) @binding(0) var imageTex: texture_2d;
+@group(0) @binding(1) var attnTex: texture_2d;
+@group(0) @binding(2) var linearSampler: sampler;
+
+struct VisParams {
+ overlayAlpha: f32,
+ showOverlay: u32,
+}
+@group(0) @binding(3) var visParams: VisParams;
+
+// Viridis-like colormap
+fn viridis(t: f32) -> vec3f {
+ let c0 = vec3f(0.267, 0.004, 0.329);
+ let c1 = vec3f(0.282, 0.141, 0.458);
+ let c2 = vec3f(0.254, 0.265, 0.530);
+ let c3 = vec3f(0.207, 0.372, 0.553);
+ let c4 = vec3f(0.164, 0.471, 0.558);
+ let c5 = vec3f(0.128, 0.567, 0.551);
+ let c6 = vec3f(0.135, 0.659, 0.518);
+ let c7 = vec3f(0.267, 0.749, 0.441);
+ let c8 = vec3f(0.478, 0.821, 0.318);
+ let c9 = vec3f(0.741, 0.873, 0.150);
+ let c10 = vec3f(0.993, 0.906, 0.144);
+
+ let s = clamp(t, 0.0, 1.0) * 10.0;
+ let i = u32(floor(s));
+ let f = fract(s);
+
+ var colors = array(c0, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10);
+
+ let lo = min(i, 10u);
+ let hi = min(i + 1u, 10u);
+ return mix(colors[lo], colors[hi], f);
+}
+
+@fragment
+fn fs(input: VertexOutput) -> @location(0) vec4f {
+ let imageColor = textureSample(imageTex, linearSampler, input.uv).rgb;
+
+ if (visParams.showOverlay == 0u) {
+ return vec4f(imageColor, 1.0);
+ }
+
+ let attnValue = textureSample(attnTex, linearSampler, input.uv).r;
+ let heatmapColor = viridis(attnValue);
+
+ let blended = mix(imageColor, heatmapColor, visParams.overlayAlpha);
+ return vec4f(blended, 1.0);
+}
diff --git a/sample/visionTransformer/visualize.ts b/sample/visionTransformer/visualize.ts
new file mode 100644
index 00000000..4e00a164
--- /dev/null
+++ b/sample/visionTransformer/visualize.ts
@@ -0,0 +1,220 @@
+// Attention map visualization
+// Extracts CLS token attention over spatial patches, renders as heatmap overlay
+
+import { DEIT_CONFIG } from './weights';
+import visualizeWGSL from './shaders/visualize.wgsl';
+
+const C = DEIT_CONFIG;
+
+export class AttentionVisualizer {
+ private device: GPUDevice;
+ private context: GPUCanvasContext;
+ private pipeline!: GPURenderPipeline;
+ private sampler!: GPUSampler;
+ private imageTexture!: GPUTexture;
+ private attnTexture!: GPUTexture;
+ private visParamsBuffer!: GPUBuffer;
+ private presentationFormat: GPUTextureFormat;
+
+ private currentLayer = 0;
+ private currentHead = 0;
+ private overlayAlpha = 0.6;
+
+ constructor(
+ device: GPUDevice,
+ context: GPUCanvasContext,
+ presentationFormat: GPUTextureFormat
+ ) {
+ this.device = device;
+ this.context = context;
+ this.presentationFormat = presentationFormat;
+ }
+
+ initialize() {
+ const device = this.device;
+
+ const module = device.createShaderModule({ code: visualizeWGSL });
+
+ this.pipeline = device.createRenderPipeline({
+ layout: 'auto',
+ vertex: { module, entryPoint: 'vs' },
+ fragment: {
+ module,
+ entryPoint: 'fs',
+ targets: [{ format: this.presentationFormat }],
+ },
+ primitive: { topology: 'triangle-list' },
+ });
+
+ this.sampler = device.createSampler({
+ magFilter: 'linear',
+ minFilter: 'linear',
+ });
+
+ this.imageTexture = device.createTexture({
+ size: [C.imgSize, C.imgSize],
+ format: 'rgba8unorm',
+ usage:
+ GPUTextureUsage.TEXTURE_BINDING |
+ GPUTextureUsage.COPY_DST |
+ GPUTextureUsage.RENDER_ATTACHMENT,
+ });
+
+ // Attention map texture: 14x14 (one per grid cell), bilinear upsampled by the shader
+ const gridSize = C.imgSize / C.patchSize; // 14
+ this.attnTexture = device.createTexture({
+ size: [gridSize, gridSize],
+ format: 'rgba8unorm',
+ usage: GPUTextureUsage.TEXTURE_BINDING | GPUTextureUsage.COPY_DST,
+ });
+
+ this.visParamsBuffer = device.createBuffer({
+ size: 8,
+ usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST,
+ });
+
+ this.updateVisParams();
+ }
+
+ // Upload the source image as RGBA8
+ uploadImage(rgbaData: Uint8ClampedArray, width: number, height: number) {
+ this.device.queue.writeTexture(
+ { texture: this.imageTexture },
+ rgbaData,
+ { bytesPerRow: width * 4, rowsPerImage: height },
+ { width, height }
+ );
+ }
+
+ // Update attention map from inference results
+ updateAttentionMap(attnWeights: Float32Array, layer: number, head: number) {
+ this.currentLayer = layer;
+ this.currentHead = head;
+
+ const gridSize = C.imgSize / C.patchSize; // 14
+ const N = C.numTokens; // 197
+
+ // Extract CLS token (row 0) attention over spatial tokens (columns 1..196)
+ const offset = (layer * C.numHeads + head) * N * N;
+ const clsAttnRaw = new Float32Array(gridSize * gridSize);
+
+ // Normalize to [0, 1] for visualization
+ let maxVal = 0;
+ for (let i = 0; i < gridSize * gridSize; i++) {
+ clsAttnRaw[i] = attnWeights[offset + i + 1]; // skip CLS-to-CLS attention
+ maxVal = Math.max(maxVal, clsAttnRaw[i]);
+ }
+ if (maxVal > 0) {
+ for (let i = 0; i < clsAttnRaw.length; i++) {
+ clsAttnRaw[i] /= maxVal;
+ }
+ }
+
+ // Pack into RGBA8 (attention value in R channel, rest zero)
+ const rgba = new Uint8Array(gridSize * gridSize * 4);
+ for (let i = 0; i < gridSize * gridSize; i++) {
+ rgba[i * 4] = Math.round(clsAttnRaw[i] * 255);
+ rgba[i * 4 + 1] = 0;
+ rgba[i * 4 + 2] = 0;
+ rgba[i * 4 + 3] = 255;
+ }
+
+ this.device.queue.writeTexture(
+ { texture: this.attnTexture },
+ rgba,
+ { bytesPerRow: gridSize * 4, rowsPerImage: gridSize },
+ { width: gridSize, height: gridSize }
+ );
+ }
+
+ setOverlayAlpha(alpha: number) {
+ this.overlayAlpha = alpha;
+ this.updateVisParams();
+ }
+
+ private updateVisParams() {
+ const data = new ArrayBuffer(8);
+ const view = new DataView(data);
+ view.setFloat32(0, this.overlayAlpha, true);
+ view.setUint32(4, 1, true); // showOverlay always on
+ this.device.queue.writeBuffer(
+ this.visParamsBuffer,
+ 0,
+ new Uint8Array(data)
+ );
+ }
+
+ render() {
+ const encoder = this.device.createCommandEncoder();
+
+ const textureView = this.context.getCurrentTexture().createView();
+
+ const bindGroup = this.device.createBindGroup({
+ layout: this.pipeline.getBindGroupLayout(0),
+ entries: [
+ { binding: 0, resource: this.imageTexture.createView() },
+ { binding: 1, resource: this.attnTexture.createView() },
+ { binding: 2, resource: this.sampler },
+ { binding: 3, resource: { buffer: this.visParamsBuffer } },
+ ],
+ });
+
+ const pass = encoder.beginRenderPass({
+ colorAttachments: [
+ {
+ view: textureView,
+ loadOp: 'clear',
+ storeOp: 'store',
+ clearValue: { r: 0, g: 0, b: 0, a: 1 },
+ },
+ ],
+ });
+
+ pass.setPipeline(this.pipeline);
+ pass.setBindGroup(0, bindGroup);
+ pass.draw(3); // full-screen triangle
+ pass.end();
+
+ this.device.queue.submit([encoder.finish()]);
+ }
+}
+
+export async function loadImageNetLabels(url: string): Promise {
+ const response = await fetch(url);
+ if (!response.ok) {
+ throw new Error(`Failed to load labels: ${response.status}`);
+ }
+ return response.json();
+}
+
+// Get top-K predictions from logits
+export function topK(
+ logits: Float32Array,
+ labels: string[],
+ k: number
+): Array<{ label: string; probability: number; index: number }> {
+ // Softmax
+ let maxLogit = -Infinity;
+ for (let i = 0; i < logits.length; i++) {
+ maxLogit = Math.max(maxLogit, logits[i]);
+ }
+ const exps = new Float32Array(logits.length);
+ let sum = 0;
+ for (let i = 0; i < logits.length; i++) {
+ exps[i] = Math.exp(logits[i] - maxLogit);
+ sum += exps[i];
+ }
+ for (let i = 0; i < exps.length; i++) {
+ exps[i] /= sum;
+ }
+
+ // Find top-K
+ const indices = Array.from({ length: logits.length }, (_, i) => i);
+ indices.sort((a, b) => exps[b] - exps[a]);
+
+ return indices.slice(0, k).map((i) => ({
+ label: labels[i] || `class_${i}`,
+ probability: exps[i],
+ index: i,
+ }));
+}
diff --git a/sample/visionTransformer/weights.ts b/sample/visionTransformer/weights.ts
new file mode 100644
index 00000000..00550bac
--- /dev/null
+++ b/sample/visionTransformer/weights.ts
@@ -0,0 +1,143 @@
+// Weight loading for DeiT-Tiny
+// Binary format (v2):
+// magic (4 bytes): "DEIT"
+// version (u32): 2
+// numTensors (u32)
+// For each tensor:
+// nameLen (u32), name (UTF-8)
+// dtype (u32): 0=fp32, 1=int8
+// ndims (u32), shape (ndims * u32)
+// [if int8: scale (f32)]
+// dataLen (u32), [align to 4], data
+//
+// Int8 tensors are dequantized to fp32 during loading:
+// float_value = int8_value * scale
+
+export interface TensorData {
+ name: string;
+ shape: number[];
+ data: Float32Array;
+}
+
+export interface ModelWeights {
+ tensors: Map;
+}
+
+export async function loadWeights(url: string): Promise {
+ const response = await fetch(url);
+ if (!response.ok) {
+ throw new Error(`Failed to load weights: ${response.status}`);
+ }
+ const buffer = await response.arrayBuffer();
+ const view = new DataView(buffer);
+ let offset = 0;
+
+ const magic = String.fromCharCode(
+ view.getUint8(offset),
+ view.getUint8(offset + 1),
+ view.getUint8(offset + 2),
+ view.getUint8(offset + 3)
+ );
+ offset += 4;
+ if (magic !== 'DEIT') {
+ throw new Error(`Invalid weight file magic: ${magic}`);
+ }
+
+ const version = view.getUint32(offset, true);
+ offset += 4;
+ if (version !== 2) {
+ throw new Error(`Unsupported weight file version: ${version}`);
+ }
+
+ const numTensors = view.getUint32(offset, true);
+ offset += 4;
+
+ const tensors = new Map();
+
+ for (let t = 0; t < numTensors; t++) {
+ const nameLen = view.getUint32(offset, true);
+ offset += 4;
+ const nameBytes = new Uint8Array(buffer, offset, nameLen);
+ const name = new TextDecoder().decode(nameBytes);
+ offset += nameLen;
+
+ const dtype = view.getUint32(offset, true);
+ offset += 4;
+
+ const ndims = view.getUint32(offset, true);
+ offset += 4;
+ const shape: number[] = [];
+ let numElements = 1;
+ for (let d = 0; d < ndims; d++) {
+ const dim = view.getUint32(offset, true);
+ shape.push(dim);
+ numElements *= dim;
+ offset += 4;
+ }
+
+ let scale = 0;
+ if (dtype === 1) {
+ scale = view.getFloat32(offset, true);
+ offset += 4;
+ }
+
+ const dataLen = view.getUint32(offset, true);
+ offset += 4;
+
+ const alignedOffset = (offset + 3) & ~3;
+
+ let data: Float32Array;
+ if (dtype === 1) {
+ // Int8: dequantize to fp32
+ const int8Data = new Int8Array(buffer, alignedOffset, numElements);
+ data = new Float32Array(numElements);
+ for (let i = 0; i < numElements; i++) {
+ data[i] = int8Data[i] * scale;
+ }
+ } else {
+ // FP32: read directly
+ data = new Float32Array(
+ buffer.slice(alignedOffset, alignedOffset + dataLen)
+ );
+ }
+ offset = alignedOffset + dataLen;
+
+ tensors.set(name, { name, shape, data });
+ }
+
+ return { tensors };
+}
+
+export function createTensorBuffer(
+ device: GPUDevice,
+ tensor: TensorData,
+ usage: GPUBufferUsageFlags = GPUBufferUsage.STORAGE
+): GPUBuffer {
+ const buf = device.createBuffer({
+ size: tensor.data.byteLength,
+ usage: usage | GPUBufferUsage.COPY_DST,
+ mappedAtCreation: true,
+ });
+ new Float32Array(buf.getMappedRange()).set(tensor.data);
+ buf.unmap();
+ return buf;
+}
+
+// DeiT-Tiny model configuration.
+// These values define the network architecture. Changing them requires
+// a matching weight file — the shapes of all weight tensors depend on
+// these dimensions.
+export const DEIT_CONFIG = {
+ imgSize: 224,
+ patchSize: 16,
+ numPatches: 196, // (224/16)^2
+ numTokens: 197, // numPatches + 1 (CLS token)
+ channels: 3,
+ dim: 192,
+ numHeads: 3,
+ headDim: 64, // dim / numHeads
+ mlpHiddenDim: 768, // dim * 4
+ numLayers: 12,
+ numClasses: 1000,
+ scale: 1.0 / Math.sqrt(64), // 1/sqrt(headDim)
+} as const;
diff --git a/src/samples.ts b/src/samples.ts
index 1621b444..5e04c887 100644
--- a/src/samples.ts
+++ b/src/samples.ts
@@ -43,6 +43,7 @@ import timestampQuery from '../sample/timestampQuery/meta';
import transparentCanvas from '../sample/transparentCanvas/meta';
import twoCubes from '../sample/twoCubes/meta';
import videoUploading from '../sample/videoUploading/meta';
+import visionTransformer from '../sample/visionTransformer/meta';
import volumeRenderingTexture3D from '../sample/volumeRenderingTexture3D/meta';
import wireframe from '../sample/wireframe/meta';
import worker from '../sample/worker/meta';
@@ -115,6 +116,7 @@ export const pageCategories: PageCategory[] = [
computeBoids,
gameOfLife,
bitonicSort,
+ visionTransformer,
},
},
diff --git a/tools/convert_deit_weights.py b/tools/convert_deit_weights.py
new file mode 100644
index 00000000..0f62630c
--- /dev/null
+++ b/tools/convert_deit_weights.py
@@ -0,0 +1,274 @@
+#!/usr/bin/env python3
+"""
+Convert DeiT-Tiny weights to flat binary format for WebGPU inference.
+
+Usage:
+ python tools/convert_deit_weights.py [--output public/assets/models/deit-tiny-int8.bin]
+
+Downloads facebook/deit-tiny-patch16-224 (Apache 2.0) from HuggingFace,
+quantizes weight matrices to int8 with per-tensor scale factors, and packs
+into a flat binary format consumable by the WebGPU sample.
+
+Format (v2):
+ magic: b"DEIT" (4 bytes)
+ version: u32 (2)
+ numTensors: u32
+ For each tensor:
+ nameLen: u32
+ name: bytes (UTF-8)
+ dtype: u32 (0=fp32, 1=int8)
+ ndims: u32
+ shape: ndims * u32
+ [if int8: scale: f32]
+ dataLen: u32
+ [padding to 4-byte alignment]
+ data: fp32[] or int8[]
+
+ Int8 tensors are dequantized during loading: value = int8 * scale.
+ Weight matrices (2D) are stored as int8; biases, norms, and embeddings
+ are stored as fp32.
+"""
+
+import argparse
+import struct
+import json
+from pathlib import Path
+
+import numpy as np
+import torch
+import timm
+
+
+def convert_qkv_weights(state_dict: dict, layer_idx: int) -> dict:
+ """Transpose and repack combined QKV weight for our WGSL layout.
+
+ PyTorch stores qkv.weight as (3*dim, dim) = (576, 192).
+ Our shader uses buffer offsets to slice Q, K, V and expects each
+ portion in (inDim, outDim) = (192, 192) layout.
+ We transpose each portion separately, then concatenate so buffer
+ offsets still work: [Q.T | K.T | V.T] each (192, 192), total (192, 576)
+ stored contiguously as 3 * dim * dim floats.
+ """
+ prefix = f"blocks.{layer_idx}.attn.qkv"
+ qkv_weight = state_dict[f"{prefix}.weight"] # (3*dim, dim)
+ qkv_bias = state_dict[f"{prefix}.bias"] # (3*dim)
+
+ dim = qkv_weight.shape[1] # 192
+
+ # Split into Q, K, V portions, transpose each, concatenate
+ q_w = qkv_weight[0:dim, :].T.contiguous() # (dim, dim)
+ k_w = qkv_weight[dim:2*dim, :].T.contiguous() # (dim, dim)
+ v_w = qkv_weight[2*dim:3*dim, :].T.contiguous() # (dim, dim)
+ combined = torch.cat([q_w, k_w, v_w], dim=1) # (dim, 3*dim) = (192, 576)
+
+ # Flatten to make buffer offsets work: Q at [0..dim*dim], K at [dim*dim..2*dim*dim], etc.
+ # But cat along dim=1 gives (192, 576) row-major, so element [k, col] is at k*576+col.
+ # With buffer offset splitting into (dim, dim) chunks, Q is bytes [0..dim*dim*4],
+ # which gives elements [0..dim*dim] of the flat array = rows 0..dim of a (dim, dim) matrix.
+ # That's NOT right because (192, 576) flat[0..192*192] is rows 0..192 of width 576.
+
+ # Simpler: store as three separate contiguous (dim, dim) blocks.
+ combined_flat = torch.cat([
+ q_w.reshape(-1),
+ k_w.reshape(-1),
+ v_w.reshape(-1),
+ ]) # (3*dim*dim,)
+
+ return {
+ f"blocks.{layer_idx}.attn.qkv.weight": combined_flat.reshape(3 * dim, dim),
+ f"blocks.{layer_idx}.attn.qkv.bias": qkv_bias,
+ }
+
+
+def reshape_patch_embed(state_dict: dict) -> dict:
+ """Reshape patch embedding conv2d weights to linear projection format.
+
+ DeiT patch_embed.proj is a Conv2d(3, 192, kernel_size=16, stride=16).
+ Weight shape: (192, 3, 16, 16) -> we need (768, 192) for our linear projection.
+ The 768 = 16*16*3 is the flattened patch dimension.
+
+ The WGSL shader indexes patch pixels in HWC order (channel varies fastest):
+ c = i % channels; pixelInPatch = i / channels;
+ py = pixelInPatch / patchSize; px = pixelInPatch % patchSize;
+ So the weight's 768-dim input axis must be in H-W-C order.
+ PyTorch conv2d weight is (out, C, H, W) = C-H-W order.
+ We permute to (out, H, W, C) before flattening.
+ """
+ weight = state_dict["patch_embed.proj.weight"] # (192, 3, 16, 16)
+ bias = state_dict["patch_embed.proj.bias"] # (192,)
+
+ out_dim = weight.shape[0] # 192
+ # Permute from (out, C, H, W) to (out, H, W, C) then flatten spatial dims
+ weight_hwc = weight.permute(0, 2, 3, 1).contiguous() # (192, 16, 16, 3)
+ flat_weight = weight_hwc.reshape(out_dim, -1).T.contiguous() # (768, 192)
+
+ return {
+ "patch_embed.proj.weight": flat_weight,
+ "patch_embed.proj.bias": bias,
+ }
+
+
+def export_imagenet_labels(output_path: Path):
+ """Export ImageNet class labels as JSON."""
+ # timm includes imagenet label mapping
+ from timm.data import ImageNetInfo
+ info = ImageNetInfo()
+ labels = []
+ for i in range(1000):
+ desc = info.index_to_description(i)
+ # Use the human-readable description
+ labels.append(desc.split(",")[0].strip())
+
+ label_path = output_path.parent / "imagenet-labels.json"
+ with open(label_path, "w") as f:
+ json.dump(labels, f)
+ print(f"Wrote {len(labels)} labels to {label_path}")
+
+
+def write_tensor(f, name: str, tensor: torch.Tensor, quantize: bool = True):
+ """Write a single tensor in our flat binary format.
+
+ If quantize=True and tensor has >1 element, stores as int8 with a per-tensor
+ scale factor. Biases and small tensors (e.g., CLS token) are stored as fp32.
+ The JS loader dequantizes: float_value = int8_value * scale.
+ """
+ data = tensor.float().contiguous().numpy()
+ name_bytes = name.encode("utf-8")
+
+ # Quantize weight matrices (2D+) to int8; keep biases, norms, embeddings as fp32
+ use_int8 = quantize and data.ndim >= 2
+ if use_int8:
+ abs_max = np.abs(data).max()
+ scale = abs_max / 127.0 if abs_max > 0 else 1.0
+ quantized = np.clip(np.round(data / scale), -128, 127).astype(np.int8)
+ dtype_flag = 1 # int8
+ else:
+ dtype_flag = 0 # fp32
+
+ # nameLen + name
+ f.write(struct.pack("