TPU Architecture: Understanding the Bandwidth Hierarchy
After mastering roofline analysis, I dove into Part 2: TPU Architecture from the JAX Scaling Book.
This section is all about understanding the bandwidth hierarchy that governs TPU performance.
🏗️ The TPU Stack
A TPU is fundamentally simple:
- MXU (Matrix Multiply Unit) — Systolic array doing
bf16[8,128] @ bf16[128,128]every 8 cycles - VMEM (Vector Memory) — 128 MiB on-chip scratchpad
- HBM (High-Bandwidth Memory) — 16-96 GB depending on generation
- ICI (Inter-Chip Interconnect) — Nearest-neighbor torus topology
- DCN (Datacenter Network) — Host-to-host networking
The Golden Rule: Communication speed decreases ~10x at each level.
📊 The Bandwidth Hierarchy
| Link | Bandwidth | Use Case |
|---|---|---|
| HBM ↔ VMEM | 1-3 TB/s | Loading weights/activations |
| ICI | 90-180 GB/s per link | Cross-chip sharding |
| PCIe | 16-32 GB/s | Host ↔ TPU |
| DCN | 6-12 GB/s | Cross-pod communication |
📝 My Solutions to the Problems
Problem 1: Bounding LLM Latency
Question: How long to load a 200B parameter model (bf16) across 32 TPU v4p?
My Answer:
- Size:
200e9 × 2 bytes = 4e11 bytes - Total HBM BW:
32 × 1.2e12 = 3.84e13 bytes/s - Time:
4e11 / 3.84e13 = **10 ms**
Insight: Even massive models load in milliseconds when parallelized across enough chips.
Problem 2: TPU Pod Specs
Question: Calculate total FLOPs, memory, and hosts for full v5e and v5p pods.
My Answers:
TPU v5e (16×16 = 256 chips):
- FLOPs/s:
5.04e16(50 PetaFLOPs) - HBM:
4 TB - Hosts:
8
TPU v5p (16×20×28 = 8960 chips):
- FLOPs/s:
4.11e18(4 ExaFLOPs!) - HBM:
860 TB - Hosts:
2240
Insight: A single v5p pod is one of the most powerful supercomputers in the world. And Google has many.
Problem 3: PCIe Operational Intensity
Question: What batch size needed to saturate a TPU v6e over PCIe?
Setup: bf16[B,D] @ bf16[D,4D] with PCIe BW = 1.5e10 bytes/s
My Answer:
- Critical AI:
9.20e14 / 1.5e10 = 6.13e4 - Threshold:
B > 6133 tokens
Insight: PCIe is ~100x slower than HBM. If your weights are on the host, you need massive batches to avoid starvation.
Problem 4: General Matmul Latency
Question: Time for int8[16K,4K] @ int8[B,4K] on TPU v5e?
My Answer:
- Compute-bound threshold:
B > 250 - Time (B > 250):
T = B / 2.93e6seconds
From VMEM: If VMEM BW is ~20x HBM → threshold drops to B > 12
Insight: VMEM’s higher bandwidth lets you be compute-bound at 20x smaller batches.
Problem 5: ICI Bandwidth & Torus Topology
Question: Send bf16[8,128,8K] from TPU{0,0} → TPU{3,3} in a 4×4 slice.
My Answer:
- First byte: 6 hops × 1μs = 6 μs
- Full transfer:
- Can split across 2 disjoint paths (torus!)
- Time:
1.7e7 / (2 × 4.5e10) = **189 μs**
Insight: The torus topology effectively doubles bandwidth by enabling multi-path routing.
Problem 6: The Full Pipeline (Hard)
Question: Copy sharded int8[128K,128K] from host DRAM to one TPU, multiply by bf16[8,128K].
My Breakdown:
- DCN transfer (host → chips): 66 ms
- ICI aggregation (chips → TPU{0,0}): 167 ms ⚠️
- HBM → MXU: 19 ms
- Compute: 1.3 ms
Bottleneck: ICI at 167 ms — even for a “small” 15 GB transfer.
Insight: When aggregating data across many chips, ICI bandwidth dominates — orders of magnitude more than compute time.
🧠 What I Learned
- VMEM is underrated: 20x higher bandwidth means 20x smaller batch size thresholds.
- ICI is the real bottleneck: Even “fast” nearest-neighbor links become slow when aggregating across 16+ chips.
- Topology matters: The torus structure isn’t just elegant — it doubles your effective bandwidth.
At Samsung, when we shard models across TPUs, we’re not just dividing work — we’re navigating a complex hierarchy of bandwidths. The difference between 10ms and 200ms latency often comes down to which link the data crosses.
📂 Code & Notes
All my notes and implementations: Model_scaling_jax
Original Q&A: Google Doc
Next up: Part 3 - Sharding Strategies
Understanding the hierarchy is half the battle.
Enjoy Reading This Article?
Here are some more articles you might like to read next: