ReduceOps: Dimension Reduction¶
ReduceOps aggregate values along one or more dimensions, returning a smaller tensor. Tinygrad has 2 primitive ReduceOps — SUM and MAX — from which all others derive.
The Two Primitives¶
SUM¶
from tinygrad import Tensor
x = Tensor([[1, 2, 3],
[4, 5, 6]])
x.sum() # 21 — reduce all
x.sum(axis=0) # [5, 7, 9] — reduce rows
x.sum(axis=1) # [6, 15] — reduce columns
x.sum(axis=1, keepdim=True) # [[6], [15]] — keep dims for broadcasting
axis=0 (sum columns): axis=1 (sum rows):
[[1,2,3], [[1+2+3] = [6]
[4,5,6]] [4+5+6] = [15]]
↓↓↓
[5,7,9]
MAX¶
Derived Reductions¶
| Op | Built From | Example |
|---|---|---|
x.mean() |
x.sum() / x.numel() |
[1,2,3,4] → 2.5 |
x.min() |
-(-x).max() |
[1,5,3,2] → 1 |
x.var() |
((x - mean)**2).mean() |
variance |
x.std() |
x.var().sqrt() |
standard deviation |
x.prod() |
x.log().sum().exp() |
product |
Axes and keepdim¶
x = Tensor.randn(2, 3, 4)
x.sum(axis=0) # (3, 4)
x.sum(axis=1) # (2, 4)
x.sum(axis=2) # (2, 3)
x.sum() # scalar
x.sum(axis=1, keepdim=True) # (2, 1, 4) ← useful for broadcasting
keepdim=True preserves the reduced dimension as size 1, making subsequent broadcasting straightforward.
Common Patterns¶
Pooling¶
# Average pooling (2D)
def avg_pool2d(x, k=2):
b, c, h, w = x.shape
return x.reshape(b, c, h//k, k, w//k, k).mean(axis=(3, 5))
# Max pooling (2D)
def max_pool2d(x, k=2):
b, c, h, w = x.shape
return x.reshape(b, c, h//k, k, w//k, k).max(axis=(3, 5))
# Global average pooling
def gap(x): # (B,C,H,W) → (B,C)
return x.mean(axis=(2, 3))
Normalization¶
# Layer norm
def layer_norm(x, eps=1e-5):
mean = x.mean(axis=-1, keepdim=True)
var = ((x - mean)**2).mean(axis=-1, keepdim=True)
return (x - mean) / (var + eps).sqrt()
# Batch norm (simplified)
def batch_norm(x, eps=1e-5):
mean = x.mean(axis=(0, 2, 3), keepdim=True)
var = x.var(axis=(0, 2, 3), keepdim=True)
return (x - mean) / (var + eps).sqrt()
# RMS norm
def rms_norm(x, eps=1e-6):
rms = (x * x).mean(axis=-1, keepdim=True).add(eps).sqrt()
return x / rms
Attention¶
# Numerically stable softmax
def softmax(x, axis=-1):
max_x = x.max(axis=axis, keepdim=True)
exp_x = (x - max_x).exp()
return exp_x / exp_x.sum(axis=axis, keepdim=True)
# Log-sum-exp
def logsumexp(x, axis=-1):
max_x = x.max(axis=axis, keepdim=True)
return max_x.squeeze(axis) + (x - max_x).exp().sum(axis=axis).log()
Loss Functions¶
mse = ((pred - target)**2).mean()
mae = (pred - target).abs().mean()
def cross_entropy(logits, targets):
log_probs = logits.log_softmax(axis=-1)
return -(targets * log_probs).sum(axis=-1).mean()
Numerical Stability¶
# Unstable softmax (overflow for large x)
def softmax_bad(x):
return x.exp() / x.exp().sum(axis=-1, keepdim=True)
# Stable (subtract max first)
def softmax_stable(x, axis=-1):
x = x - x.max(axis=axis, keepdim=True)
exp_x = x.exp()
return exp_x / exp_x.sum(axis=axis, keepdim=True)
# Always add eps before log/sqrt
safe_log = (x + 1e-8).log()
safe_sqrt = (x.maximum(0) + 1e-8).sqrt()
Performance Notes¶
- ReduceOps break fusion boundaries — elementwise ops before a reduce fuse together, and elementwise ops after a reduce start a new kernel
- Large reductions across many elements are slower; prefer reducing the last (innermost) axis for better cache locality
keepdim=Trueavoids a reshape and is slightly more efficient when you need to broadcast back
Key Takeaways¶
- 2 primitives: SUM and MAX
- All other reductions (MEAN, MIN, VAR, STD, PROD) are derived
axisselects which dimension to collapse;keepdimpreserves shape for broadcasting- Reductions break kernel fusion — each reduce produces a separate kernel
- Numerical stability requires subtracting max before exponentiation in softmax/logsumexp