CS336 Lecture 6: GPU 性能、Kernel 手写与 Nsight 分析
CS336 Lecture 6 学习讲义:GPU 性能、Kernel 手写与 Nsight 分析
副标题:从 PyTorch 算子到底层 CUDA / Triton Kernel,再到 NVIDIA Nsight 性能分析
适用对象:希望入门 AI Infra、能独立手写基础 CUDA/Triton kernel 的同学
学完这节课你能做什么
- 看懂 GPU 执行模型:SM、thread/block/grid、shared memory、arithmetic intensity。
- 会写 benchmark / profile:用
torch.profiler和time.time()做可复现的测速。 - 能手写 CUDA kernel:从 C++ 源码到
torch.utils.cpp_extension.load_inline加载。 - 能手写 Triton kernel:element-wise(GeLU)和 row-wise reduction(softmax)。
- 会用
torch.compile自动融合算子。 - 会用 Nsight Systems (
nsys) 和 Nsight Compute (ncu) 分析性能瓶颈。
运行环境
本 Notebook 使用 Note/lec6/ 下的独立 uv 环境,已安装:
- Python 3.11
- PyTorch nightly (cu128,支持 RTX 5060 Blackwell)
- Triton 3.7
- JupyterLab、matplotlib、numpy、nvtx
- 系统自带
nsys/ncu
如果当前没有 GPU,所有 CUDA 相关 cell 会自动跳过,Triton 相关代码需要在 GPU 环境下运行。
Part 0:环境检查
先确认 PyTorch、CUDA、Triton、Nsight 工具是否就绪。
import sys
import subprocess
import torch
import triton
print(f"Python: {sys.version.split()[0]}")
print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f"CUDA version: {torch.version.cuda}")
props = torch.cuda.get_device_properties(0)
print(f"GPU: {props.name}")
print(f"SM count: {props.multi_processor_count}")
print(f"Total memory: {props.total_memory / 1e9:.2f} GB")
print(f"Compute capability: {props.major}.{props.minor}")
print(f"Triton: {triton.__version__}")
for cmd, label in [("nsys", "Nsight Systems"), ("ncu", "Nsight Compute")]:
try:
out = subprocess.run([cmd, "--version"], capture_output=True, text=True, check=True).stdout.strip()
print(f"{label}: {out.split(chr(10))[0]}")
except Exception as e:
print(f"{label}: not available ({e})")
Part 1:GPU 硬件与执行模型
1.1 硬件层次
| 层级 | A100 示例 | 作用 |
|---|---|---|
| Streaming Multiprocessors (SMs) | 108 个 | 实际做计算的单位 |
| DRAM (HBM) | 80 GB | 大容量显存,带宽高但延迟高 |
| L2 Cache | 40 MB | 全局共享,比 DRAM 快 |
| L1 Cache / Shared Memory | 192 KB / SM | 每个 SM 内,可被同 block 线程共享 |
1.2 执行层次
GPU 编程的核心抽象是 “对大量索引并行执行同一个函数 f(i)”。
- Thread:执行
f(i)的最小单元。 - Thread Block(CTA):一组线程,运行在同一个 SM 上,内部可共享 shared memory 并同步。
- Grid:所有 thread block 的集合。
1.3 关键概念
- Wave quantization:如果 grid 里的 block 数不能整除 SM 数,最后一 wave 会空闲部分 SM。
- Arithmetic intensity:
FLOPs / byte,越高越偏向 compute-bound,越低越偏向 memory-bound。 - 一般规律:矩阵乘法是 compute-bound,大部分其他操作(激活、归一化、softmax)是 memory-bound。
def print_gpu_specs():
if not torch.cuda.is_available():
print("No CUDA available")
return
num = torch.cuda.device_count()
print(f"Devices: {num}")
for i in range(num):
p = torch.cuda.get_device_properties(i)
print(f" [{i}] {p.name}")
print(f" SMs: {p.multi_processor_count}, Memory: {p.total_memory/1e9:.2f} GB")
print(f" Compute capability: {p.major}.{p.minor}")
print(f" L2 cache size: {p.L2_cache_size / 1024**2:.1f} MB")
print_gpu_specs()
1.4 思考题
- 为什么矩阵乘法是 compute-bound,而 GeLU 是 memory-bound?
- 对于一个 element-wise 操作,增加 block 数是否能无限加速?瓶颈在哪里?
Part 2:Benchmarking —— 先会测速,再谈优化
优化的第一步永远是 测量。没有测量,所有优化都是盲猜。
一个好的 benchmark 函数应该:
- Warmup:先跑几次,排除编译、缓存冷启动影响。
- Synchronize:CUDA 是异步的,必须
torch.cuda.synchronize()才能拿到真实 wall-clock 时间。 - 多次试验:取平均,减少噪声。
from src.helpers import benchmark
# 简单示例:sleep 50ms
import time
benchmark("sleep", lambda: time.sleep(0.05), num_warmups=0, num_trials=3)
2.1 用 MLP 做 scaling 实验
下面定义一个简单 MLP:Linear -> GeLU -> Linear -> GeLU -> ... -> Linear -> GeLU。
我们通过改变 step / layers / batch / dim 来观察时间如何变化。
import torch.nn as nn
class MLP(nn.Module):
# Simple MLP: linear -> GeLU -> linear -> GeLU -> ... -> linear -> GeLU
def __init__(self, dim: int, num_layers: int):
super().__init__()
self.layers = nn.ModuleList([nn.Linear(dim, dim) for _ in range(num_layers)])
def forward(self, x):
for layer in self.layers:
x = layer(x)
x = torch.nn.functional.gelu(x)
return x
def run_mlp(dim: int, num_layers: int, batch_size: int, num_steps: int):
device = "cuda" if torch.cuda.is_available() else "cpu"
model = MLP(dim, num_layers).to(device)
x = torch.randn(batch_size, dim, device=device)
def fn():
for _ in range(num_steps):
y = model(x).mean()
y.backward()
return fn
# baseline
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Running on {device}")
baseline = benchmark("mlp_baseline", run_mlp(dim=256, num_layers=4, batch_size=256, num_steps=2))
# scaling experiments
import matplotlib.pyplot as plt
def scaling_experiment(name, vary_fn, scales):
# vary_fn(scale) -> (label, fn)
results = []
for s in scales:
label, fn = vary_fn(s)
t = benchmark(label, fn, num_warmups=2, num_trials=3)
results.append((s, t))
return results
scales = [1, 2, 3, 4, 5]
# 1) scale num_steps
step_results = scaling_experiment(
"steps",
lambda s: (f"mlp({s}x steps)", run_mlp(256, 4, 256, 2 * s)),
scales,
)
# 2) scale num_layers
layer_results = scaling_experiment(
"layers",
lambda s: (f"mlp({s}x layers)", run_mlp(256, 4 * s, 256, 2)),
scales,
)
# 3) scale batch_size
batch_results = scaling_experiment(
"batch",
lambda s: (f"mlp({s}x batch)", run_mlp(256, 4, 256 * s, 2)),
scales,
)
# 4) scale dim
dim_results = scaling_experiment(
"dim",
lambda s: (f"mlp({s}x dim)", run_mlp(256 * s, 4, 256, 2)),
scales,
)
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
for ax, (title, data) in zip(axes.flat, [
("Scale steps", step_results),
("Scale layers", layer_results),
("Scale batch", batch_results),
("Scale dim", dim_results),
]):
xs, ys = zip(*data)
ax.plot(xs, ys, marker='o')
ax.set_xlabel("Scale factor")
ax.set_ylabel("Time (ms)")
ax.set_title(title)
ax.grid(True)
plt.tight_layout()
plt.show()
2.2 观察与结论
- 增加
num_steps:时间近似线性增长,因为每次都要完整跑前向+反向。 - 增加
num_layers:线性增长,但常数比 steps 大(layer 越多,需要存储的激活和梯度越多)。 - 增加
batch_size:通常次线性或接近线性,取决于是否达到 compute-bound。 - 增加
dim:往往会超线性增长,因为计算量 O(d³) 级别(矩阵乘法主导)。
实践建议:每次改动模型或训练代码后,都跑一组 scaling 实验,确认性能变化符合预期。
Part 3:Profiling —— 知道时间花在哪里
Benchmarking 只看总时间;Profiling 告诉我们 总时间 = 哪些 kernel / 哪些 Python 调用 组成的。
PyTorch 内置了 torch.profiler,可以输出每个 op 的 CPU/CUDA 时间。
from src.helpers import profile_table
def add_fn():
a = torch.randn(1024, 1024, device="cuda" if torch.cuda.is_available() else "cpu")
b = torch.randn(1024, 1024, device=a.device)
return a + b
profile_table("add", add_fn)
def matmul_fn():
a = torch.randn(1024, 1024, device="cuda" if torch.cuda.is_available() else "cpu")
b = torch.randn(1024, 1024, device=a.device)
return a @ b
profile_table("matmul", matmul_fn)
3.1 解读 kernel 名称
例如 cutlass_80_simt_sgemm_256x128_8x4_nn_align1:
cutlass:NVIDIA 的 CUDA linear algebra 库。80:对应 SM 8.0(Ampere)。simt_sgemm:单精度通用矩阵乘。256x128:tile 大小。nn:A 和 B 都不转置(non-transpose)。
通过 kernel 名称,可以反推出 PyTorch 实际调了哪个底层实现。
# 用 with_stack=True 可以导出火焰图
if torch.cuda.is_available():
profile_table(
"mlp",
run_mlp(dim=2048, num_layers=64, batch_size=1024, num_steps=2),
with_stack=True,
)
else:
print("CUDA not available, skip flame graph profiling.")
3.2 火焰图导出
torch.profiler 支持 export_stacks(path, metric) 导出文本栈,可以用其他工具转成 SVG。
prof.export_stacks("var/stacks_mlp.txt", "self_cuda_time_total")
导出后可用 stackvis 或 FlameGraph 工具渲染:
# 需要安装 FlameGraph 工具
git clone https://github.com/brendangregg/FlameGraph.git
./FlameGraph/flamegraph.pl var/stacks_mlp.txt > var/stacks_mlp.svg
Part 4:Kernel Fusion —— 减少读写就是加速
4.1 Warehouse / Factory 类比
- DRAM = 仓库:大、慢、便宜。
- SRAM = 工厂:小、快、贵。
每个 PyTorch 算子都要:从仓库读原料 → 在工厂加工 → 写回仓库。如果有多个算子,就会反复搬运。
Kernel fusion:把多个算子合并成一个 kernel,原料只读一次、结果只写一次。
4.2 GeLU 案例
PyTorch 的 F.gelu 是融合实现;我们可以手写一个非融合版本对比。
def manual_gelu(x: torch.Tensor) -> torch.Tensor:
# 手写 GeLU(非融合,会产生多个 kernel)
return 0.5 * x * (1 + torch.tanh(0.79788456 * (x + 0.044715 * x * x * x)))
def pytorch_gelu(x: torch.Tensor) -> torch.Tensor:
# PyTorch 融合 GeLU
return torch.nn.functional.gelu(x, approximate="tanh")
device = "cuda" if torch.cuda.is_available() else "cpu"
x = torch.randn(1024, 1024, device=device)
print("Correctness:", torch.allclose(pytorch_gelu(x), manual_gelu(x), atol=1e-5))
benchmark("manual_gelu", lambda: manual_gelu(x))
benchmark("pytorch_gelu", lambda: pytorch_gelu(x))
profile_table("manual_gelu", lambda: manual_gelu(x))
profile_table("pytorch_gelu", lambda: pytorch_gelu(x))
4.3 结论
manual_gelu会触发多个 CUDA kernel:乘、加、tanh、乘…pytorch_gelu通常只触发一个融合 kernel。- 在 memory-bound 场景下,融合版明显更快。
Part 5:手写 CUDA Kernel
CUDA 是 C/C++ 的扩展,核心思想是:写一个线程要执行的函数,然后启动 N 个线程。
5.1 线程索引
blockIdx.x:当前线程所在的 block 编号。threadIdx.x:当前线程在 block 内的编号。blockDim.x:每个 block 的线程数。- 全局线程 ID =
blockIdx.x * blockDim.x + threadIdx.x。
5.2 用 load_inline 编译 CUDA
torch.utils.cpp_extension.load_inline 可以让我们在 Python 里写一段 CUDA C++ 代码,即时编译成一个 Python 模块。
我们的 src/gelu.cu 已经写好了,直接加载。
if torch.cuda.is_available():
from torch.utils.cpp_extension import load_inline
cuda_mod = load_inline(
name="inline_gelu",
cuda_sources=[open("src/gelu.cu").read()],
cpp_sources=["torch::Tensor gelu(torch::Tensor x);"],
functions=["gelu"],
extra_cflags=["-O2"],
verbose=False,
)
cuda_gelu = cuda_mod.gelu
print("CUDA gelu module loaded.")
else:
print("CUDA not available, skip CUDA kernel loading.")
if torch.cuda.is_available():
x_f = x.float()
y_cuda = cuda_gelu(x_f)
y_ref = pytorch_gelu(x_f)
print("CUDA gelu correct:", torch.allclose(y_cuda, y_ref, atol=1e-5))
print("Max diff:", (y_cuda - y_ref).abs().max().item())
if torch.cuda.is_available():
benchmark("manual_gelu", lambda: manual_gelu(x))
benchmark("pytorch_gelu", lambda: pytorch_gelu(x))
benchmark("cuda_gelu", lambda: cuda_gelu(x.float()))
5.3 gelu.cu 逐行解析
__global__ void gelu_kernel(const float* __restrict__ x,
float* __restrict__ y,
int n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x; // 全局线程 ID
int stride = blockDim.x * gridDim.x; // 所有线程一次能跨多少元素
for (int i = idx; i < n; i += stride) { // grid-stride loop
float xi = x[i];
float a = 0.79788456f * (xi + 0.044715f * xi * xi * xi);
y[i] = 0.5f * xi * (1.0f + tanhf(a));
}
}
__global__:CPU 调用、GPU 执行的函数。__restrict__:告诉编译器指针不别名,便于优化。- grid-stride loop:即使 block 数不足以覆盖所有元素,每个线程也能处理多个位置。
Part 6:手写 Triton Kernel
Triton 由 OpenAI 开发,目标是用 Python 写高性能 GPU kernel,同时隐藏线程级细节。
| 能力 | CUDA | Triton |
|---|---|---|
| Memory coalescing | 手动 | 自动 |
| Shared memory 管理 | 手动 | 自动 |
| SM 内调度 | 手动 | 自动 |
| SM 间调度(grid/block) | 手动 | 手动 |
Triton 让我们专注于 “每个 block 做什么”,而不是每个 thread 做什么。
import triton
import triton.language as tl
@triton.jit
def triton_gelu_kernel(x_ptr, y_ptr, n, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(axis=0) # block ID
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE) # 这个 block 负责的元素索引
mask = offsets < n # 边界 mask
x = tl.load(x_ptr + offsets, mask=mask) # 从 global memory 读取
# GeLU tanh 近似
a = 0.79788456 * (x + 0.044715 * x * x * x)
exp = tl.exp(2 * a)
tanh = (exp - 1) / (exp + 1)
y = 0.5 * x * (1 + tanh)
tl.store(y_ptr + offsets, y, mask=mask) # 写回 global memory
def triton_gelu(x: torch.Tensor) -> torch.Tensor:
y = torch.empty_like(x)
n = x.numel()
BLOCK_SIZE = 1024
grid = (triton.cdiv(n, BLOCK_SIZE),)
triton_gelu_kernel[grid](x, y, n, BLOCK_SIZE=BLOCK_SIZE)
return y
if torch.cuda.is_available():
y_triton = triton_gelu(x)
print("Triton gelu correct:", torch.allclose(y_triton, pytorch_gelu(x), atol=1e-5))
print("Max diff:", (y_triton - pytorch_gelu(x)).abs().max().item())
if torch.cuda.is_available():
benchmark("manual_gelu", lambda: manual_gelu(x))
benchmark("pytorch_gelu", lambda: pytorch_gelu(x))
benchmark("cuda_gelu", lambda: cuda_gelu(x.float()))
benchmark("triton_gelu", lambda: triton_gelu(x))
6.1 看 Triton 生成的 PTX
Triton 编译后会生成 PTX(类似 GPU 汇编)。我们可以查看关键指令:
ld.global.*/st.global.*:读写 global memory(DRAM)。%ctaid.x:block 索引。%tid.x:thread 索引。- thread coarsening:一个 thread 处理多个元素。
# 获取 Triton kernel 的 PTX(Triton API 版本间差异较大,这里用环境变量方式导出)
import os
if torch.cuda.is_available():
# 设置环境变量导出编译产物
os.environ["TRITON_DUMP_ASM"] = "1"
os.environ["TRITON_CACHE_DIR"] = os.path.abspath("var/triton_cache")
os.makedirs("var/triton_cache", exist_ok=True)
# 触发编译
_ = triton_gelu(x)
print("PTX 已导出到 var/triton_cache,可用以下命令查看:")
print(" find var/triton_cache -name '*.ptx' | head -1 | xargs cat | head -80")
else:
print("CUDA not available, skip PTX dump.")
Part 7:torch.compile —— 不用手写 kernel 也能融合
PyTorch 2.0 引入了 torch.compile,可以把 Python 写的多个算子自动融合成更高效的 kernel。
对于 element-wise 操作,它通常能接近手写 Triton 的性能。
compiled_gelu = torch.compile(manual_gelu)
if torch.cuda.is_available():
print("compiled_gelu correct:", torch.allclose(compiled_gelu(x), pytorch_gelu(x), atol=1e-5))
if torch.cuda.is_available():
benchmark("manual_gelu", lambda: manual_gelu(x))
benchmark("pytorch_gelu", lambda: pytorch_gelu(x))
benchmark("cuda_gelu", lambda: cuda_gelu(x.float()))
benchmark("triton_gelu", lambda: triton_gelu(x))
benchmark("compiled_gelu", lambda: compiled_gelu(x))
7.1 结论
| 实现 | 难度 | 速度 | 适用场景 |
|---|---|---|---|
| PyTorch fused | 最简单 | 最快之一 | 日常首选 |
| manual | 简单 | 最慢 | 理解算子组成 |
| CUDA | 最难 | 可控 | 需要极致优化 |
| Triton | 中等 | 接近 PyTorch | 自定义融合算子 |
| torch.compile | 简单 | 接近 Triton | 快速验证自动优化 |
Part 8:Triton Softmax —— row-wise reduction
Softmax 是典型的 row-wise reduction 操作:对矩阵每一行做 exp(x - max) / sum(exp(...))。
8.1 Naive 实现的访存代价
x_max = x.max(dim=1) # MN reads, M writes
x = x - x_max # MN + M reads, MN writes
num = exp(x) # MN reads, MN writes
den = num.sum(dim=1) # MN reads, M writes
y = num / den # MN reads, MN writes
总共约 5MN + M reads, 3MN + 2M writes。理想情况下只需 MN reads + MN writes。
Triton 可以把整个 row 放进一个 block,一次完成 max、sub、exp、sum、div。
def manual_softmax(x: torch.Tensor) -> torch.Tensor:
x_max = x.max(dim=1, keepdim=True)[0]
x = x - x_max
num = torch.exp(x)
den = num.sum(dim=1, keepdim=True)
return num / den
def pytorch_softmax(x: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.softmax(x, dim=-1)
@triton.jit
def triton_softmax_kernel(x_ptr, y_ptr, x_row_stride, y_row_stride, n_cols, BLOCK_SIZE: tl.constexpr):
row_idx = tl.program_id(0)
col_offsets = tl.arange(0, BLOCK_SIZE)
# 读取一整行(超出 n_cols 的部分用 -inf 填充)
x_row = tl.load(
x_ptr + row_idx * x_row_stride + col_offsets,
mask=col_offsets < n_cols,
other=float("-inf"),
)
# 在线内做 softmax
x_row = x_row - tl.max(x_row, axis=0)
num = tl.exp(x_row)
den = tl.sum(num, axis=0)
tl.store(
y_ptr + row_idx * y_row_stride + col_offsets,
num / den,
mask=col_offsets < n_cols,
)
def triton_softmax(x: torch.Tensor) -> torch.Tensor:
M, N = x.shape
y = torch.empty_like(x)
BLOCK_SIZE = triton.next_power_of_2(N)
triton_softmax_kernel[(M,)](
x, y,
x.stride(0), y.stride(0),
N, BLOCK_SIZE=BLOCK_SIZE,
)
return y
if torch.cuda.is_available():
x2 = torch.randn(1024, 1024, device="cuda")
print("manual correct:", torch.allclose(manual_softmax(x2), pytorch_softmax(x2), atol=1e-5))
print("triton correct:", torch.allclose(triton_softmax(x2), pytorch_softmax(x2), atol=1e-5))
if torch.cuda.is_available():
benchmark("manual_softmax", lambda: manual_softmax(x2))
benchmark("pytorch_softmax", lambda: pytorch_softmax(x2))
benchmark("triton_softmax", lambda: triton_softmax(x2))
8.2 Softmax 关键技巧
- ** numerical stability**:先减 max 再 exp,避免指数爆炸。
- 一个 block 处理一行:利用 Triton 自动管理 shared memory,减少多次全局读写。
- mask 处理变长列:
other=float("-inf")让 pad 位置不影响 max。
Part 9:Triton MatMul —— Tiling 与 Shared Memory
矩阵乘法 C = A @ B 是 GPU 上优化最充分的算法之一。
9.1 Naive 访存分析
- C 的每个元素需要 K 次 A 和 K 次 B 的读取。
- 总读取量约
M * K * N * 2,写入M * N。 - 但计算量也是
2 * M * N * KFLOPs,所以 arithmetic intensity 高,属于 compute-bound。
9.2 Tiling 核心思想
把 A 和 B 分成小 block(tile),一次加载到 shared memory:
- 从 DRAM 读 A 的一个 tile 和 B 的一个 tile 到 SRAM。
- 在这个 tile 上做 mini matmul,累加到局部寄存器。
- 重复直到遍历完 K 维。
- 把结果写回 DRAM。
这样 A 和 B 的每个元素从 DRAM 读取次数大大减少。
9.3 L2 Cache 友好的数据布局
Triton 官方 tutorial 里比较了 row-major vs grouped ordering:
- Row-major:按行优先顺序处理 block,可能导致 B 被重复加载。
- Grouped:让相邻 block 共享更多 A/B 数据,减少 L2 miss。
# 先 benchmark PyTorch matmul 作为 baseline
if torch.cuda.is_available():
a = torch.randn(1024, 1024, device="cuda")
b = torch.randn(1024, 1024, device="cuda")
benchmark("pytorch_matmul", lambda: a @ b)
else:
print("CUDA not available, skip matmul benchmark.")
9.4 进阶:Triton matmul 参考实现(可选阅读)
下面的代码来自 Triton 官方 tutorial,做了简化。它是理解 tiling 和 shared memory 的最佳示例。
# 参考 Triton 官方 matmul tutorial 的简化版
@triton.jit
def matmul_kernel(
a_ptr, b_ptr, c_ptr,
M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
accumulator = tl.dot(a, b, accumulator)
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
c = accumulator.to(tl.float16)
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)
def triton_matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
assert a.shape[1] == b.shape[0]
M, K = a.shape
K2, N = b.shape
assert K == K2
c = torch.empty((M, N), device=a.device, dtype=torch.float16)
grid = lambda META: (
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
)
matmul_kernel[grid](
a, b, c,
M, N, K,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
BLOCK_SIZE_M=128, BLOCK_SIZE_N=256,
BLOCK_SIZE_K=64, GROUP_SIZE_M=8,
)
return c
if torch.cuda.is_available():
a16 = torch.randn(1024, 1024, device="cuda", dtype=torch.float16)
b16 = torch.randn(1024, 1024, device="cuda", dtype=torch.float16)
c_ref = a16 @ b16
c_tri = triton_matmul(a16, b16)
print("triton_matmul correct:", torch.allclose(c_tri, c_ref, atol=1e-2))
benchmark("pytorch_matmul", lambda: a16 @ b16)
benchmark("triton_matmul", lambda: triton_matmul(a16, b16))
Part 9.5:分块(Tiling)入门教学与实践
前面 Part 9 提到了 matmul 的 tiling 思想,但那是”看源码学 tiling”。本节从更基础的角度入手:为什么需要分块?分块到底省了什么?如何自己动手写一个最简单的 tiled kernel?
学习目标
- 理解 naive matmul 的访存代价。
- 理解 tiling 如何减少 DRAM 读取次数。
- 能补全一个填空式 Triton tiled matmul kernel。
- 会分析 tile size 对 shared memory 和 occupancy 的影响。
9.5.1 为什么需要分块?
考虑矩阵乘法 ,其中 ,。
Naive 实现:每个 thread 负责计算 的一个元素 。
- 计算 需要读取 的第 行和 的第 列,共 个元素。
- 整个 有 个元素,所以总读取量约为 。
问题是:
- 的第 行被所有计算 的 threads 重复读取。
- 的第 列被所有计算 的 threads 重复读取。
如果 很大,这些重复读取会大量占用 HBM 带宽。
9.5.2 分块核心思想
把矩阵分成大小为 和 的小块(tile):
- 从 HBM 读取 的一个 tile 到 shared memory。
- 从 HBM 读取 的一个 tile 到 shared memory。
- 用 shared memory 里的数据计算一个 partial sum,累加到寄存器。
- 沿 维滑动,重复步骤 1-3。
- 最后把累加结果写回 。
关键节省:
- 的每个 tile 只从 HBM 读取一次,但服务了 个输出元素。
- 的每个 tile 只从 HBM 读取一次,但服务了 个输出元素。
9.5.3 形象类比
想象你在做一道大菜:
- Naive:每次只从冰箱拿一个食材,用完再回去拿。
- Tiling:一次性从冰箱拿出一小篮子食材,在灶台上用完,再回去拿下一篮。
灶台(shared memory)比冰箱(HBM)小很多,但速度快得多。
import matplotlib.pyplot as plt
import matplotlib.patches as patches
def draw_tiling_concept():
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
# Left: naive access pattern
ax = axes[0]
ax.set_title("Naive MatMul: each thread reads full row/column", fontsize=12, fontweight="bold")
for r in range(4):
for c in range(4):
rect = patches.Rectangle((c, r), 0.9, 0.9, facecolor="#e3f2fd", edgecolor="#1976d2")
ax.add_patch(rect)
ax.text(c + 0.45, r + 0.45, f"C_{{{r},{c}}}", ha="center", va="center", fontsize=8)
# Highlight row of A and col of B for C[1,2]
ax.annotate("", xy=(4.2, 1.5), xytext=(0, 1.5), arrowprops=dict(arrowstyle="->", color="#d32f2f", lw=2))
ax.text(4.5, 1.5, "A row i", color="#d32f2f", va="center")
ax.annotate("", xy=(2.5, 4.2), xytext=(2.5, 0), arrowprops=dict(arrowstyle="->", color="#388e3c", lw=2))
ax.text(2.5, 4.5, "B col j", color="#388e3c", ha="center")
ax.set_xlim(-0.5, 5.5)
ax.set_ylim(-0.5, 5)
ax.set_aspect("equal")
ax.axis("off")
# Right: tiled access pattern
ax = axes[1]
ax.set_title("Tiled MatMul: load a tile once, reuse for multiple outputs", fontsize=12, fontweight="bold")
tile_colors = [["#c5e1a5", "#c5e1a5", "#b39ddb", "#b39ddb"],
["#c5e1a5", "#c5e1a5", "#b39ddb", "#b39ddb"],
["#ffcc80", "#ffcc80", "#ef9a9a", "#ef9a9a"],
["#ffcc80", "#ffcc80", "#ef9a9a", "#ef9a9a"]]
for r in range(4):
for c in range(4):
rect = patches.Rectangle((c, r), 0.9, 0.9, facecolor=tile_colors[r][c], edgecolor="#333")
ax.add_patch(rect)
ax.text(c + 0.45, r + 0.45, f"C_{{{r},{c}}}", ha="center", va="center", fontsize=8)
ax.text(2, 4.5, "Each 2x2 tile of A is loaded once and reused", ha="center", fontsize=9)
ax.set_xlim(-0.5, 4.5)
ax.set_ylim(-0.5, 5)
ax.set_aspect("equal")
ax.axis("off")
plt.tight_layout()
plt.show()
draw_tiling_concept()
9.5.4 练习:补全一个 Triton Tiled MatMul Kernel
下面的代码是一个骨架,核心部分(加载 A/B tile、做 dot product)留空了。
你需要补全:
- 计算当前 block 负责输出 的哪个 tile。
- 沿 维循环时,加载 和 的 tile。
- 用
tl.dot计算 partial sum 并累加。
提示:
BLOCK_SIZE_M,BLOCK_SIZE_N,BLOCK_SIZE_K是编译时常量。tl.load(ptrs, mask=..., other=0.0)可以处理边界。tl.dot(a, b, acc)对 FP16/BF16 支持很好。
import triton
import triton.language as tl
@triton.jit
def tiled_matmul_kernel(
a_ptr, b_ptr, c_ptr,
M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
):
"""
每个 block 负责计算 C 的一个 (BLOCK_SIZE_M, BLOCK_SIZE_N) tile。
"""
# TODO 1: 获取当前 block 在 grid 中的 (m, n) 坐标
pid = tl.program_id(0)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
pid_m = pid // num_pid_n # TODO: 改成正确计算
pid_n = pid % num_pid_n # TODO: 改成正确计算
# 当前 block 负责的 C tile 的左上角偏移
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
# 计算 A tile 和 B tile 的指针
a_ptrs = a_ptr + (offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn)
# 初始化 accumulator
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
# 沿 K 维循环
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
# TODO 2: 加载 A tile 和 B tile,注意边界 mask
a = tl.load(a_ptrs, mask=offs_m[:, None] < M and offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K and offs_n[None, :] < N, other=0.0)
# TODO 3: 用 tl.dot 计算 partial sum 并累加到 acc
acc += tl.dot(a, b, acc)
# 指针沿 K 维滑动到下一个 tile
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
# 写回 C
c_ptrs = c_ptr + (offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn)
c_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
tl.store(c_ptrs, acc.to(tl.float16), mask=c_mask)
def tiled_matmul(a: torch.Tensor, b: torch.Tensor, BLOCK_SIZE_M=32, BLOCK_SIZE_N=32, BLOCK_SIZE_K=32):
assert a.shape[1] == b.shape[0]
M, K = a.shape
K2, N = b.shape
assert K == K2
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
grid = (triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N),)
tiled_matmul_kernel[grid](
a, b, c,
M, N, K,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_N=BLOCK_SIZE_N,
BLOCK_SIZE_K=BLOCK_SIZE_K,
)
return c
# 测试正确性
if torch.cuda.is_available():
M, N, K = 256, 256, 256
a = torch.randn(M, K, device="cuda", dtype=torch.float16)
b = torch.randn(K, N, device="cuda", dtype=torch.float16)
c_ref = a @ b
c_tri = tiled_matmul(a, b, BLOCK_SIZE_M=32, BLOCK_SIZE_N=32, BLOCK_SIZE_K=32)
print("Correctness:", torch.allclose(c_tri, c_ref, atol=1e-2, rtol=1e-2))
print("Max diff:", (c_tri - c_ref).abs().max().item())
else:
print("CUDA not available.")
9.5.5 练习 1:改变 BLOCK_SIZE,观察性能变化
不同的 BLOCK_SIZE 会影响:
- shared memory 使用量
- occupancy
- L2 cache 命中率
- 边界检查 overhead
尝试几组参数,记录运行时间。
if torch.cuda.is_available():
M, N, K = 1024, 1024, 1024
a = torch.randn(M, K, device="cuda", dtype=torch.float16)
b = torch.randn(K, N, device="cuda", dtype=torch.float16)
c_ref = a @ b
configs = [
(16, 16, 16),
(32, 32, 32),
(64, 64, 64),
(64, 64, 32),
(128, 128, 32),
]
results = []
for BM, BN, BK in configs:
try:
torch.cuda.empty_cache()
c_tri = tiled_matmul(a, b, BLOCK_SIZE_M=BM, BLOCK_SIZE_N=BN, BLOCK_SIZE_K=BK)
correct = torch.allclose(c_tri, c_ref, atol=1e-2, rtol=1e-2)
t = benchmark(f"tiled_matmul({BM}x{BN}x{BK})", lambda: tiled_matmul(a, b, BLOCK_SIZE_M=BM, BLOCK_SIZE_N=BN, BLOCK_SIZE_K=BK), num_warmups=2, num_trials=5)
results.append((BM, BN, BK, t, correct))
print(f"{BM}x{BN}x{BK}: {t:.3f} ms, correct={correct}")
except Exception as e:
print(f"{BM}x{BN}x{BK} failed: {e}")
# PyTorch baseline
t_ref = benchmark("pytorch_matmul", lambda: a @ b, num_warmups=2, num_trials=5)
print(f"\nPyTorch: {t_ref:.3f} ms")
9.5.6 练习 2:分析 Shared Memory 用量与 Occupancy
一个 block 的 shared memory 用量约为:
shmem_per_block = BLOCK_SIZE_M * BLOCK_SIZE_K * sizeof(dtype) + BLOCK_SIZE_K * BLOCK_SIZE_N * sizeof(dtype)
对于 FP16(2 bytes):
- 32x32x32: 32322 + 32322 = 4 KB
- 64x64x64: 64642 + 64642 = 16 KB
- 128x128x32: 128322 + 321282 = 16 KB
RTX 5060 每个 SM 的 shared memory 大约 100 KB。据此估算:
- 每个 SM 最多能同时驻留多少个 64x64x64 的 block?
- 如果 shared memory 变成瓶颈,occupancy 会受到什么影响?
def shmem_for_matmul(BM, BN, BK, dtype_bytes=2):
"""估算一个 block 的 shared memory 用量(近似)。"""
a_tile = BM * BK * dtype_bytes
b_tile = BK * BN * dtype_bytes
return a_tile + b_tile
print("Shared memory usage per block:")
for BM, BN, BK in [(32, 32, 32), (64, 64, 64), (128, 128, 32), (128, 128, 64)]:
shmem = shmem_for_matmul(BM, BN, BK)
print(f" {BM}x{BN}x{BK}: {shmem / 1024:.1f} KB")
# 假设 SM 有 100 KB shared memory
sm_smem_kb = 100
print(f"\nWith {sm_smem_kb} KB shared memory per SM:")
for BM, BN, BK in [(64, 64, 64), (128, 128, 32), (128, 128, 64)]:
shmem_kb = shmem_for_matmul(BM, BN, BK) / 1024
max_blocks_by_smem = sm_smem_kb // shmem_kb
print(f" {BM}x{BN}x{BK}: max {max_blocks_by_smem} blocks per SM (by shared memory)")
9.5.7 练习 3:对比 Tiled vs PyTorch MatMul
上面的 tiled kernel 只是一个教学实现,没有做很多高级优化(如:
- 向量加载/存储
- 多级 tiling(register-level + shared-level)
- L2 cache 友好的 pid 分组
- warp-level 优化
所以通常比 PyTorch/CUTLASS 慢。这很正常,重点是理解 tiling 的思想。
思考题:
- 为什么增大 BLOCK_SIZE 不总是加速?
- 为什么 128x128x32 可能比 64x64x64 快或慢?
- 如果要继续优化这个 kernel,你会先改哪里?
if torch.cuda.is_available():
sizes = [512, 1024, 2048]
for size in sizes:
a = torch.randn(size, size, device="cuda", dtype=torch.float16)
b = torch.randn(size, size, device="cuda", dtype=torch.float16)
t_ref = benchmark(f"pytorch_matmul({size})", lambda: a @ b, num_warmups=2, num_trials=5)
t_tri = benchmark(f"tiled_matmul({size}, 64x64x32)",
lambda: tiled_matmul(a, b, BLOCK_SIZE_M=64, BLOCK_SIZE_N=64, BLOCK_SIZE_K=32),
num_warmups=2, num_trials=5)
print(f"size={size}: PyTorch={t_ref:.2f} ms, Tiled={t_tri:.2f} ms, ratio={t_tri/t_ref:.2f}x")
else:
print("CUDA not available.")
9.5.8 小结
分块(tiling)是 GPU kernel 优化的核心技巧:
- 问题:naive matmul 重复读取 HBM,memory-bound。
- 思想:把数据切成小块加载到 shared memory,复用多次。
- 关键参数:BLOCK_SIZE_M/N/K 影响 shared memory、occupancy、边界 overhead。
- 现代框架:PyTorch/CUTLASS 的 matmul 本质上就是极其优化的 tiled kernel,通常还有多级 tiling 和 Tensor Core。
下一步:可以尝试把上面的 tiled matmul 改成用 Tensor Core(FP16 下 tl.dot 已经会尝试用 Tensor Core),或者加入 L2 cache 友好的 pid 分组。
Part 10:Nsight Systems / Nsight Compute 实操
10.1 Nsight Systems (nsys)
nsys profile 记录整个应用的 CPU + GPU timeline,适合看:
- 哪些 kernel 占了最多时间。
- CPU 和 GPU 之间有没有空闲间隙(launch latency)。
- CUDA API 调用顺序。
基本命令:
nsys profile --stats=true -o output_name python your_script.py
10.2 Nsight Compute (ncu)
ncu 分析单个 kernel 的微观指标:
sm__warps_active.avg.pct_of_peak_sustained_elapsed:SM 利用率 / occupancy。dram__bytes.sum.per_second:显存带宽利用率。smsp__sass_thread_inst_executed_op_fadd_pred_on.sum:实际执行的 FLOPs。
10.3 实践:分析 GeLU
# 生成一个可独立运行的 profile 脚本
profile_script = r'''
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
import torch
def manual_gelu(x):
return 0.5 * x * (1 + torch.tanh(0.79788456 * (x + 0.044715 * x * x * x)))
def main():
device = "cuda" if torch.cuda.is_available() else "cpu"
x = torch.randn(4096 * 4096, device=device)
for _ in range(3):
y = manual_gelu(x)
if device == "cuda":
torch.cuda.synchronize()
if __name__ == "__main__":
main()
'''
os.makedirs("scripts", exist_ok=True)
with open("scripts/profile_gelu.py", "w") as f:
f.write(profile_script)
print("scripts/profile_gelu.py written.")
# 在 Notebook 中运行 nsys profile(注意:会花一些时间)
import subprocess
if torch.cuda.is_available():
cmd = [
"nsys", "profile", "--stats=true", "--no-wait",
"-o", "var/nsys_gelu",
".venv/bin/python", "scripts/profile_gelu.py",
]
print("Running:", " ".join(cmd))
result = subprocess.run(cmd, capture_output=True, text=True, timeout=300)
print(result.stdout[-3000:])
print(result.stderr[-1000:])
else:
print("CUDA not available, skip nsys profile.")
10.4 如何读 nsys 报告
nsys 输出通常包含:
- CUDA API Statistics:每个 CUDA API 调用的次数和总时间。
- CUDA Kernel Statistics:每个 kernel 的执行次数、平均/最大/总时间。
- Memory Operation Statistics:HtoD/DtoH 传输次数。
重点关注:
- Time(%) 最高的 kernel:这是你的瓶颈。
- CPU 是否在等 GPU:看 timeline 里有没有大片空白。
- kernel launch overhead:小 kernel 太多会导致 launch 开销占比高。
生成的 .nsys-rep 可以用 Nsight Systems GUI 打开:
nsys-ui var/nsys_gelu.nsys-rep
# 用 ncu 分析单个 kernel(需要知道 kernel 名称,可用 nsys 或 torch.profiler 获取)
if torch.cuda.is_available():
cmd = [
"ncu",
"--kernel-name", "regex:gelsu", # GeLU backward 名称示例
"--metrics", "sm__warps_active.avg.pct_of_peak_sustained_elapsed,dram__bytes.sum.per_second",
".venv/bin/python", "scripts/profile_gelu.py",
]
print("Example ncu command:")
print(" ".join(cmd))
print("\\nNote: run this in terminal if you want actual metrics.")
10.5 用 nvtx 标注代码区间
nvtx 可以在代码里打标记,nsys timeline 上就会显示这些区间,方便定位。
import nvtx
if torch.cuda.is_available():
with nvtx.annotate("gelu_loop", color="green"):
for _ in range(5):
_ = manual_gelu(x)
torch.cuda.synchronize()
print("nvtx annotated region executed.")
Part 11:总结与练习
11.1 核心原则
- Measure first:先 benchmark / profile,再优化。
- Minimize reads/writes:kernel fusion 和 tiling 是核心手段。
- Know your bottleneck:compute-bound 提 FLOPs,memory-bound 提访存效率。
- Use the right tool:
- 日常开发:
torch.compile+ PyTorch fused ops。 - 自定义融合:Triton。
- 极致优化或特殊硬件:CUDA。
- 性能分析:
nsys+ncu+torch.profiler。
- 日常开发:
11.2 自测练习
练习 1:手写 ReLU 的 CUDA 和 Triton kernel
要求:
- CUDA:
relu.cu,输入输出都是 float32。 - Triton:
triton_relu_kernel,支持任意长度向量。 - 与
torch.relu对比正确性和速度。
练习 2:手写 row-wise sum 的 Triton kernel
输入 shape (M, N),输出 shape (M,),对每行求和。
提示:和 softmax 类似,但不需要 max/exp。
练习 3:用 nsys 分析训练循环
对下面这个训练循环跑 nsys profile:
model = MLP(512, 8).cuda()
opt = torch.optim.AdamW(model.parameters(), lr=1e-3)
x = torch.randn(128, 512, device="cuda")
for _ in range(20):
opt.zero_grad()
loss = model(x).mean()
loss.backward()
opt.step()
回答:
- 时间占比最高的前 3 个 kernel 是什么?
- 有没有 kernel launch gap?
- 优化器 step 是否占大量时间?
11.3 进阶阅读
- Horace He’s GPU performance intro
- Triton fused softmax tutorial
- Triton matmul tutorial
- CUDA MODE YouTube playlist
- GPU Puzzles
附录:交互式 GPU 架构可视化
本讲还包含一个交互式的 GPU 架构可视化页面,覆盖 SM、Warp Scheduler、SIMT、Occupancy 等主题:
也可以下载原始 Notebook 运行: