Sharding Strategies: The Art of Distributed Matrix Multiplication
After understanding TPU architecture, I tackled Part 3: Sharded Matrices from the JAX Scaling Book.
This section is all about the 4 core communication primitives that enable distributed matrix multiplication.
🎯 The Big Idea
When you split a matrix across 64 TPUs and multiply it, you need to:
- Move data between chips (communication)
- Do local math on each chip (computation)
The art is choosing which data to move, when, and how.
🔧 The 4 Core Primitives
| Primitive | What it does | Cost |
|---|---|---|
| AllGather | [Aₓ,B] → [A,B] | V / W_ici |
| ReduceScatter | [A,B]{Uₓ} → [Aₓ,B] | V / W_ici |
| AllReduce | [A,B]{Uₓ} → [A,B] | 2V / W_ici |
| AllToAll | [A,Bₓ] → [Aₓ,B] | V / (4W_ici) |
The Golden Rule: Cost depends only on array size and bandwidth, NOT number of devices!
📝 My Solutions to Key Problems
Problem 4: Strategy Comparison
Question: Multiply X[B,D] · Y[Dₓ,F]. Which is faster?
- AllGather Y, then multiply
- Multiply locally, then AllReduce
My Analysis:
- Strategy 1:
T_comm = 2DF/W,T_comp = 2BDF/W_flop- Compute-bound when
B > W_flop/W_ici - Best for large batches
- Compute-bound when
- Strategy 2:
T_comm = 4BF/W,T_comp = 2BDF/(X·W_flop)- Compute-bound when
D > 2X·W_flop/W_ici - Best for large models
- Compute-bound when
Key Insight: Large batches favor Strategy 1. Large models favor Strategy 2.
Problem 6: Common Sharding Patterns
Question: Analyze 3 real-world sharding schemes on 4×4 TPU v5e.
Pattern 1: A[Iₓ,Jᵧ] · B[Jᵧ,K] → C[Iₓ,K]
- AllGatherᵧ twice →
T_comm = 2JK/(Y·W_ici) - Compute:
2IJK/(XY·W_flop)
Pattern 2: A[Iₓ,J] · B[Jₓ,Kᵧ] → C[Iₓ,Kᵧ] (Training)
- AllGatherₓ on J →
T_comm = JK/(Y·W_ici) - This is Data Parallel + Tensor Parallel + ZeRO
Pattern 3: A[Iₓ,J] · B[J,Kᵧ] → C[Iₓ,Kᵧ] (Inference)
- No communication!
- This is pure Tensor Parallel + Data Parallel
Key Insight: Inference can be communication-free with proper sharding.
Problem 7: Transformer Block Memory Constraints
Question: Shard In[128,8K] · Wᵢₙ[8K,32K] · Wₒᵤₜ[32K,8K] on 2×2 v5e with 300MB/TPU.
My Strategy:
- Problem: Each weight matrix is 536 MB (too big!)
- Solution: Shard weights as
Wᵢₙ[D,Fₓᵧ],Wₒᵤₜ[Fₓᵧ,D] - Trick: Fuse
Wᵢₙ · Wₒᵤₜfirst, then ReduceScatter-
Wᵢₙ · Wₒᵤₜ → W[D,D]{Uₓᵧ}(local multiply) ReduceScatterₓᵧ → W[D,Dₓᵧ]AllGather + In · W → Out[B,Dₓᵧ]
-
Key Insight: Intermediate fusion halves communication vs. sequential matmuls.
Problem 10: Why is AllToAll 4× Faster?
Question: Explain the factor-of-4 speedup for AllToAll vs. AllGather in bidirectional rings.
My Analysis:
Unidirectional:
- AllGather: Every device sends full shard around the ring →
N²scalars total - AllToAll: Each device sends partial shards →
N²/2scalars - Ratio:
2×
Bidirectional:
- AllGather:
D/2hops,2N²/Dbytes/hop →N²total - AllToAll:
D/4hops,N²/Dbytes/hop →N²/4total - Ratio:
4×
Key Insight: Bidirectional AllToAll benefits from halved hops AND halved bytes/hop.
🧠 What I Learned
- Communication ≠ scaling bottleneck: AllGather on 4 chips takes the same time as on 64 chips (bandwidth-bound).
- Strategy beats scale: Choosing AllGather-then-multiply vs. multiply-then-AllReduce can 2× performance.
- Inference is “free”: With
A[Iₓ,J] · B[J,Kᵧ], there’s zero cross-chip communication.
At Samsung, when we shard models, we’re not just dividing work — we’re choosing a communication strategy that determines whether we hit 10ms or 200ms latency.
📂 Code & Notes
All my notes and implementations: Model_scaling_jax
Original Q&A: Google Doc
Next up: Part 4 - Transformer Math
Sharding isn’t about splitting — it’s about orchestrating.
Enjoy Reading This Article?
Here are some more articles you might like to read next: