CS336 Lecture 6: GPU 性能、Kernel 手写与 Nsight 分析 | Feixiang Tao
CS336 Spring 2025 2026-06-15 · 36 min read

CS336 Lecture 6: GPU 性能、Kernel 手写与 Nsight 分析

CS336 Lecture 6 学习讲义:GPU 性能、Kernel 手写与 Nsight 分析

副标题:从 PyTorch 算子到底层 CUDA / Triton Kernel,再到 NVIDIA Nsight 性能分析

适用对象:希望入门 AI Infra、能独立手写基础 CUDA/Triton kernel 的同学

依据材料:CS336 Spring 2025 Lecture 6


学完这节课你能做什么

  1. 看懂 GPU 执行模型:SM、thread/block/grid、shared memory、arithmetic intensity。
  2. 会写 benchmark / profile:用 torch.profilertime.time() 做可复现的测速。
  3. 能手写 CUDA kernel:从 C++ 源码到 torch.utils.cpp_extension.load_inline 加载。
  4. 能手写 Triton kernel:element-wise(GeLU)和 row-wise reduction(softmax)。
  5. 会用 torch.compile 自动融合算子
  6. 会用 Nsight Systems (nsys) 和 Nsight Compute (ncu) 分析性能瓶颈

运行环境

本 Notebook 使用 Note/lec6/ 下的独立 uv 环境,已安装:

  • Python 3.11
  • PyTorch nightly (cu128,支持 RTX 5060 Blackwell)
  • Triton 3.7
  • JupyterLab、matplotlib、numpy、nvtx
  • 系统自带 nsys / ncu

如果当前没有 GPU,所有 CUDA 相关 cell 会自动跳过,Triton 相关代码需要在 GPU 环境下运行。

Part 0:环境检查

先确认 PyTorch、CUDA、Triton、Nsight 工具是否就绪。


import sys
import subprocess

import torch
import triton

print(f"Python: {sys.version.split()[0]}")
print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
    props = torch.cuda.get_device_properties(0)
    print(f"GPU: {props.name}")
    print(f"SM count: {props.multi_processor_count}")
    print(f"Total memory: {props.total_memory / 1e9:.2f} GB")
    print(f"Compute capability: {props.major}.{props.minor}")

print(f"Triton: {triton.__version__}")

for cmd, label in [("nsys", "Nsight Systems"), ("ncu", "Nsight Compute")]:
    try:
        out = subprocess.run([cmd, "--version"], capture_output=True, text=True, check=True).stdout.strip()
        print(f"{label}: {out.split(chr(10))[0]}")
    except Exception as e:
        print(f"{label}: not available ({e})")

Part 1:GPU 硬件与执行模型

1.1 硬件层次

层级A100 示例作用
Streaming Multiprocessors (SMs)108 个实际做计算的单位
DRAM (HBM)80 GB大容量显存,带宽高但延迟高
L2 Cache40 MB全局共享,比 DRAM 快
L1 Cache / Shared Memory192 KB / SM每个 SM 内,可被同 block 线程共享

1.2 执行层次

GPU 编程的核心抽象是 “对大量索引并行执行同一个函数 f(i)”

  • Thread:执行 f(i) 的最小单元。
  • Thread Block(CTA):一组线程,运行在同一个 SM 上,内部可共享 shared memory 并同步。
  • Grid:所有 thread block 的集合。

1.3 关键概念

  • Wave quantization:如果 grid 里的 block 数不能整除 SM 数,最后一 wave 会空闲部分 SM。
  • Arithmetic intensityFLOPs / byte,越高越偏向 compute-bound,越低越偏向 memory-bound。
  • 一般规律:矩阵乘法是 compute-bound,大部分其他操作(激活、归一化、softmax)是 memory-bound。

def print_gpu_specs():
    if not torch.cuda.is_available():
        print("No CUDA available")
        return
    num = torch.cuda.device_count()
    print(f"Devices: {num}")
    for i in range(num):
        p = torch.cuda.get_device_properties(i)
        print(f"  [{i}] {p.name}")
        print(f"       SMs: {p.multi_processor_count}, Memory: {p.total_memory/1e9:.2f} GB")
        print(f"       Compute capability: {p.major}.{p.minor}")
        print(f"       L2 cache size: {p.L2_cache_size / 1024**2:.1f} MB")

print_gpu_specs()

1.4 思考题

  1. 为什么矩阵乘法是 compute-bound,而 GeLU 是 memory-bound?
  2. 对于一个 element-wise 操作,增加 block 数是否能无限加速?瓶颈在哪里?

Part 2:Benchmarking —— 先会测速,再谈优化

优化的第一步永远是 测量。没有测量,所有优化都是盲猜。

一个好的 benchmark 函数应该:

  1. Warmup:先跑几次,排除编译、缓存冷启动影响。
  2. Synchronize:CUDA 是异步的,必须 torch.cuda.synchronize() 才能拿到真实 wall-clock 时间。
  3. 多次试验:取平均,减少噪声。

from src.helpers import benchmark

# 简单示例:sleep 50ms
import time
benchmark("sleep", lambda: time.sleep(0.05), num_warmups=0, num_trials=3)

2.1 用 MLP 做 scaling 实验

下面定义一个简单 MLP:Linear -> GeLU -> Linear -> GeLU -> ... -> Linear -> GeLU

我们通过改变 step / layers / batch / dim 来观察时间如何变化。


import torch.nn as nn


class MLP(nn.Module):
    # Simple MLP: linear -> GeLU -> linear -> GeLU -> ... -> linear -> GeLU
    def __init__(self, dim: int, num_layers: int):
        super().__init__()
        self.layers = nn.ModuleList([nn.Linear(dim, dim) for _ in range(num_layers)])

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
            x = torch.nn.functional.gelu(x)
        return x


def run_mlp(dim: int, num_layers: int, batch_size: int, num_steps: int):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = MLP(dim, num_layers).to(device)
    x = torch.randn(batch_size, dim, device=device)

    def fn():
        for _ in range(num_steps):
            y = model(x).mean()
            y.backward()
    return fn


# baseline
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Running on {device}")
baseline = benchmark("mlp_baseline", run_mlp(dim=256, num_layers=4, batch_size=256, num_steps=2))

# scaling experiments
import matplotlib.pyplot as plt


def scaling_experiment(name, vary_fn, scales):
    # vary_fn(scale) -> (label, fn)
    results = []
    for s in scales:
        label, fn = vary_fn(s)
        t = benchmark(label, fn, num_warmups=2, num_trials=3)
        results.append((s, t))
    return results


scales = [1, 2, 3, 4, 5]

# 1) scale num_steps
step_results = scaling_experiment(
    "steps",
    lambda s: (f"mlp({s}x steps)", run_mlp(256, 4, 256, 2 * s)),
    scales,
)

# 2) scale num_layers
layer_results = scaling_experiment(
    "layers",
    lambda s: (f"mlp({s}x layers)", run_mlp(256, 4 * s, 256, 2)),
    scales,
)

# 3) scale batch_size
batch_results = scaling_experiment(
    "batch",
    lambda s: (f"mlp({s}x batch)", run_mlp(256, 4, 256 * s, 2)),
    scales,
)

# 4) scale dim
dim_results = scaling_experiment(
    "dim",
    lambda s: (f"mlp({s}x dim)", run_mlp(256 * s, 4, 256, 2)),
    scales,
)

fig, axes = plt.subplots(2, 2, figsize=(12, 10))
for ax, (title, data) in zip(axes.flat, [
    ("Scale steps", step_results),
    ("Scale layers", layer_results),
    ("Scale batch", batch_results),
    ("Scale dim", dim_results),
]):
    xs, ys = zip(*data)
    ax.plot(xs, ys, marker='o')
    ax.set_xlabel("Scale factor")
    ax.set_ylabel("Time (ms)")
    ax.set_title(title)
    ax.grid(True)
plt.tight_layout()
plt.show()

2.2 观察与结论

  • 增加 num_steps:时间近似线性增长,因为每次都要完整跑前向+反向。
  • 增加 num_layers:线性增长,但常数比 steps 大(layer 越多,需要存储的激活和梯度越多)。
  • 增加 batch_size:通常次线性或接近线性,取决于是否达到 compute-bound。
  • 增加 dim:往往会超线性增长,因为计算量 O(d³) 级别(矩阵乘法主导)。

实践建议:每次改动模型或训练代码后,都跑一组 scaling 实验,确认性能变化符合预期。

Part 3:Profiling —— 知道时间花在哪里

Benchmarking 只看总时间;Profiling 告诉我们 总时间 = 哪些 kernel / 哪些 Python 调用 组成的。

PyTorch 内置了 torch.profiler,可以输出每个 op 的 CPU/CUDA 时间。


from src.helpers import profile_table


def add_fn():
    a = torch.randn(1024, 1024, device="cuda" if torch.cuda.is_available() else "cpu")
    b = torch.randn(1024, 1024, device=a.device)
    return a + b


profile_table("add", add_fn)

def matmul_fn():
    a = torch.randn(1024, 1024, device="cuda" if torch.cuda.is_available() else "cpu")
    b = torch.randn(1024, 1024, device=a.device)
    return a @ b


profile_table("matmul", matmul_fn)

3.1 解读 kernel 名称

例如 cutlass_80_simt_sgemm_256x128_8x4_nn_align1

  • cutlass:NVIDIA 的 CUDA linear algebra 库。
  • 80:对应 SM 8.0(Ampere)。
  • simt_sgemm:单精度通用矩阵乘。
  • 256x128:tile 大小。
  • nn:A 和 B 都不转置(non-transpose)。

通过 kernel 名称,可以反推出 PyTorch 实际调了哪个底层实现。


# 用 with_stack=True 可以导出火焰图
if torch.cuda.is_available():
    profile_table(
        "mlp",
        run_mlp(dim=2048, num_layers=64, batch_size=1024, num_steps=2),
        with_stack=True,
    )
else:
    print("CUDA not available, skip flame graph profiling.")

3.2 火焰图导出

torch.profiler 支持 export_stacks(path, metric) 导出文本栈,可以用其他工具转成 SVG。

prof.export_stacks("var/stacks_mlp.txt", "self_cuda_time_total")

导出后可用 stackvis 或 FlameGraph 工具渲染:

# 需要安装 FlameGraph 工具
git clone https://github.com/brendangregg/FlameGraph.git
./FlameGraph/flamegraph.pl var/stacks_mlp.txt > var/stacks_mlp.svg

Part 4:Kernel Fusion —— 减少读写就是加速

4.1 Warehouse / Factory 类比

  • DRAM = 仓库:大、慢、便宜。
  • SRAM = 工厂:小、快、贵。

每个 PyTorch 算子都要:从仓库读原料 → 在工厂加工 → 写回仓库。如果有多个算子,就会反复搬运。

Kernel fusion:把多个算子合并成一个 kernel,原料只读一次、结果只写一次。

4.2 GeLU 案例

PyTorch 的 F.gelu 是融合实现;我们可以手写一个非融合版本对比。


def manual_gelu(x: torch.Tensor) -> torch.Tensor:
    # 手写 GeLU(非融合,会产生多个 kernel)
    return 0.5 * x * (1 + torch.tanh(0.79788456 * (x + 0.044715 * x * x * x)))


def pytorch_gelu(x: torch.Tensor) -> torch.Tensor:
    # PyTorch 融合 GeLU
    return torch.nn.functional.gelu(x, approximate="tanh")


device = "cuda" if torch.cuda.is_available() else "cpu"
x = torch.randn(1024, 1024, device=device)
print("Correctness:", torch.allclose(pytorch_gelu(x), manual_gelu(x), atol=1e-5))

benchmark("manual_gelu", lambda: manual_gelu(x))
benchmark("pytorch_gelu", lambda: pytorch_gelu(x))

profile_table("manual_gelu", lambda: manual_gelu(x))
profile_table("pytorch_gelu", lambda: pytorch_gelu(x))

4.3 结论

  • manual_gelu 会触发多个 CUDA kernel:乘、加、tanh、乘…
  • pytorch_gelu 通常只触发一个融合 kernel。
  • 在 memory-bound 场景下,融合版明显更快。

Part 5:手写 CUDA Kernel

CUDA 是 C/C++ 的扩展,核心思想是:写一个线程要执行的函数,然后启动 N 个线程

5.1 线程索引

  • blockIdx.x:当前线程所在的 block 编号。
  • threadIdx.x:当前线程在 block 内的编号。
  • blockDim.x:每个 block 的线程数。
  • 全局线程 ID = blockIdx.x * blockDim.x + threadIdx.x

5.2 用 load_inline 编译 CUDA

torch.utils.cpp_extension.load_inline 可以让我们在 Python 里写一段 CUDA C++ 代码,即时编译成一个 Python 模块。

我们的 src/gelu.cu 已经写好了,直接加载。


if torch.cuda.is_available():
    from torch.utils.cpp_extension import load_inline

    cuda_mod = load_inline(
        name="inline_gelu",
        cuda_sources=[open("src/gelu.cu").read()],
        cpp_sources=["torch::Tensor gelu(torch::Tensor x);"],
        functions=["gelu"],
        extra_cflags=["-O2"],
        verbose=False,
    )
    cuda_gelu = cuda_mod.gelu
    print("CUDA gelu module loaded.")
else:
    print("CUDA not available, skip CUDA kernel loading.")

if torch.cuda.is_available():
    x_f = x.float()
    y_cuda = cuda_gelu(x_f)
    y_ref = pytorch_gelu(x_f)
    print("CUDA gelu correct:", torch.allclose(y_cuda, y_ref, atol=1e-5))
    print("Max diff:", (y_cuda - y_ref).abs().max().item())

if torch.cuda.is_available():
    benchmark("manual_gelu", lambda: manual_gelu(x))
    benchmark("pytorch_gelu", lambda: pytorch_gelu(x))
    benchmark("cuda_gelu", lambda: cuda_gelu(x.float()))

5.3 gelu.cu 逐行解析

__global__ void gelu_kernel(const float* __restrict__ x,
                            float* __restrict__ y,
                            int n) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;  // 全局线程 ID
    int stride = blockDim.x * gridDim.x;              // 所有线程一次能跨多少元素
    for (int i = idx; i < n; i += stride) {           // grid-stride loop
        float xi = x[i];
        float a = 0.79788456f * (xi + 0.044715f * xi * xi * xi);
        y[i] = 0.5f * xi * (1.0f + tanhf(a));
    }
}
  • __global__:CPU 调用、GPU 执行的函数。
  • __restrict__:告诉编译器指针不别名,便于优化。
  • grid-stride loop:即使 block 数不足以覆盖所有元素,每个线程也能处理多个位置。

Part 6:手写 Triton Kernel

Triton 由 OpenAI 开发,目标是用 Python 写高性能 GPU kernel,同时隐藏线程级细节。

能力CUDATriton
Memory coalescing手动自动
Shared memory 管理手动自动
SM 内调度手动自动
SM 间调度(grid/block)手动手动

Triton 让我们专注于 “每个 block 做什么”,而不是每个 thread 做什么。


import triton
import triton.language as tl


@triton.jit
def triton_gelu_kernel(x_ptr, y_ptr, n, BLOCK_SIZE: tl.constexpr):
    pid = tl.program_id(axis=0)               # block ID
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)  # 这个 block 负责的元素索引
    mask = offsets < n                        # 边界 mask

    x = tl.load(x_ptr + offsets, mask=mask)   # 从 global memory 读取

    # GeLU tanh 近似
    a = 0.79788456 * (x + 0.044715 * x * x * x)
    exp = tl.exp(2 * a)
    tanh = (exp - 1) / (exp + 1)
    y = 0.5 * x * (1 + tanh)

    tl.store(y_ptr + offsets, y, mask=mask)   # 写回 global memory


def triton_gelu(x: torch.Tensor) -> torch.Tensor:
    y = torch.empty_like(x)
    n = x.numel()
    BLOCK_SIZE = 1024
    grid = (triton.cdiv(n, BLOCK_SIZE),)
    triton_gelu_kernel[grid](x, y, n, BLOCK_SIZE=BLOCK_SIZE)
    return y

if torch.cuda.is_available():
    y_triton = triton_gelu(x)
    print("Triton gelu correct:", torch.allclose(y_triton, pytorch_gelu(x), atol=1e-5))
    print("Max diff:", (y_triton - pytorch_gelu(x)).abs().max().item())

if torch.cuda.is_available():
    benchmark("manual_gelu", lambda: manual_gelu(x))
    benchmark("pytorch_gelu", lambda: pytorch_gelu(x))
    benchmark("cuda_gelu", lambda: cuda_gelu(x.float()))
    benchmark("triton_gelu", lambda: triton_gelu(x))

6.1 看 Triton 生成的 PTX

Triton 编译后会生成 PTX(类似 GPU 汇编)。我们可以查看关键指令:

  • ld.global.* / st.global.*:读写 global memory(DRAM)。
  • %ctaid.x:block 索引。
  • %tid.x:thread 索引。
  • thread coarsening:一个 thread 处理多个元素。
# 获取 Triton kernel 的 PTX(Triton API 版本间差异较大,这里用环境变量方式导出)
import os

if torch.cuda.is_available():
    # 设置环境变量导出编译产物
    os.environ["TRITON_DUMP_ASM"] = "1"
    os.environ["TRITON_CACHE_DIR"] = os.path.abspath("var/triton_cache")
    os.makedirs("var/triton_cache", exist_ok=True)

    # 触发编译
    _ = triton_gelu(x)
    print("PTX 已导出到 var/triton_cache,可用以下命令查看:")
    print("  find var/triton_cache -name '*.ptx' | head -1 | xargs cat | head -80")
else:
    print("CUDA not available, skip PTX dump.")

Part 7:torch.compile —— 不用手写 kernel 也能融合

PyTorch 2.0 引入了 torch.compile,可以把 Python 写的多个算子自动融合成更高效的 kernel。

对于 element-wise 操作,它通常能接近手写 Triton 的性能。


compiled_gelu = torch.compile(manual_gelu)

if torch.cuda.is_available():
    print("compiled_gelu correct:", torch.allclose(compiled_gelu(x), pytorch_gelu(x), atol=1e-5))

if torch.cuda.is_available():
    benchmark("manual_gelu", lambda: manual_gelu(x))
    benchmark("pytorch_gelu", lambda: pytorch_gelu(x))
    benchmark("cuda_gelu", lambda: cuda_gelu(x.float()))
    benchmark("triton_gelu", lambda: triton_gelu(x))
    benchmark("compiled_gelu", lambda: compiled_gelu(x))

7.1 结论

实现难度速度适用场景
PyTorch fused最简单最快之一日常首选
manual简单最慢理解算子组成
CUDA最难可控需要极致优化
Triton中等接近 PyTorch自定义融合算子
torch.compile简单接近 Triton快速验证自动优化

Part 8:Triton Softmax —— row-wise reduction

Softmax 是典型的 row-wise reduction 操作:对矩阵每一行做 exp(x - max) / sum(exp(...))

8.1 Naive 实现的访存代价

x_max = x.max(dim=1)        # MN reads, M writes
x = x - x_max               # MN + M reads, MN writes
num = exp(x)                # MN reads, MN writes
den = num.sum(dim=1)        # MN reads, M writes
y = num / den               # MN reads, MN writes

总共约 5MN + M reads, 3MN + 2M writes。理想情况下只需 MN reads + MN writes

Triton 可以把整个 row 放进一个 block,一次完成 max、sub、exp、sum、div。


def manual_softmax(x: torch.Tensor) -> torch.Tensor:
    x_max = x.max(dim=1, keepdim=True)[0]
    x = x - x_max
    num = torch.exp(x)
    den = num.sum(dim=1, keepdim=True)
    return num / den


def pytorch_softmax(x: torch.Tensor) -> torch.Tensor:
    return torch.nn.functional.softmax(x, dim=-1)


@triton.jit
def triton_softmax_kernel(x_ptr, y_ptr, x_row_stride, y_row_stride, n_cols, BLOCK_SIZE: tl.constexpr):
    row_idx = tl.program_id(0)
    col_offsets = tl.arange(0, BLOCK_SIZE)

    # 读取一整行(超出 n_cols 的部分用 -inf 填充)
    x_row = tl.load(
        x_ptr + row_idx * x_row_stride + col_offsets,
        mask=col_offsets < n_cols,
        other=float("-inf"),
    )

    # 在线内做 softmax
    x_row = x_row - tl.max(x_row, axis=0)
    num = tl.exp(x_row)
    den = tl.sum(num, axis=0)

    tl.store(
        y_ptr + row_idx * y_row_stride + col_offsets,
        num / den,
        mask=col_offsets < n_cols,
    )


def triton_softmax(x: torch.Tensor) -> torch.Tensor:
    M, N = x.shape
    y = torch.empty_like(x)
    BLOCK_SIZE = triton.next_power_of_2(N)
    triton_softmax_kernel[(M,)](
        x, y,
        x.stride(0), y.stride(0),
        N, BLOCK_SIZE=BLOCK_SIZE,
    )
    return y

if torch.cuda.is_available():
    x2 = torch.randn(1024, 1024, device="cuda")
    print("manual correct:", torch.allclose(manual_softmax(x2), pytorch_softmax(x2), atol=1e-5))
    print("triton correct:", torch.allclose(triton_softmax(x2), pytorch_softmax(x2), atol=1e-5))

if torch.cuda.is_available():
    benchmark("manual_softmax", lambda: manual_softmax(x2))
    benchmark("pytorch_softmax", lambda: pytorch_softmax(x2))
    benchmark("triton_softmax", lambda: triton_softmax(x2))

8.2 Softmax 关键技巧

  1. ** numerical stability**:先减 max 再 exp,避免指数爆炸。
  2. 一个 block 处理一行:利用 Triton 自动管理 shared memory,减少多次全局读写。
  3. mask 处理变长列other=float("-inf") 让 pad 位置不影响 max。

Part 9:Triton MatMul —— Tiling 与 Shared Memory

矩阵乘法 C = A @ B 是 GPU 上优化最充分的算法之一。

9.1 Naive 访存分析

  • C 的每个元素需要 K 次 A 和 K 次 B 的读取。
  • 总读取量约 M * K * N * 2,写入 M * N
  • 但计算量也是 2 * M * N * K FLOPs,所以 arithmetic intensity 高,属于 compute-bound。

9.2 Tiling 核心思想

把 A 和 B 分成小 block(tile),一次加载到 shared memory:

  1. 从 DRAM 读 A 的一个 tile 和 B 的一个 tile 到 SRAM。
  2. 在这个 tile 上做 mini matmul,累加到局部寄存器。
  3. 重复直到遍历完 K 维。
  4. 把结果写回 DRAM。

这样 A 和 B 的每个元素从 DRAM 读取次数大大减少。

9.3 L2 Cache 友好的数据布局

Triton 官方 tutorial 里比较了 row-major vs grouped ordering:

  • Row-major:按行优先顺序处理 block,可能导致 B 被重复加载。
  • Grouped:让相邻 block 共享更多 A/B 数据,减少 L2 miss。

参考:Triton matmul tutorial


# 先 benchmark PyTorch matmul 作为 baseline
if torch.cuda.is_available():
    a = torch.randn(1024, 1024, device="cuda")
    b = torch.randn(1024, 1024, device="cuda")
    benchmark("pytorch_matmul", lambda: a @ b)
else:
    print("CUDA not available, skip matmul benchmark.")

9.4 进阶:Triton matmul 参考实现(可选阅读)

下面的代码来自 Triton 官方 tutorial,做了简化。它是理解 tiling 和 shared memory 的最佳示例。


# 参考 Triton 官方 matmul tutorial 的简化版
@triton.jit
def matmul_kernel(
    a_ptr, b_ptr, c_ptr,
    M, N, K,
    stride_am, stride_ak,
    stride_bk, stride_bn,
    stride_cm, stride_cn,
    BLOCK_SIZE_M: tl.constexpr,
    BLOCK_SIZE_N: tl.constexpr,
    BLOCK_SIZE_K: tl.constexpr,
    GROUP_SIZE_M: tl.constexpr,
):
    pid = tl.program_id(axis=0)
    num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
    num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
    num_pid_in_group = GROUP_SIZE_M * num_pid_n
    group_id = pid // num_pid_in_group
    first_pid_m = group_id * GROUP_SIZE_M
    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
    pid_m = first_pid_m + (pid % group_size_m)
    pid_n = (pid % num_pid_in_group) // group_size_m

    offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
    offs_k = tl.arange(0, BLOCK_SIZE_K)

    a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
    b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)

    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
    for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
        a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
        b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
        accumulator = tl.dot(a, b, accumulator)
        a_ptrs += BLOCK_SIZE_K * stride_ak
        b_ptrs += BLOCK_SIZE_K * stride_bk

    c = accumulator.to(tl.float16)
    offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
    c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
    c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
    tl.store(c_ptrs, c, mask=c_mask)


def triton_matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
    assert a.shape[1] == b.shape[0]
    M, K = a.shape
    K2, N = b.shape
    assert K == K2
    c = torch.empty((M, N), device=a.device, dtype=torch.float16)
    grid = lambda META: (
        triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
    )
    matmul_kernel[grid](
        a, b, c,
        M, N, K,
        a.stride(0), a.stride(1),
        b.stride(0), b.stride(1),
        c.stride(0), c.stride(1),
        BLOCK_SIZE_M=128, BLOCK_SIZE_N=256,
        BLOCK_SIZE_K=64, GROUP_SIZE_M=8,
    )
    return c

if torch.cuda.is_available():
    a16 = torch.randn(1024, 1024, device="cuda", dtype=torch.float16)
    b16 = torch.randn(1024, 1024, device="cuda", dtype=torch.float16)
    c_ref = a16 @ b16
    c_tri = triton_matmul(a16, b16)
    print("triton_matmul correct:", torch.allclose(c_tri, c_ref, atol=1e-2))
    benchmark("pytorch_matmul", lambda: a16 @ b16)
    benchmark("triton_matmul", lambda: triton_matmul(a16, b16))

Part 9.5:分块(Tiling)入门教学与实践

前面 Part 9 提到了 matmul 的 tiling 思想,但那是”看源码学 tiling”。本节从更基础的角度入手:为什么需要分块?分块到底省了什么?如何自己动手写一个最简单的 tiled kernel?

学习目标

  1. 理解 naive matmul 的访存代价。
  2. 理解 tiling 如何减少 DRAM 读取次数。
  3. 能补全一个填空式 Triton tiled matmul kernel。
  4. 会分析 tile size 对 shared memory 和 occupancy 的影响。

9.5.1 为什么需要分块?

考虑矩阵乘法 C=A×BC = A \times B,其中 ARM×KA \in \mathbb{R}^{M \times K}BRK×NB \in \mathbb{R}^{K \times N}

Naive 实现:每个 thread 负责计算 CC 的一个元素 CijC_{ij}

  • 计算 CijC_{ij} 需要读取 AA 的第 ii 行和 BB 的第 jj 列,共 2K2K 个元素。
  • 整个 CCM×NM \times N 个元素,所以总读取量约为 2MNK2 M N K

问题是:

  • AA 的第 ii 行被所有计算 CiC_{i*} 的 threads 重复读取。
  • BB 的第 jj 列被所有计算 CjC_{*j} 的 threads 重复读取。

如果 KK 很大,这些重复读取会大量占用 HBM 带宽。

9.5.2 分块核心思想

把矩阵分成大小为 (BM,BK)(BM, BK)(BK,BN)(BK, BN) 的小块(tile):

  1. 从 HBM 读取 AA 的一个 tile 到 shared memory。
  2. 从 HBM 读取 BB 的一个 tile 到 shared memory。
  3. 用 shared memory 里的数据计算一个 partial sum,累加到寄存器。
  4. 沿 KK 维滑动,重复步骤 1-3。
  5. 最后把累加结果写回 CC

关键节省:

  • AA 的每个 tile 只从 HBM 读取一次,但服务了 BNBN 个输出元素。
  • BB 的每个 tile 只从 HBM 读取一次,但服务了 BMBM 个输出元素。

9.5.3 形象类比

想象你在做一道大菜:

  • Naive:每次只从冰箱拿一个食材,用完再回去拿。
  • Tiling:一次性从冰箱拿出一小篮子食材,在灶台上用完,再回去拿下一篮。

灶台(shared memory)比冰箱(HBM)小很多,但速度快得多。

import matplotlib.pyplot as plt
import matplotlib.patches as patches


def draw_tiling_concept():
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    # Left: naive access pattern
    ax = axes[0]
    ax.set_title("Naive MatMul: each thread reads full row/column", fontsize=12, fontweight="bold")
    for r in range(4):
        for c in range(4):
            rect = patches.Rectangle((c, r), 0.9, 0.9, facecolor="#e3f2fd", edgecolor="#1976d2")
            ax.add_patch(rect)
            ax.text(c + 0.45, r + 0.45, f"C_{{{r},{c}}}", ha="center", va="center", fontsize=8)

    # Highlight row of A and col of B for C[1,2]
    ax.annotate("", xy=(4.2, 1.5), xytext=(0, 1.5), arrowprops=dict(arrowstyle="->", color="#d32f2f", lw=2))
    ax.text(4.5, 1.5, "A row i", color="#d32f2f", va="center")
    ax.annotate("", xy=(2.5, 4.2), xytext=(2.5, 0), arrowprops=dict(arrowstyle="->", color="#388e3c", lw=2))
    ax.text(2.5, 4.5, "B col j", color="#388e3c", ha="center")

    ax.set_xlim(-0.5, 5.5)
    ax.set_ylim(-0.5, 5)
    ax.set_aspect("equal")
    ax.axis("off")

    # Right: tiled access pattern
    ax = axes[1]
    ax.set_title("Tiled MatMul: load a tile once, reuse for multiple outputs", fontsize=12, fontweight="bold")
    tile_colors = [["#c5e1a5", "#c5e1a5", "#b39ddb", "#b39ddb"],
                   ["#c5e1a5", "#c5e1a5", "#b39ddb", "#b39ddb"],
                   ["#ffcc80", "#ffcc80", "#ef9a9a", "#ef9a9a"],
                   ["#ffcc80", "#ffcc80", "#ef9a9a", "#ef9a9a"]]
    for r in range(4):
        for c in range(4):
            rect = patches.Rectangle((c, r), 0.9, 0.9, facecolor=tile_colors[r][c], edgecolor="#333")
            ax.add_patch(rect)
            ax.text(c + 0.45, r + 0.45, f"C_{{{r},{c}}}", ha="center", va="center", fontsize=8)

    ax.text(2, 4.5, "Each 2x2 tile of A is loaded once and reused", ha="center", fontsize=9)
    ax.set_xlim(-0.5, 4.5)
    ax.set_ylim(-0.5, 5)
    ax.set_aspect("equal")
    ax.axis("off")

    plt.tight_layout()
    plt.show()


draw_tiling_concept()

9.5.4 练习:补全一个 Triton Tiled MatMul Kernel

下面的代码是一个骨架,核心部分(加载 A/B tile、做 dot product)留空了。

你需要补全:

  1. 计算当前 block 负责输出 CC 的哪个 tile。
  2. 沿 KK 维循环时,加载 AABB 的 tile。
  3. tl.dot 计算 partial sum 并累加。

提示:

  • BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K 是编译时常量。
  • tl.load(ptrs, mask=..., other=0.0) 可以处理边界。
  • tl.dot(a, b, acc) 对 FP16/BF16 支持很好。
import triton
import triton.language as tl


@triton.jit
def tiled_matmul_kernel(
    a_ptr, b_ptr, c_ptr,
    M, N, K,
    stride_am, stride_ak,
    stride_bk, stride_bn,
    stride_cm, stride_cn,
    BLOCK_SIZE_M: tl.constexpr,
    BLOCK_SIZE_N: tl.constexpr,
    BLOCK_SIZE_K: tl.constexpr,
):
    """
    每个 block 负责计算 C 的一个 (BLOCK_SIZE_M, BLOCK_SIZE_N) tile。
    """
    # TODO 1: 获取当前 block 在 grid 中的 (m, n) 坐标
    pid = tl.program_id(0)
    num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
    pid_m = pid // num_pid_n  # TODO: 改成正确计算
    pid_n = pid % num_pid_n   # TODO: 改成正确计算

    # 当前 block 负责的 C tile 的左上角偏移
    offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
    offs_k = tl.arange(0, BLOCK_SIZE_K)

    # 计算 A tile 和 B tile 的指针
    a_ptrs = a_ptr + (offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak)
    b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn)

    # 初始化 accumulator
    acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)

    # 沿 K 维循环
    for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
        # TODO 2: 加载 A tile 和 B tile,注意边界 mask
        a = tl.load(a_ptrs, mask=offs_m[:, None] < M and offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
        b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K and offs_n[None, :] < N, other=0.0)

        # TODO 3: 用 tl.dot 计算 partial sum 并累加到 acc
        acc += tl.dot(a, b, acc)

        # 指针沿 K 维滑动到下一个 tile
        a_ptrs += BLOCK_SIZE_K * stride_ak
        b_ptrs += BLOCK_SIZE_K * stride_bk

    # 写回 C
    c_ptrs = c_ptr + (offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn)
    c_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
    tl.store(c_ptrs, acc.to(tl.float16), mask=c_mask)


def tiled_matmul(a: torch.Tensor, b: torch.Tensor, BLOCK_SIZE_M=32, BLOCK_SIZE_N=32, BLOCK_SIZE_K=32):
    assert a.shape[1] == b.shape[0]
    M, K = a.shape
    K2, N = b.shape
    assert K == K2
    c = torch.empty((M, N), device=a.device, dtype=a.dtype)
    grid = (triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N),)
    tiled_matmul_kernel[grid](
        a, b, c,
        M, N, K,
        a.stride(0), a.stride(1),
        b.stride(0), b.stride(1),
        c.stride(0), c.stride(1),
        BLOCK_SIZE_M=BLOCK_SIZE_M,
        BLOCK_SIZE_N=BLOCK_SIZE_N,
        BLOCK_SIZE_K=BLOCK_SIZE_K,
    )
    return c


# 测试正确性
if torch.cuda.is_available():
    M, N, K = 256, 256, 256
    a = torch.randn(M, K, device="cuda", dtype=torch.float16)
    b = torch.randn(K, N, device="cuda", dtype=torch.float16)
    c_ref = a @ b
    c_tri = tiled_matmul(a, b, BLOCK_SIZE_M=32, BLOCK_SIZE_N=32, BLOCK_SIZE_K=32)
    print("Correctness:", torch.allclose(c_tri, c_ref, atol=1e-2, rtol=1e-2))
    print("Max diff:", (c_tri - c_ref).abs().max().item())
else:
    print("CUDA not available.")

9.5.5 练习 1:改变 BLOCK_SIZE,观察性能变化

不同的 BLOCK_SIZE 会影响:

  • shared memory 使用量
  • occupancy
  • L2 cache 命中率
  • 边界检查 overhead

尝试几组参数,记录运行时间。

if torch.cuda.is_available():
    M, N, K = 1024, 1024, 1024
    a = torch.randn(M, K, device="cuda", dtype=torch.float16)
    b = torch.randn(K, N, device="cuda", dtype=torch.float16)
    c_ref = a @ b

    configs = [
        (16, 16, 16),
        (32, 32, 32),
        (64, 64, 64),
        (64, 64, 32),
        (128, 128, 32),
    ]

    results = []
    for BM, BN, BK in configs:
        try:
            torch.cuda.empty_cache()
            c_tri = tiled_matmul(a, b, BLOCK_SIZE_M=BM, BLOCK_SIZE_N=BN, BLOCK_SIZE_K=BK)
            correct = torch.allclose(c_tri, c_ref, atol=1e-2, rtol=1e-2)
            t = benchmark(f"tiled_matmul({BM}x{BN}x{BK})", lambda: tiled_matmul(a, b, BLOCK_SIZE_M=BM, BLOCK_SIZE_N=BN, BLOCK_SIZE_K=BK), num_warmups=2, num_trials=5)
            results.append((BM, BN, BK, t, correct))
            print(f"{BM}x{BN}x{BK}: {t:.3f} ms, correct={correct}")
        except Exception as e:
            print(f"{BM}x{BN}x{BK} failed: {e}")

    # PyTorch baseline
    t_ref = benchmark("pytorch_matmul", lambda: a @ b, num_warmups=2, num_trials=5)
    print(f"\nPyTorch: {t_ref:.3f} ms")

9.5.6 练习 2:分析 Shared Memory 用量与 Occupancy

一个 block 的 shared memory 用量约为:

shmem_per_block = BLOCK_SIZE_M * BLOCK_SIZE_K * sizeof(dtype) + BLOCK_SIZE_K * BLOCK_SIZE_N * sizeof(dtype)

对于 FP16(2 bytes):

  • 32x32x32: 32322 + 32322 = 4 KB
  • 64x64x64: 64642 + 64642 = 16 KB
  • 128x128x32: 128322 + 321282 = 16 KB

RTX 5060 每个 SM 的 shared memory 大约 100 KB。据此估算:

  • 每个 SM 最多能同时驻留多少个 64x64x64 的 block?
  • 如果 shared memory 变成瓶颈,occupancy 会受到什么影响?
def shmem_for_matmul(BM, BN, BK, dtype_bytes=2):
    """估算一个 block 的 shared memory 用量(近似)。"""
    a_tile = BM * BK * dtype_bytes
    b_tile = BK * BN * dtype_bytes
    return a_tile + b_tile


print("Shared memory usage per block:")
for BM, BN, BK in [(32, 32, 32), (64, 64, 64), (128, 128, 32), (128, 128, 64)]:
    shmem = shmem_for_matmul(BM, BN, BK)
    print(f"  {BM}x{BN}x{BK}: {shmem / 1024:.1f} KB")

# 假设 SM 有 100 KB shared memory
sm_smem_kb = 100
print(f"\nWith {sm_smem_kb} KB shared memory per SM:")
for BM, BN, BK in [(64, 64, 64), (128, 128, 32), (128, 128, 64)]:
    shmem_kb = shmem_for_matmul(BM, BN, BK) / 1024
    max_blocks_by_smem = sm_smem_kb // shmem_kb
    print(f"  {BM}x{BN}x{BK}: max {max_blocks_by_smem} blocks per SM (by shared memory)")

9.5.7 练习 3:对比 Tiled vs PyTorch MatMul

上面的 tiled kernel 只是一个教学实现,没有做很多高级优化(如:

  • 向量加载/存储
  • 多级 tiling(register-level + shared-level)
  • L2 cache 友好的 pid 分组
  • warp-level 优化

所以通常比 PyTorch/CUTLASS 慢。这很正常,重点是理解 tiling 的思想。

思考题:

  1. 为什么增大 BLOCK_SIZE 不总是加速?
  2. 为什么 128x128x32 可能比 64x64x64 快或慢?
  3. 如果要继续优化这个 kernel,你会先改哪里?
if torch.cuda.is_available():
    sizes = [512, 1024, 2048]
    for size in sizes:
        a = torch.randn(size, size, device="cuda", dtype=torch.float16)
        b = torch.randn(size, size, device="cuda", dtype=torch.float16)
        t_ref = benchmark(f"pytorch_matmul({size})", lambda: a @ b, num_warmups=2, num_trials=5)
        t_tri = benchmark(f"tiled_matmul({size}, 64x64x32)",
                          lambda: tiled_matmul(a, b, BLOCK_SIZE_M=64, BLOCK_SIZE_N=64, BLOCK_SIZE_K=32),
                          num_warmups=2, num_trials=5)
        print(f"size={size}: PyTorch={t_ref:.2f} ms, Tiled={t_tri:.2f} ms, ratio={t_tri/t_ref:.2f}x")
else:
    print("CUDA not available.")

9.5.8 小结

分块(tiling)是 GPU kernel 优化的核心技巧:

  1. 问题:naive matmul 重复读取 HBM,memory-bound。
  2. 思想:把数据切成小块加载到 shared memory,复用多次。
  3. 关键参数:BLOCK_SIZE_M/N/K 影响 shared memory、occupancy、边界 overhead。
  4. 现代框架:PyTorch/CUTLASS 的 matmul 本质上就是极其优化的 tiled kernel,通常还有多级 tiling 和 Tensor Core。

下一步:可以尝试把上面的 tiled matmul 改成用 Tensor Core(FP16 下 tl.dot 已经会尝试用 Tensor Core),或者加入 L2 cache 友好的 pid 分组。

Part 10:Nsight Systems / Nsight Compute 实操

10.1 Nsight Systems (nsys)

nsys profile 记录整个应用的 CPU + GPU timeline,适合看:

  • 哪些 kernel 占了最多时间。
  • CPU 和 GPU 之间有没有空闲间隙(launch latency)。
  • CUDA API 调用顺序。

基本命令:

nsys profile --stats=true -o output_name python your_script.py

10.2 Nsight Compute (ncu)

ncu 分析单个 kernel 的微观指标:

  • sm__warps_active.avg.pct_of_peak_sustained_elapsed:SM 利用率 / occupancy。
  • dram__bytes.sum.per_second:显存带宽利用率。
  • smsp__sass_thread_inst_executed_op_fadd_pred_on.sum:实际执行的 FLOPs。

10.3 实践:分析 GeLU


# 生成一个可独立运行的 profile 脚本
profile_script = r'''
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))

import torch

def manual_gelu(x):
    return 0.5 * x * (1 + torch.tanh(0.79788456 * (x + 0.044715 * x * x * x)))

def main():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    x = torch.randn(4096 * 4096, device=device)
    for _ in range(3):
        y = manual_gelu(x)
    if device == "cuda":
        torch.cuda.synchronize()

if __name__ == "__main__":
    main()
'''
os.makedirs("scripts", exist_ok=True)
with open("scripts/profile_gelu.py", "w") as f:
    f.write(profile_script)
print("scripts/profile_gelu.py written.")
# 在 Notebook 中运行 nsys profile(注意:会花一些时间)
import subprocess

if torch.cuda.is_available():
    cmd = [
        "nsys", "profile", "--stats=true", "--no-wait",
        "-o", "var/nsys_gelu",
        ".venv/bin/python", "scripts/profile_gelu.py",
    ]
    print("Running:", " ".join(cmd))
    result = subprocess.run(cmd, capture_output=True, text=True, timeout=300)
    print(result.stdout[-3000:])
    print(result.stderr[-1000:])
else:
    print("CUDA not available, skip nsys profile.")

10.4 如何读 nsys 报告

nsys 输出通常包含:

  • CUDA API Statistics:每个 CUDA API 调用的次数和总时间。
  • CUDA Kernel Statistics:每个 kernel 的执行次数、平均/最大/总时间。
  • Memory Operation Statistics:HtoD/DtoH 传输次数。

重点关注:

  1. Time(%) 最高的 kernel:这是你的瓶颈。
  2. CPU 是否在等 GPU:看 timeline 里有没有大片空白。
  3. kernel launch overhead:小 kernel 太多会导致 launch 开销占比高。

生成的 .nsys-rep 可以用 Nsight Systems GUI 打开:

nsys-ui var/nsys_gelu.nsys-rep

# 用 ncu 分析单个 kernel(需要知道 kernel 名称,可用 nsys 或 torch.profiler 获取)
if torch.cuda.is_available():
    cmd = [
        "ncu",
        "--kernel-name", "regex:gelsu",  # GeLU backward 名称示例
        "--metrics", "sm__warps_active.avg.pct_of_peak_sustained_elapsed,dram__bytes.sum.per_second",
        ".venv/bin/python", "scripts/profile_gelu.py",
    ]
    print("Example ncu command:")
    print(" ".join(cmd))
    print("\\nNote: run this in terminal if you want actual metrics.")

10.5 用 nvtx 标注代码区间

nvtx 可以在代码里打标记,nsys timeline 上就会显示这些区间,方便定位。


import nvtx

if torch.cuda.is_available():
    with nvtx.annotate("gelu_loop", color="green"):
        for _ in range(5):
            _ = manual_gelu(x)
    torch.cuda.synchronize()
    print("nvtx annotated region executed.")

Part 11:总结与练习

11.1 核心原则

  1. Measure first:先 benchmark / profile,再优化。
  2. Minimize reads/writes:kernel fusion 和 tiling 是核心手段。
  3. Know your bottleneck:compute-bound 提 FLOPs,memory-bound 提访存效率。
  4. Use the right tool
    • 日常开发:torch.compile + PyTorch fused ops。
    • 自定义融合:Triton。
    • 极致优化或特殊硬件:CUDA。
    • 性能分析:nsys + ncu + torch.profiler

11.2 自测练习

练习 1:手写 ReLU 的 CUDA 和 Triton kernel

要求:

  • CUDA:relu.cu,输入输出都是 float32。
  • Triton:triton_relu_kernel,支持任意长度向量。
  • torch.relu 对比正确性和速度。

练习 2:手写 row-wise sum 的 Triton kernel

输入 shape (M, N),输出 shape (M,),对每行求和。

提示:和 softmax 类似,但不需要 max/exp。

练习 3:用 nsys 分析训练循环

对下面这个训练循环跑 nsys profile

model = MLP(512, 8).cuda()
opt = torch.optim.AdamW(model.parameters(), lr=1e-3)
x = torch.randn(128, 512, device="cuda")
for _ in range(20):
    opt.zero_grad()
    loss = model(x).mean()
    loss.backward()
    opt.step()

回答:

  • 时间占比最高的前 3 个 kernel 是什么?
  • 有没有 kernel launch gap?
  • 优化器 step 是否占大量时间?

11.3 进阶阅读


附录:交互式 GPU 架构可视化

本讲还包含一个交互式的 GPU 架构可视化页面,覆盖 SM、Warp Scheduler、SIMT、Occupancy 等主题:

也可以下载原始 Notebook 运行:

END

Comments