{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "d1d673f7",
   "metadata": {},
   "source": [
    "\n",
    "# CS336 Lecture 6 学习讲义：GPU 性能、Kernel 手写与 Nsight 分析\n",
    "\n",
    "> 副标题：从 PyTorch 算子到底层 CUDA / Triton Kernel，再到 NVIDIA Nsight 性能分析\n",
    ">\n",
    "> 适用对象：希望入门 AI Infra、能独立手写基础 CUDA/Triton kernel 的同学\n",
    ">\n",
    "> 依据材料：[CS336 Spring 2025 Lecture 6](https://cs336.stanford.edu/spring2025-lectures/?trace=var/traces/lecture_06.json)\n",
    "\n",
    "---\n",
    "\n",
    "## 学完这节课你能做什么\n",
    "\n",
    "1. **看懂 GPU 执行模型**：SM、thread/block/grid、shared memory、arithmetic intensity。\n",
    "2. **会写 benchmark / profile**：用 `torch.profiler` 和 `time.time()` 做可复现的测速。\n",
    "3. **能手写 CUDA kernel**：从 C++ 源码到 `torch.utils.cpp_extension.load_inline` 加载。\n",
    "4. **能手写 Triton kernel**：element-wise（GeLU）和 row-wise reduction（softmax）。\n",
    "5. **会用 `torch.compile` 自动融合算子**。\n",
    "6. **会用 Nsight Systems (`nsys`) 和 Nsight Compute (`ncu`) 分析性能瓶颈**。\n",
    "\n",
    "---\n",
    "\n",
    "## 运行环境\n",
    "\n",
    "本 Notebook 使用 `Note/lec6/` 下的独立 uv 环境，已安装：\n",
    "\n",
    "- Python 3.11\n",
    "- PyTorch nightly (cu128，支持 RTX 5060 Blackwell)\n",
    "- Triton 3.7\n",
    "- JupyterLab、matplotlib、numpy、nvtx\n",
    "- 系统自带 `nsys` / `ncu`\n",
    "\n",
    "如果当前没有 GPU，所有 CUDA 相关 cell 会自动跳过，Triton 相关代码需要在 GPU 环境下运行。\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6715c6e5",
   "metadata": {},
   "source": [
    "\n",
    "## Part 0：环境检查\n",
    "\n",
    "先确认 PyTorch、CUDA、Triton、Nsight 工具是否就绪。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "549da693",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "import sys\n",
    "import subprocess\n",
    "\n",
    "import torch\n",
    "import triton\n",
    "\n",
    "print(f\"Python: {sys.version.split()[0]}\")\n",
    "print(f\"PyTorch: {torch.__version__}\")\n",
    "print(f\"CUDA available: {torch.cuda.is_available()}\")\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    print(f\"CUDA version: {torch.version.cuda}\")\n",
    "    props = torch.cuda.get_device_properties(0)\n",
    "    print(f\"GPU: {props.name}\")\n",
    "    print(f\"SM count: {props.multi_processor_count}\")\n",
    "    print(f\"Total memory: {props.total_memory / 1e9:.2f} GB\")\n",
    "    print(f\"Compute capability: {props.major}.{props.minor}\")\n",
    "\n",
    "print(f\"Triton: {triton.__version__}\")\n",
    "\n",
    "for cmd, label in [(\"nsys\", \"Nsight Systems\"), (\"ncu\", \"Nsight Compute\")]:\n",
    "    try:\n",
    "        out = subprocess.run([cmd, \"--version\"], capture_output=True, text=True, check=True).stdout.strip()\n",
    "        print(f\"{label}: {out.split(chr(10))[0]}\")\n",
    "    except Exception as e:\n",
    "        print(f\"{label}: not available ({e})\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c24d8232",
   "metadata": {},
   "source": [
    "\n",
    "## Part 1：GPU 硬件与执行模型\n",
    "\n",
    "### 1.1 硬件层次\n",
    "\n",
    "| 层级 | A100 示例 | 作用 |\n",
    "|------|----------|------|\n",
    "| Streaming Multiprocessors (SMs) | 108 个 | 实际做计算的单位 |\n",
    "| DRAM (HBM) | 80 GB | 大容量显存，带宽高但延迟高 |\n",
    "| L2 Cache | 40 MB | 全局共享，比 DRAM 快 |\n",
    "| L1 Cache / Shared Memory | 192 KB / SM | 每个 SM 内，可被同 block 线程共享 |\n",
    "\n",
    "### 1.2 执行层次\n",
    "\n",
    "GPU 编程的核心抽象是 **\"对大量索引并行执行同一个函数 f(i)\"**。\n",
    "\n",
    "- **Thread**：执行 `f(i)` 的最小单元。\n",
    "- **Thread Block（CTA）**：一组线程，运行在同一个 SM 上，内部可共享 shared memory 并同步。\n",
    "- **Grid**：所有 thread block 的集合。\n",
    "\n",
    "### 1.3 关键概念\n",
    "\n",
    "- **Wave quantization**：如果 grid 里的 block 数不能整除 SM 数，最后一 wave 会空闲部分 SM。\n",
    "- **Arithmetic intensity**：`FLOPs / byte`，越高越偏向 compute-bound，越低越偏向 memory-bound。\n",
    "- **一般规律**：矩阵乘法是 compute-bound，大部分其他操作（激活、归一化、softmax）是 memory-bound。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "06dffc56",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def print_gpu_specs():\n",
    "    if not torch.cuda.is_available():\n",
    "        print(\"No CUDA available\")\n",
    "        return\n",
    "    num = torch.cuda.device_count()\n",
    "    print(f\"Devices: {num}\")\n",
    "    for i in range(num):\n",
    "        p = torch.cuda.get_device_properties(i)\n",
    "        print(f\"  [{i}] {p.name}\")\n",
    "        print(f\"       SMs: {p.multi_processor_count}, Memory: {p.total_memory/1e9:.2f} GB\")\n",
    "        print(f\"       Compute capability: {p.major}.{p.minor}\")\n",
    "        print(f\"       L2 cache size: {p.L2_cache_size / 1024**2:.1f} MB\")\n",
    "\n",
    "print_gpu_specs()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "231fe2bc",
   "metadata": {},
   "source": [
    "\n",
    "### 1.4 思考题\n",
    "\n",
    "1. 为什么矩阵乘法是 compute-bound，而 GeLU 是 memory-bound？\n",
    "2. 对于一个 element-wise 操作，增加 block 数是否能无限加速？瓶颈在哪里？\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "db5845a7",
   "metadata": {},
   "source": [
    "\n",
    "## Part 2：Benchmarking —— 先会测速，再谈优化\n",
    "\n",
    "优化的第一步永远是 **测量**。没有测量，所有优化都是盲猜。\n",
    "\n",
    "一个好的 benchmark 函数应该：\n",
    "1. **Warmup**：先跑几次，排除编译、缓存冷启动影响。\n",
    "2. **Synchronize**：CUDA 是异步的，必须 `torch.cuda.synchronize()` 才能拿到真实 wall-clock 时间。\n",
    "3. **多次试验**：取平均，减少噪声。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "50867150",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "from src.helpers import benchmark\n",
    "\n",
    "# 简单示例：sleep 50ms\n",
    "import time\n",
    "benchmark(\"sleep\", lambda: time.sleep(0.05), num_warmups=0, num_trials=3)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fc57b0f3",
   "metadata": {},
   "source": [
    "\n",
    "### 2.1 用 MLP 做 scaling 实验\n",
    "\n",
    "下面定义一个简单 MLP：`Linear -> GeLU -> Linear -> GeLU -> ... -> Linear -> GeLU`。\n",
    "\n",
    "我们通过改变 step / layers / batch / dim 来观察时间如何变化。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea55e0b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "import torch.nn as nn\n",
    "\n",
    "\n",
    "class MLP(nn.Module):\n",
    "    # Simple MLP: linear -> GeLU -> linear -> GeLU -> ... -> linear -> GeLU\n",
    "    def __init__(self, dim: int, num_layers: int):\n",
    "        super().__init__()\n",
    "        self.layers = nn.ModuleList([nn.Linear(dim, dim) for _ in range(num_layers)])\n",
    "\n",
    "    def forward(self, x):\n",
    "        for layer in self.layers:\n",
    "            x = layer(x)\n",
    "            x = torch.nn.functional.gelu(x)\n",
    "        return x\n",
    "\n",
    "\n",
    "def run_mlp(dim: int, num_layers: int, batch_size: int, num_steps: int):\n",
    "    device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "    model = MLP(dim, num_layers).to(device)\n",
    "    x = torch.randn(batch_size, dim, device=device)\n",
    "\n",
    "    def fn():\n",
    "        for _ in range(num_steps):\n",
    "            y = model(x).mean()\n",
    "            y.backward()\n",
    "    return fn\n",
    "\n",
    "\n",
    "# baseline\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "print(f\"Running on {device}\")\n",
    "baseline = benchmark(\"mlp_baseline\", run_mlp(dim=256, num_layers=4, batch_size=256, num_steps=2))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "25e6e556",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# scaling experiments\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "\n",
    "def scaling_experiment(name, vary_fn, scales):\n",
    "    # vary_fn(scale) -> (label, fn)\n",
    "    results = []\n",
    "    for s in scales:\n",
    "        label, fn = vary_fn(s)\n",
    "        t = benchmark(label, fn, num_warmups=2, num_trials=3)\n",
    "        results.append((s, t))\n",
    "    return results\n",
    "\n",
    "\n",
    "scales = [1, 2, 3, 4, 5]\n",
    "\n",
    "# 1) scale num_steps\n",
    "step_results = scaling_experiment(\n",
    "    \"steps\",\n",
    "    lambda s: (f\"mlp({s}x steps)\", run_mlp(256, 4, 256, 2 * s)),\n",
    "    scales,\n",
    ")\n",
    "\n",
    "# 2) scale num_layers\n",
    "layer_results = scaling_experiment(\n",
    "    \"layers\",\n",
    "    lambda s: (f\"mlp({s}x layers)\", run_mlp(256, 4 * s, 256, 2)),\n",
    "    scales,\n",
    ")\n",
    "\n",
    "# 3) scale batch_size\n",
    "batch_results = scaling_experiment(\n",
    "    \"batch\",\n",
    "    lambda s: (f\"mlp({s}x batch)\", run_mlp(256, 4, 256 * s, 2)),\n",
    "    scales,\n",
    ")\n",
    "\n",
    "# 4) scale dim\n",
    "dim_results = scaling_experiment(\n",
    "    \"dim\",\n",
    "    lambda s: (f\"mlp({s}x dim)\", run_mlp(256 * s, 4, 256, 2)),\n",
    "    scales,\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6231fbca",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "fig, axes = plt.subplots(2, 2, figsize=(12, 10))\n",
    "for ax, (title, data) in zip(axes.flat, [\n",
    "    (\"Scale steps\", step_results),\n",
    "    (\"Scale layers\", layer_results),\n",
    "    (\"Scale batch\", batch_results),\n",
    "    (\"Scale dim\", dim_results),\n",
    "]):\n",
    "    xs, ys = zip(*data)\n",
    "    ax.plot(xs, ys, marker='o')\n",
    "    ax.set_xlabel(\"Scale factor\")\n",
    "    ax.set_ylabel(\"Time (ms)\")\n",
    "    ax.set_title(title)\n",
    "    ax.grid(True)\n",
    "plt.tight_layout()\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5556aa5b",
   "metadata": {},
   "source": [
    "\n",
    "### 2.2 观察与结论\n",
    "\n",
    "- 增加 `num_steps`：时间近似线性增长，因为每次都要完整跑前向+反向。\n",
    "- 增加 `num_layers`：线性增长，但常数比 steps 大（layer 越多，需要存储的激活和梯度越多）。\n",
    "- 增加 `batch_size`：通常次线性或接近线性，取决于是否达到 compute-bound。\n",
    "- 增加 `dim`：往往会超线性增长，因为计算量 O(d³) 级别（矩阵乘法主导）。\n",
    "\n",
    "> **实践建议**：每次改动模型或训练代码后，都跑一组 scaling 实验，确认性能变化符合预期。\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2633b714",
   "metadata": {},
   "source": [
    "\n",
    "## Part 3：Profiling —— 知道时间花在哪里\n",
    "\n",
    "Benchmarking 只看总时间；Profiling 告诉我们 **总时间 = 哪些 kernel / 哪些 Python 调用** 组成的。\n",
    "\n",
    "PyTorch 内置了 `torch.profiler`，可以输出每个 op 的 CPU/CUDA 时间。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "483f3642",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "from src.helpers import profile_table\n",
    "\n",
    "\n",
    "def add_fn():\n",
    "    a = torch.randn(1024, 1024, device=\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "    b = torch.randn(1024, 1024, device=a.device)\n",
    "    return a + b\n",
    "\n",
    "\n",
    "profile_table(\"add\", add_fn)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "336feccb",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def matmul_fn():\n",
    "    a = torch.randn(1024, 1024, device=\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "    b = torch.randn(1024, 1024, device=a.device)\n",
    "    return a @ b\n",
    "\n",
    "\n",
    "profile_table(\"matmul\", matmul_fn)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9dbe49e8",
   "metadata": {},
   "source": [
    "\n",
    "### 3.1 解读 kernel 名称\n",
    "\n",
    "例如 `cutlass_80_simt_sgemm_256x128_8x4_nn_align1`：\n",
    "\n",
    "- `cutlass`：NVIDIA 的 CUDA linear algebra 库。\n",
    "- `80`：对应 SM 8.0（Ampere）。\n",
    "- `simt_sgemm`：单精度通用矩阵乘。\n",
    "- `256x128`：tile 大小。\n",
    "- `nn`：A 和 B 都不转置（non-transpose）。\n",
    "\n",
    "通过 kernel 名称，可以反推出 PyTorch 实际调了哪个底层实现。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8363b29c",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# 用 with_stack=True 可以导出火焰图\n",
    "if torch.cuda.is_available():\n",
    "    profile_table(\n",
    "        \"mlp\",\n",
    "        run_mlp(dim=2048, num_layers=64, batch_size=1024, num_steps=2),\n",
    "        with_stack=True,\n",
    "    )\n",
    "else:\n",
    "    print(\"CUDA not available, skip flame graph profiling.\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d6404fc9",
   "metadata": {},
   "source": [
    "\n",
    "### 3.2 火焰图导出\n",
    "\n",
    "`torch.profiler` 支持 `export_stacks(path, metric)` 导出文本栈，可以用其他工具转成 SVG。\n",
    "\n",
    "```python\n",
    "prof.export_stacks(\"var/stacks_mlp.txt\", \"self_cuda_time_total\")\n",
    "```\n",
    "\n",
    "导出后可用 `stackvis` 或 FlameGraph 工具渲染：\n",
    "\n",
    "```bash\n",
    "# 需要安装 FlameGraph 工具\n",
    "git clone https://github.com/brendangregg/FlameGraph.git\n",
    "./FlameGraph/flamegraph.pl var/stacks_mlp.txt > var/stacks_mlp.svg\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1338b04f",
   "metadata": {},
   "source": [
    "\n",
    "## Part 4：Kernel Fusion —— 减少读写就是加速\n",
    "\n",
    "### 4.1 Warehouse / Factory 类比\n",
    "\n",
    "- **DRAM = 仓库**：大、慢、便宜。\n",
    "- **SRAM = 工厂**：小、快、贵。\n",
    "\n",
    "每个 PyTorch 算子都要：从仓库读原料 → 在工厂加工 → 写回仓库。如果有多个算子，就会反复搬运。\n",
    "\n",
    "**Kernel fusion**：把多个算子合并成一个 kernel，原料只读一次、结果只写一次。\n",
    "\n",
    "### 4.2 GeLU 案例\n",
    "\n",
    "PyTorch 的 `F.gelu` 是融合实现；我们可以手写一个非融合版本对比。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1faf1ddd",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def manual_gelu(x: torch.Tensor) -> torch.Tensor:\n",
    "    # 手写 GeLU（非融合，会产生多个 kernel）\n",
    "    return 0.5 * x * (1 + torch.tanh(0.79788456 * (x + 0.044715 * x * x * x)))\n",
    "\n",
    "\n",
    "def pytorch_gelu(x: torch.Tensor) -> torch.Tensor:\n",
    "    # PyTorch 融合 GeLU\n",
    "    return torch.nn.functional.gelu(x, approximate=\"tanh\")\n",
    "\n",
    "\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "x = torch.randn(1024, 1024, device=device)\n",
    "print(\"Correctness:\", torch.allclose(pytorch_gelu(x), manual_gelu(x), atol=1e-5))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9342d3e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "benchmark(\"manual_gelu\", lambda: manual_gelu(x))\n",
    "benchmark(\"pytorch_gelu\", lambda: pytorch_gelu(x))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "34fe5779",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "profile_table(\"manual_gelu\", lambda: manual_gelu(x))\n",
    "profile_table(\"pytorch_gelu\", lambda: pytorch_gelu(x))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "85894123",
   "metadata": {},
   "source": [
    "\n",
    "### 4.3 结论\n",
    "\n",
    "- `manual_gelu` 会触发多个 CUDA kernel：乘、加、tanh、乘...\n",
    "- `pytorch_gelu` 通常只触发一个融合 kernel。\n",
    "- 在 memory-bound 场景下，融合版明显更快。\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "01e01c26",
   "metadata": {},
   "source": [
    "\n",
    "## Part 5：手写 CUDA Kernel\n",
    "\n",
    "CUDA 是 C/C++ 的扩展，核心思想是：**写一个线程要执行的函数，然后启动 N 个线程**。\n",
    "\n",
    "### 5.1 线程索引\n",
    "\n",
    "- `blockIdx.x`：当前线程所在的 block 编号。\n",
    "- `threadIdx.x`：当前线程在 block 内的编号。\n",
    "- `blockDim.x`：每个 block 的线程数。\n",
    "- 全局线程 ID = `blockIdx.x * blockDim.x + threadIdx.x`。\n",
    "\n",
    "### 5.2 用 `load_inline` 编译 CUDA\n",
    "\n",
    "`torch.utils.cpp_extension.load_inline` 可以让我们在 Python 里写一段 CUDA C++ 代码，即时编译成一个 Python 模块。\n",
    "\n",
    "我们的 `src/gelu.cu` 已经写好了，直接加载。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "df103674",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "if torch.cuda.is_available():\n",
    "    from torch.utils.cpp_extension import load_inline\n",
    "\n",
    "    cuda_mod = load_inline(\n",
    "        name=\"inline_gelu\",\n",
    "        cuda_sources=[open(\"src/gelu.cu\").read()],\n",
    "        cpp_sources=[\"torch::Tensor gelu(torch::Tensor x);\"],\n",
    "        functions=[\"gelu\"],\n",
    "        extra_cflags=[\"-O2\"],\n",
    "        verbose=False,\n",
    "    )\n",
    "    cuda_gelu = cuda_mod.gelu\n",
    "    print(\"CUDA gelu module loaded.\")\n",
    "else:\n",
    "    print(\"CUDA not available, skip CUDA kernel loading.\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b39c3088",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "if torch.cuda.is_available():\n",
    "    x_f = x.float()\n",
    "    y_cuda = cuda_gelu(x_f)\n",
    "    y_ref = pytorch_gelu(x_f)\n",
    "    print(\"CUDA gelu correct:\", torch.allclose(y_cuda, y_ref, atol=1e-5))\n",
    "    print(\"Max diff:\", (y_cuda - y_ref).abs().max().item())\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6e898550",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "if torch.cuda.is_available():\n",
    "    benchmark(\"manual_gelu\", lambda: manual_gelu(x))\n",
    "    benchmark(\"pytorch_gelu\", lambda: pytorch_gelu(x))\n",
    "    benchmark(\"cuda_gelu\", lambda: cuda_gelu(x.float()))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e9ca7b7a",
   "metadata": {},
   "source": [
    "\n",
    "### 5.3 `gelu.cu` 逐行解析\n",
    "\n",
    "```cpp\n",
    "__global__ void gelu_kernel(const float* __restrict__ x,\n",
    "                            float* __restrict__ y,\n",
    "                            int n) {\n",
    "    int idx = blockIdx.x * blockDim.x + threadIdx.x;  // 全局线程 ID\n",
    "    int stride = blockDim.x * gridDim.x;              // 所有线程一次能跨多少元素\n",
    "    for (int i = idx; i < n; i += stride) {           // grid-stride loop\n",
    "        float xi = x[i];\n",
    "        float a = 0.79788456f * (xi + 0.044715f * xi * xi * xi);\n",
    "        y[i] = 0.5f * xi * (1.0f + tanhf(a));\n",
    "    }\n",
    "}\n",
    "```\n",
    "\n",
    "- `__global__`：CPU 调用、GPU 执行的函数。\n",
    "- `__restrict__`：告诉编译器指针不别名，便于优化。\n",
    "- grid-stride loop：即使 block 数不足以覆盖所有元素，每个线程也能处理多个位置。\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9ca40105",
   "metadata": {},
   "source": [
    "\n",
    "## Part 6：手写 Triton Kernel\n",
    "\n",
    "Triton 由 OpenAI 开发，目标是用 Python 写高性能 GPU kernel，同时隐藏线程级细节。\n",
    "\n",
    "| 能力 | CUDA | Triton |\n",
    "|------|------|--------|\n",
    "| Memory coalescing | 手动 | 自动 |\n",
    "| Shared memory 管理 | 手动 | 自动 |\n",
    "| SM 内调度 | 手动 | 自动 |\n",
    "| SM 间调度（grid/block） | 手动 | 手动 |\n",
    "\n",
    "Triton 让我们专注于 **\"每个 block 做什么\"**，而不是每个 thread 做什么。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "911ba789",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "import triton\n",
    "import triton.language as tl\n",
    "\n",
    "\n",
    "@triton.jit\n",
    "def triton_gelu_kernel(x_ptr, y_ptr, n, BLOCK_SIZE: tl.constexpr):\n",
    "    pid = tl.program_id(axis=0)               # block ID\n",
    "    block_start = pid * BLOCK_SIZE\n",
    "    offsets = block_start + tl.arange(0, BLOCK_SIZE)  # 这个 block 负责的元素索引\n",
    "    mask = offsets < n                        # 边界 mask\n",
    "\n",
    "    x = tl.load(x_ptr + offsets, mask=mask)   # 从 global memory 读取\n",
    "\n",
    "    # GeLU tanh 近似\n",
    "    a = 0.79788456 * (x + 0.044715 * x * x * x)\n",
    "    exp = tl.exp(2 * a)\n",
    "    tanh = (exp - 1) / (exp + 1)\n",
    "    y = 0.5 * x * (1 + tanh)\n",
    "\n",
    "    tl.store(y_ptr + offsets, y, mask=mask)   # 写回 global memory\n",
    "\n",
    "\n",
    "def triton_gelu(x: torch.Tensor) -> torch.Tensor:\n",
    "    y = torch.empty_like(x)\n",
    "    n = x.numel()\n",
    "    BLOCK_SIZE = 1024\n",
    "    grid = (triton.cdiv(n, BLOCK_SIZE),)\n",
    "    triton_gelu_kernel[grid](x, y, n, BLOCK_SIZE=BLOCK_SIZE)\n",
    "    return y\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3623ed14",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "if torch.cuda.is_available():\n",
    "    y_triton = triton_gelu(x)\n",
    "    print(\"Triton gelu correct:\", torch.allclose(y_triton, pytorch_gelu(x), atol=1e-5))\n",
    "    print(\"Max diff:\", (y_triton - pytorch_gelu(x)).abs().max().item())\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "675854c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "if torch.cuda.is_available():\n",
    "    benchmark(\"manual_gelu\", lambda: manual_gelu(x))\n",
    "    benchmark(\"pytorch_gelu\", lambda: pytorch_gelu(x))\n",
    "    benchmark(\"cuda_gelu\", lambda: cuda_gelu(x.float()))\n",
    "    benchmark(\"triton_gelu\", lambda: triton_gelu(x))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "695da9b0",
   "metadata": {},
   "source": [
    "\n",
    "### 6.1 看 Triton 生成的 PTX\n",
    "\n",
    "Triton 编译后会生成 PTX（类似 GPU 汇编）。我们可以查看关键指令：\n",
    "\n",
    "- `ld.global.*` / `st.global.*`：读写 global memory（DRAM）。\n",
    "- `%ctaid.x`：block 索引。\n",
    "- `%tid.x`：thread 索引。\n",
    "- thread coarsening：一个 thread 处理多个元素。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d2fedc1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 获取 Triton kernel 的 PTX（Triton API 版本间差异较大，这里用环境变量方式导出）\n",
    "import os\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    # 设置环境变量导出编译产物\n",
    "    os.environ[\"TRITON_DUMP_ASM\"] = \"1\"\n",
    "    os.environ[\"TRITON_CACHE_DIR\"] = os.path.abspath(\"var/triton_cache\")\n",
    "    os.makedirs(\"var/triton_cache\", exist_ok=True)\n",
    "\n",
    "    # 触发编译\n",
    "    _ = triton_gelu(x)\n",
    "    print(\"PTX 已导出到 var/triton_cache，可用以下命令查看：\")\n",
    "    print(\"  find var/triton_cache -name '*.ptx' | head -1 | xargs cat | head -80\")\n",
    "else:\n",
    "    print(\"CUDA not available, skip PTX dump.\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8190aca7",
   "metadata": {},
   "source": [
    "\n",
    "## Part 7：torch.compile —— 不用手写 kernel 也能融合\n",
    "\n",
    "PyTorch 2.0 引入了 `torch.compile`，可以把 Python 写的多个算子自动融合成更高效的 kernel。\n",
    "\n",
    "对于 element-wise 操作，它通常能接近手写 Triton 的性能。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2149eb14",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "compiled_gelu = torch.compile(manual_gelu)\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    print(\"compiled_gelu correct:\", torch.allclose(compiled_gelu(x), pytorch_gelu(x), atol=1e-5))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "55ab93ce",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "if torch.cuda.is_available():\n",
    "    benchmark(\"manual_gelu\", lambda: manual_gelu(x))\n",
    "    benchmark(\"pytorch_gelu\", lambda: pytorch_gelu(x))\n",
    "    benchmark(\"cuda_gelu\", lambda: cuda_gelu(x.float()))\n",
    "    benchmark(\"triton_gelu\", lambda: triton_gelu(x))\n",
    "    benchmark(\"compiled_gelu\", lambda: compiled_gelu(x))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "914dc083",
   "metadata": {},
   "source": [
    "\n",
    "### 7.1 结论\n",
    "\n",
    "| 实现 | 难度 | 速度 | 适用场景 |\n",
    "|------|------|------|----------|\n",
    "| PyTorch fused | 最简单 | 最快之一 | 日常首选 |\n",
    "| manual | 简单 | 最慢 | 理解算子组成 |\n",
    "| CUDA | 最难 | 可控 | 需要极致优化 |\n",
    "| Triton | 中等 | 接近 PyTorch | 自定义融合算子 |\n",
    "| torch.compile | 简单 | 接近 Triton | 快速验证自动优化 |\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "347a1108",
   "metadata": {},
   "source": [
    "\n",
    "## Part 8：Triton Softmax —— row-wise reduction\n",
    "\n",
    "Softmax 是典型的 **row-wise reduction** 操作：对矩阵每一行做 `exp(x - max) / sum(exp(...))`。\n",
    "\n",
    "### 8.1 Naive 实现的访存代价\n",
    "\n",
    "```\n",
    "x_max = x.max(dim=1)        # MN reads, M writes\n",
    "x = x - x_max               # MN + M reads, MN writes\n",
    "num = exp(x)                # MN reads, MN writes\n",
    "den = num.sum(dim=1)        # MN reads, M writes\n",
    "y = num / den               # MN reads, MN writes\n",
    "```\n",
    "\n",
    "总共约 **5MN + M reads, 3MN + 2M writes**。理想情况下只需 **MN reads + MN writes**。\n",
    "\n",
    "Triton 可以把整个 row 放进一个 block，一次完成 max、sub、exp、sum、div。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ddd69707",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def manual_softmax(x: torch.Tensor) -> torch.Tensor:\n",
    "    x_max = x.max(dim=1, keepdim=True)[0]\n",
    "    x = x - x_max\n",
    "    num = torch.exp(x)\n",
    "    den = num.sum(dim=1, keepdim=True)\n",
    "    return num / den\n",
    "\n",
    "\n",
    "def pytorch_softmax(x: torch.Tensor) -> torch.Tensor:\n",
    "    return torch.nn.functional.softmax(x, dim=-1)\n",
    "\n",
    "\n",
    "@triton.jit\n",
    "def triton_softmax_kernel(x_ptr, y_ptr, x_row_stride, y_row_stride, n_cols, BLOCK_SIZE: tl.constexpr):\n",
    "    row_idx = tl.program_id(0)\n",
    "    col_offsets = tl.arange(0, BLOCK_SIZE)\n",
    "\n",
    "    # 读取一整行（超出 n_cols 的部分用 -inf 填充）\n",
    "    x_row = tl.load(\n",
    "        x_ptr + row_idx * x_row_stride + col_offsets,\n",
    "        mask=col_offsets < n_cols,\n",
    "        other=float(\"-inf\"),\n",
    "    )\n",
    "\n",
    "    # 在线内做 softmax\n",
    "    x_row = x_row - tl.max(x_row, axis=0)\n",
    "    num = tl.exp(x_row)\n",
    "    den = tl.sum(num, axis=0)\n",
    "\n",
    "    tl.store(\n",
    "        y_ptr + row_idx * y_row_stride + col_offsets,\n",
    "        num / den,\n",
    "        mask=col_offsets < n_cols,\n",
    "    )\n",
    "\n",
    "\n",
    "def triton_softmax(x: torch.Tensor) -> torch.Tensor:\n",
    "    M, N = x.shape\n",
    "    y = torch.empty_like(x)\n",
    "    BLOCK_SIZE = triton.next_power_of_2(N)\n",
    "    triton_softmax_kernel[(M,)](\n",
    "        x, y,\n",
    "        x.stride(0), y.stride(0),\n",
    "        N, BLOCK_SIZE=BLOCK_SIZE,\n",
    "    )\n",
    "    return y\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a8f07e00",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "if torch.cuda.is_available():\n",
    "    x2 = torch.randn(1024, 1024, device=\"cuda\")\n",
    "    print(\"manual correct:\", torch.allclose(manual_softmax(x2), pytorch_softmax(x2), atol=1e-5))\n",
    "    print(\"triton correct:\", torch.allclose(triton_softmax(x2), pytorch_softmax(x2), atol=1e-5))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d01527b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "if torch.cuda.is_available():\n",
    "    benchmark(\"manual_softmax\", lambda: manual_softmax(x2))\n",
    "    benchmark(\"pytorch_softmax\", lambda: pytorch_softmax(x2))\n",
    "    benchmark(\"triton_softmax\", lambda: triton_softmax(x2))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "872eb45b",
   "metadata": {},
   "source": [
    "\n",
    "### 8.2 Softmax 关键技巧\n",
    "\n",
    "1. ** numerical stability**：先减 max 再 exp，避免指数爆炸。\n",
    "2. **一个 block 处理一行**：利用 Triton 自动管理 shared memory，减少多次全局读写。\n",
    "3. **mask 处理变长列**：`other=float(\"-inf\")` 让 pad 位置不影响 max。\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "77f5b6de",
   "metadata": {},
   "source": [
    "\n",
    "## Part 9：Triton MatMul —— Tiling 与 Shared Memory\n",
    "\n",
    "矩阵乘法 `C = A @ B` 是 GPU 上优化最充分的算法之一。\n",
    "\n",
    "### 9.1 Naive 访存分析\n",
    "\n",
    "- C 的每个元素需要 K 次 A 和 K 次 B 的读取。\n",
    "- 总读取量约 `M * K * N * 2`，写入 `M * N`。\n",
    "- 但计算量也是 `2 * M * N * K` FLOPs，所以 arithmetic intensity 高，属于 compute-bound。\n",
    "\n",
    "### 9.2 Tiling 核心思想\n",
    "\n",
    "把 A 和 B 分成小 block（tile），一次加载到 shared memory：\n",
    "\n",
    "1. 从 DRAM 读 A 的一个 tile 和 B 的一个 tile 到 SRAM。\n",
    "2. 在这个 tile 上做 mini matmul，累加到局部寄存器。\n",
    "3. 重复直到遍历完 K 维。\n",
    "4. 把结果写回 DRAM。\n",
    "\n",
    "这样 A 和 B 的每个元素从 DRAM 读取次数大大减少。\n",
    "\n",
    "### 9.3 L2 Cache 友好的数据布局\n",
    "\n",
    "Triton 官方 tutorial 里比较了 row-major vs grouped ordering：\n",
    "\n",
    "- Row-major：按行优先顺序处理 block，可能导致 B 被重复加载。\n",
    "- Grouped：让相邻 block 共享更多 A/B 数据，减少 L2 miss。\n",
    "\n",
    "参考：[Triton matmul tutorial](https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3435320",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# 先 benchmark PyTorch matmul 作为 baseline\n",
    "if torch.cuda.is_available():\n",
    "    a = torch.randn(1024, 1024, device=\"cuda\")\n",
    "    b = torch.randn(1024, 1024, device=\"cuda\")\n",
    "    benchmark(\"pytorch_matmul\", lambda: a @ b)\n",
    "else:\n",
    "    print(\"CUDA not available, skip matmul benchmark.\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f415459a",
   "metadata": {},
   "source": [
    "\n",
    "### 9.4 进阶：Triton matmul 参考实现（可选阅读）\n",
    "\n",
    "下面的代码来自 Triton 官方 tutorial，做了简化。它是理解 tiling 和 shared memory 的最佳示例。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1850d614",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# 参考 Triton 官方 matmul tutorial 的简化版\n",
    "@triton.jit\n",
    "def matmul_kernel(\n",
    "    a_ptr, b_ptr, c_ptr,\n",
    "    M, N, K,\n",
    "    stride_am, stride_ak,\n",
    "    stride_bk, stride_bn,\n",
    "    stride_cm, stride_cn,\n",
    "    BLOCK_SIZE_M: tl.constexpr,\n",
    "    BLOCK_SIZE_N: tl.constexpr,\n",
    "    BLOCK_SIZE_K: tl.constexpr,\n",
    "    GROUP_SIZE_M: tl.constexpr,\n",
    "):\n",
    "    pid = tl.program_id(axis=0)\n",
    "    num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n",
    "    num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n",
    "    num_pid_in_group = GROUP_SIZE_M * num_pid_n\n",
    "    group_id = pid // num_pid_in_group\n",
    "    first_pid_m = group_id * GROUP_SIZE_M\n",
    "    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n",
    "    pid_m = first_pid_m + (pid % group_size_m)\n",
    "    pid_n = (pid % num_pid_in_group) // group_size_m\n",
    "\n",
    "    offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n",
    "    offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n",
    "    offs_k = tl.arange(0, BLOCK_SIZE_K)\n",
    "\n",
    "    a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)\n",
    "    b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)\n",
    "\n",
    "    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n",
    "    for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):\n",
    "        a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)\n",
    "        b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)\n",
    "        accumulator = tl.dot(a, b, accumulator)\n",
    "        a_ptrs += BLOCK_SIZE_K * stride_ak\n",
    "        b_ptrs += BLOCK_SIZE_K * stride_bk\n",
    "\n",
    "    c = accumulator.to(tl.float16)\n",
    "    offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n",
    "    offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n",
    "    c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]\n",
    "    c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)\n",
    "    tl.store(c_ptrs, c, mask=c_mask)\n",
    "\n",
    "\n",
    "def triton_matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:\n",
    "    assert a.shape[1] == b.shape[0]\n",
    "    M, K = a.shape\n",
    "    K2, N = b.shape\n",
    "    assert K == K2\n",
    "    c = torch.empty((M, N), device=a.device, dtype=torch.float16)\n",
    "    grid = lambda META: (\n",
    "        triton.cdiv(M, META[\"BLOCK_SIZE_M\"]) * triton.cdiv(N, META[\"BLOCK_SIZE_N\"]),\n",
    "    )\n",
    "    matmul_kernel[grid](\n",
    "        a, b, c,\n",
    "        M, N, K,\n",
    "        a.stride(0), a.stride(1),\n",
    "        b.stride(0), b.stride(1),\n",
    "        c.stride(0), c.stride(1),\n",
    "        BLOCK_SIZE_M=128, BLOCK_SIZE_N=256,\n",
    "        BLOCK_SIZE_K=64, GROUP_SIZE_M=8,\n",
    "    )\n",
    "    return c\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5245a53e",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "if torch.cuda.is_available():\n",
    "    a16 = torch.randn(1024, 1024, device=\"cuda\", dtype=torch.float16)\n",
    "    b16 = torch.randn(1024, 1024, device=\"cuda\", dtype=torch.float16)\n",
    "    c_ref = a16 @ b16\n",
    "    c_tri = triton_matmul(a16, b16)\n",
    "    print(\"triton_matmul correct:\", torch.allclose(c_tri, c_ref, atol=1e-2))\n",
    "    benchmark(\"pytorch_matmul\", lambda: a16 @ b16)\n",
    "    benchmark(\"triton_matmul\", lambda: triton_matmul(a16, b16))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d3fb83c5",
   "metadata": {},
   "source": [
    "## Part 9.5：分块（Tiling）入门教学与实践\n",
    "\n",
    "前面 Part 9 提到了 matmul 的 tiling 思想，但那是\"看源码学 tiling\"。本节从更基础的角度入手：**为什么需要分块？分块到底省了什么？如何自己动手写一个最简单的 tiled kernel？**\n",
    "\n",
    "### 学习目标\n",
    "\n",
    "1. 理解 naive matmul 的访存代价。\n",
    "2. 理解 tiling 如何减少 DRAM 读取次数。\n",
    "3. 能补全一个填空式 Triton tiled matmul kernel。\n",
    "4. 会分析 tile size 对 shared memory 和 occupancy 的影响。\n",
    "\n",
    "### 9.5.1 为什么需要分块？\n",
    "\n",
    "考虑矩阵乘法 $C = A \\times B$，其中 $A \\in \\mathbb{R}^{M \\times K}$，$B \\in \\mathbb{R}^{K \\times N}$。\n",
    "\n",
    "**Naive 实现**：每个 thread 负责计算 $C$ 的一个元素 $C_{ij}$。\n",
    "\n",
    "- 计算 $C_{ij}$ 需要读取 $A$ 的第 $i$ 行和 $B$ 的第 $j$ 列，共 $2K$ 个元素。\n",
    "- 整个 $C$ 有 $M \\times N$ 个元素，所以总读取量约为 $2 M N K$。\n",
    "\n",
    "问题是：\n",
    "- $A$ 的第 $i$ 行被所有计算 $C_{i*}$ 的 threads 重复读取。\n",
    "- $B$ 的第 $j$ 列被所有计算 $C_{*j}$ 的 threads 重复读取。\n",
    "\n",
    "如果 $K$ 很大，这些重复读取会大量占用 HBM 带宽。\n",
    "\n",
    "### 9.5.2 分块核心思想\n",
    "\n",
    "把矩阵分成大小为 $(BM, BK)$ 和 $(BK, BN)$ 的小块（tile）：\n",
    "\n",
    "1. 从 HBM 读取 $A$ 的一个 tile 到 shared memory。\n",
    "2. 从 HBM 读取 $B$ 的一个 tile 到 shared memory。\n",
    "3. 用 shared memory 里的数据计算一个 partial sum，累加到寄存器。\n",
    "4. 沿 $K$ 维滑动，重复步骤 1-3。\n",
    "5. 最后把累加结果写回 $C$。\n",
    "\n",
    "关键节省：\n",
    "- $A$ 的每个 tile 只从 HBM 读取一次，但服务了 $BN$ 个输出元素。\n",
    "- $B$ 的每个 tile 只从 HBM 读取一次，但服务了 $BM$ 个输出元素。\n",
    "\n",
    "### 9.5.3 形象类比\n",
    "\n",
    "想象你在做一道大菜：\n",
    "- **Naive**：每次只从冰箱拿一个食材，用完再回去拿。\n",
    "- **Tiling**：一次性从冰箱拿出一小篮子食材，在灶台上用完，再回去拿下一篮。\n",
    "\n",
    "灶台（shared memory）比冰箱（HBM）小很多，但速度快得多。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "856701c2",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.patches as patches\n",
    "\n",
    "\n",
    "def draw_tiling_concept():\n",
    "    fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n",
    "\n",
    "    # Left: naive access pattern\n",
    "    ax = axes[0]\n",
    "    ax.set_title(\"Naive MatMul: each thread reads full row/column\", fontsize=12, fontweight=\"bold\")\n",
    "    for r in range(4):\n",
    "        for c in range(4):\n",
    "            rect = patches.Rectangle((c, r), 0.9, 0.9, facecolor=\"#e3f2fd\", edgecolor=\"#1976d2\")\n",
    "            ax.add_patch(rect)\n",
    "            ax.text(c + 0.45, r + 0.45, f\"C_{{{r},{c}}}\", ha=\"center\", va=\"center\", fontsize=8)\n",
    "\n",
    "    # Highlight row of A and col of B for C[1,2]\n",
    "    ax.annotate(\"\", xy=(4.2, 1.5), xytext=(0, 1.5), arrowprops=dict(arrowstyle=\"->\", color=\"#d32f2f\", lw=2))\n",
    "    ax.text(4.5, 1.5, \"A row i\", color=\"#d32f2f\", va=\"center\")\n",
    "    ax.annotate(\"\", xy=(2.5, 4.2), xytext=(2.5, 0), arrowprops=dict(arrowstyle=\"->\", color=\"#388e3c\", lw=2))\n",
    "    ax.text(2.5, 4.5, \"B col j\", color=\"#388e3c\", ha=\"center\")\n",
    "\n",
    "    ax.set_xlim(-0.5, 5.5)\n",
    "    ax.set_ylim(-0.5, 5)\n",
    "    ax.set_aspect(\"equal\")\n",
    "    ax.axis(\"off\")\n",
    "\n",
    "    # Right: tiled access pattern\n",
    "    ax = axes[1]\n",
    "    ax.set_title(\"Tiled MatMul: load a tile once, reuse for multiple outputs\", fontsize=12, fontweight=\"bold\")\n",
    "    tile_colors = [[\"#c5e1a5\", \"#c5e1a5\", \"#b39ddb\", \"#b39ddb\"],\n",
    "                   [\"#c5e1a5\", \"#c5e1a5\", \"#b39ddb\", \"#b39ddb\"],\n",
    "                   [\"#ffcc80\", \"#ffcc80\", \"#ef9a9a\", \"#ef9a9a\"],\n",
    "                   [\"#ffcc80\", \"#ffcc80\", \"#ef9a9a\", \"#ef9a9a\"]]\n",
    "    for r in range(4):\n",
    "        for c in range(4):\n",
    "            rect = patches.Rectangle((c, r), 0.9, 0.9, facecolor=tile_colors[r][c], edgecolor=\"#333\")\n",
    "            ax.add_patch(rect)\n",
    "            ax.text(c + 0.45, r + 0.45, f\"C_{{{r},{c}}}\", ha=\"center\", va=\"center\", fontsize=8)\n",
    "\n",
    "    ax.text(2, 4.5, \"Each 2x2 tile of A is loaded once and reused\", ha=\"center\", fontsize=9)\n",
    "    ax.set_xlim(-0.5, 4.5)\n",
    "    ax.set_ylim(-0.5, 5)\n",
    "    ax.set_aspect(\"equal\")\n",
    "    ax.axis(\"off\")\n",
    "\n",
    "    plt.tight_layout()\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "draw_tiling_concept()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ee2d9851",
   "metadata": {},
   "source": [
    "### 9.5.4 练习：补全一个 Triton Tiled MatMul Kernel\n",
    "\n",
    "下面的代码是一个**骨架**，核心部分（加载 A/B tile、做 dot product）留空了。\n",
    "\n",
    "你需要补全：\n",
    "1. 计算当前 block 负责输出 $C$ 的哪个 tile。\n",
    "2. 沿 $K$ 维循环时，加载 $A$ 和 $B$ 的 tile。\n",
    "3. 用 `tl.dot` 计算 partial sum 并累加。\n",
    "\n",
    "> 提示：\n",
    "> - `BLOCK_SIZE_M`, `BLOCK_SIZE_N`, `BLOCK_SIZE_K` 是编译时常量。\n",
    "> - `tl.load(ptrs, mask=..., other=0.0)` 可以处理边界。\n",
    "> - `tl.dot(a, b, acc)` 对 FP16/BF16 支持很好。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "21d36b73",
   "metadata": {},
   "outputs": [],
   "source": [
    "import triton\n",
    "import triton.language as tl\n",
    "\n",
    "\n",
    "@triton.jit\n",
    "def tiled_matmul_kernel(\n",
    "    a_ptr, b_ptr, c_ptr,\n",
    "    M, N, K,\n",
    "    stride_am, stride_ak,\n",
    "    stride_bk, stride_bn,\n",
    "    stride_cm, stride_cn,\n",
    "    BLOCK_SIZE_M: tl.constexpr,\n",
    "    BLOCK_SIZE_N: tl.constexpr,\n",
    "    BLOCK_SIZE_K: tl.constexpr,\n",
    "):\n",
    "    \"\"\"\n",
    "    每个 block 负责计算 C 的一个 (BLOCK_SIZE_M, BLOCK_SIZE_N) tile。\n",
    "    \"\"\"\n",
    "    # TODO 1: 获取当前 block 在 grid 中的 (m, n) 坐标\n",
    "    pid = tl.program_id(0)\n",
    "    num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n",
    "    pid_m = pid // num_pid_n  # TODO: 改成正确计算\n",
    "    pid_n = pid % num_pid_n   # TODO: 改成正确计算\n",
    "\n",
    "    # 当前 block 负责的 C tile 的左上角偏移\n",
    "    offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n",
    "    offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n",
    "    offs_k = tl.arange(0, BLOCK_SIZE_K)\n",
    "\n",
    "    # 计算 A tile 和 B tile 的指针\n",
    "    a_ptrs = a_ptr + (offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak)\n",
    "    b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn)\n",
    "\n",
    "    # 初始化 accumulator\n",
    "    acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n",
    "\n",
    "    # 沿 K 维循环\n",
    "    for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):\n",
    "        # TODO 2: 加载 A tile 和 B tile，注意边界 mask\n",
    "        a = tl.load(a_ptrs, mask=offs_m[:, None] < M and offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)\n",
    "        b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K and offs_n[None, :] < N, other=0.0)\n",
    "\n",
    "        # TODO 3: 用 tl.dot 计算 partial sum 并累加到 acc\n",
    "        acc += tl.dot(a, b, acc)\n",
    "\n",
    "        # 指针沿 K 维滑动到下一个 tile\n",
    "        a_ptrs += BLOCK_SIZE_K * stride_ak\n",
    "        b_ptrs += BLOCK_SIZE_K * stride_bk\n",
    "\n",
    "    # 写回 C\n",
    "    c_ptrs = c_ptr + (offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn)\n",
    "    c_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)\n",
    "    tl.store(c_ptrs, acc.to(tl.float16), mask=c_mask)\n",
    "\n",
    "\n",
    "def tiled_matmul(a: torch.Tensor, b: torch.Tensor, BLOCK_SIZE_M=32, BLOCK_SIZE_N=32, BLOCK_SIZE_K=32):\n",
    "    assert a.shape[1] == b.shape[0]\n",
    "    M, K = a.shape\n",
    "    K2, N = b.shape\n",
    "    assert K == K2\n",
    "    c = torch.empty((M, N), device=a.device, dtype=a.dtype)\n",
    "    grid = (triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N),)\n",
    "    tiled_matmul_kernel[grid](\n",
    "        a, b, c,\n",
    "        M, N, K,\n",
    "        a.stride(0), a.stride(1),\n",
    "        b.stride(0), b.stride(1),\n",
    "        c.stride(0), c.stride(1),\n",
    "        BLOCK_SIZE_M=BLOCK_SIZE_M,\n",
    "        BLOCK_SIZE_N=BLOCK_SIZE_N,\n",
    "        BLOCK_SIZE_K=BLOCK_SIZE_K,\n",
    "    )\n",
    "    return c\n",
    "\n",
    "\n",
    "# 测试正确性\n",
    "if torch.cuda.is_available():\n",
    "    M, N, K = 256, 256, 256\n",
    "    a = torch.randn(M, K, device=\"cuda\", dtype=torch.float16)\n",
    "    b = torch.randn(K, N, device=\"cuda\", dtype=torch.float16)\n",
    "    c_ref = a @ b\n",
    "    c_tri = tiled_matmul(a, b, BLOCK_SIZE_M=32, BLOCK_SIZE_N=32, BLOCK_SIZE_K=32)\n",
    "    print(\"Correctness:\", torch.allclose(c_tri, c_ref, atol=1e-2, rtol=1e-2))\n",
    "    print(\"Max diff:\", (c_tri - c_ref).abs().max().item())\n",
    "else:\n",
    "    print(\"CUDA not available.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "506db676",
   "metadata": {},
   "source": [
    "### 9.5.5 练习 1：改变 BLOCK_SIZE，观察性能变化\n",
    "\n",
    "不同的 BLOCK_SIZE 会影响：\n",
    "- shared memory 使用量\n",
    "- occupancy\n",
    "- L2 cache 命中率\n",
    "- 边界检查 overhead\n",
    "\n",
    "尝试几组参数，记录运行时间。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3d68dd8b",
   "metadata": {},
   "outputs": [],
   "source": [
    "if torch.cuda.is_available():\n",
    "    M, N, K = 1024, 1024, 1024\n",
    "    a = torch.randn(M, K, device=\"cuda\", dtype=torch.float16)\n",
    "    b = torch.randn(K, N, device=\"cuda\", dtype=torch.float16)\n",
    "    c_ref = a @ b\n",
    "\n",
    "    configs = [\n",
    "        (16, 16, 16),\n",
    "        (32, 32, 32),\n",
    "        (64, 64, 64),\n",
    "        (64, 64, 32),\n",
    "        (128, 128, 32),\n",
    "    ]\n",
    "\n",
    "    results = []\n",
    "    for BM, BN, BK in configs:\n",
    "        try:\n",
    "            torch.cuda.empty_cache()\n",
    "            c_tri = tiled_matmul(a, b, BLOCK_SIZE_M=BM, BLOCK_SIZE_N=BN, BLOCK_SIZE_K=BK)\n",
    "            correct = torch.allclose(c_tri, c_ref, atol=1e-2, rtol=1e-2)\n",
    "            t = benchmark(f\"tiled_matmul({BM}x{BN}x{BK})\", lambda: tiled_matmul(a, b, BLOCK_SIZE_M=BM, BLOCK_SIZE_N=BN, BLOCK_SIZE_K=BK), num_warmups=2, num_trials=5)\n",
    "            results.append((BM, BN, BK, t, correct))\n",
    "            print(f\"{BM}x{BN}x{BK}: {t:.3f} ms, correct={correct}\")\n",
    "        except Exception as e:\n",
    "            print(f\"{BM}x{BN}x{BK} failed: {e}\")\n",
    "\n",
    "    # PyTorch baseline\n",
    "    t_ref = benchmark(\"pytorch_matmul\", lambda: a @ b, num_warmups=2, num_trials=5)\n",
    "    print(f\"\\nPyTorch: {t_ref:.3f} ms\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "89ad35f5",
   "metadata": {},
   "source": [
    "### 9.5.6 练习 2：分析 Shared Memory 用量与 Occupancy\n",
    "\n",
    "一个 block 的 shared memory 用量约为：\n",
    "\n",
    "```\n",
    "shmem_per_block = BLOCK_SIZE_M * BLOCK_SIZE_K * sizeof(dtype) + BLOCK_SIZE_K * BLOCK_SIZE_N * sizeof(dtype)\n",
    "```\n",
    "\n",
    "对于 FP16（2 bytes）：\n",
    "- 32x32x32: 32*32*2 + 32*32*2 = 4 KB\n",
    "- 64x64x64: 64*64*2 + 64*64*2 = 16 KB\n",
    "- 128x128x32: 128*32*2 + 32*128*2 = 16 KB\n",
    "\n",
    "RTX 5060 每个 SM 的 shared memory 大约 100 KB。据此估算：\n",
    "- 每个 SM 最多能同时驻留多少个 64x64x64 的 block？\n",
    "- 如果 shared memory 变成瓶颈，occupancy 会受到什么影响？"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db4718ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "def shmem_for_matmul(BM, BN, BK, dtype_bytes=2):\n",
    "    \"\"\"估算一个 block 的 shared memory 用量（近似）。\"\"\"\n",
    "    a_tile = BM * BK * dtype_bytes\n",
    "    b_tile = BK * BN * dtype_bytes\n",
    "    return a_tile + b_tile\n",
    "\n",
    "\n",
    "print(\"Shared memory usage per block:\")\n",
    "for BM, BN, BK in [(32, 32, 32), (64, 64, 64), (128, 128, 32), (128, 128, 64)]:\n",
    "    shmem = shmem_for_matmul(BM, BN, BK)\n",
    "    print(f\"  {BM}x{BN}x{BK}: {shmem / 1024:.1f} KB\")\n",
    "\n",
    "# 假设 SM 有 100 KB shared memory\n",
    "sm_smem_kb = 100\n",
    "print(f\"\\nWith {sm_smem_kb} KB shared memory per SM:\")\n",
    "for BM, BN, BK in [(64, 64, 64), (128, 128, 32), (128, 128, 64)]:\n",
    "    shmem_kb = shmem_for_matmul(BM, BN, BK) / 1024\n",
    "    max_blocks_by_smem = sm_smem_kb // shmem_kb\n",
    "    print(f\"  {BM}x{BN}x{BK}: max {max_blocks_by_smem} blocks per SM (by shared memory)\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "301bbbc1",
   "metadata": {},
   "source": [
    "### 9.5.7 练习 3：对比 Tiled vs PyTorch MatMul\n",
    "\n",
    "上面的 tiled kernel 只是一个教学实现，没有做很多高级优化（如：\n",
    "- 向量加载/存储\n",
    "- 多级 tiling（register-level + shared-level）\n",
    "- L2 cache 友好的 pid 分组\n",
    "- warp-level 优化\n",
    "\n",
    "所以通常比 PyTorch/CUTLASS 慢。这很正常，重点是理解 tiling 的思想。\n",
    "\n",
    "思考题：\n",
    "1. 为什么增大 BLOCK_SIZE 不总是加速？\n",
    "2. 为什么 128x128x32 可能比 64x64x64 快或慢？\n",
    "3. 如果要继续优化这个 kernel，你会先改哪里？"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6ae1833b",
   "metadata": {},
   "outputs": [],
   "source": [
    "if torch.cuda.is_available():\n",
    "    sizes = [512, 1024, 2048]\n",
    "    for size in sizes:\n",
    "        a = torch.randn(size, size, device=\"cuda\", dtype=torch.float16)\n",
    "        b = torch.randn(size, size, device=\"cuda\", dtype=torch.float16)\n",
    "        t_ref = benchmark(f\"pytorch_matmul({size})\", lambda: a @ b, num_warmups=2, num_trials=5)\n",
    "        t_tri = benchmark(f\"tiled_matmul({size}, 64x64x32)\",\n",
    "                          lambda: tiled_matmul(a, b, BLOCK_SIZE_M=64, BLOCK_SIZE_N=64, BLOCK_SIZE_K=32),\n",
    "                          num_warmups=2, num_trials=5)\n",
    "        print(f\"size={size}: PyTorch={t_ref:.2f} ms, Tiled={t_tri:.2f} ms, ratio={t_tri/t_ref:.2f}x\")\n",
    "else:\n",
    "    print(\"CUDA not available.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "55b2ca9f",
   "metadata": {},
   "source": [
    "### 9.5.8 小结\n",
    "\n",
    "分块（tiling）是 GPU kernel 优化的核心技巧：\n",
    "\n",
    "1. **问题**：naive matmul 重复读取 HBM，memory-bound。\n",
    "2. **思想**：把数据切成小块加载到 shared memory，复用多次。\n",
    "3. **关键参数**：BLOCK_SIZE_M/N/K 影响 shared memory、occupancy、边界 overhead。\n",
    "4. **现代框架**：PyTorch/CUTLASS 的 matmul 本质上就是极其优化的 tiled kernel，通常还有多级 tiling 和 Tensor Core。\n",
    "\n",
    "下一步：可以尝试把上面的 tiled matmul 改成用 Tensor Core（FP16 下 `tl.dot` 已经会尝试用 Tensor Core），或者加入 L2 cache 友好的 pid 分组。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "47fa4959",
   "metadata": {},
   "source": [
    "\n",
    "## Part 10：Nsight Systems / Nsight Compute 实操\n",
    "\n",
    "### 10.1 Nsight Systems (`nsys`)\n",
    "\n",
    "`nsys profile` 记录整个应用的 CPU + GPU timeline，适合看：\n",
    "\n",
    "- 哪些 kernel 占了最多时间。\n",
    "- CPU 和 GPU 之间有没有空闲间隙（launch latency）。\n",
    "- CUDA API 调用顺序。\n",
    "\n",
    "基本命令：\n",
    "\n",
    "```bash\n",
    "nsys profile --stats=true -o output_name python your_script.py\n",
    "```\n",
    "\n",
    "### 10.2 Nsight Compute (`ncu`)\n",
    "\n",
    "`ncu` 分析单个 kernel 的微观指标：\n",
    "\n",
    "- `sm__warps_active.avg.pct_of_peak_sustained_elapsed`：SM 利用率 / occupancy。\n",
    "- `dram__bytes.sum.per_second`：显存带宽利用率。\n",
    "- `smsp__sass_thread_inst_executed_op_fadd_pred_on.sum`：实际执行的 FLOPs。\n",
    "\n",
    "### 10.3 实践：分析 GeLU\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ac69325f",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# 生成一个可独立运行的 profile 脚本\n",
    "profile_script = r'''\n",
    "import sys\n",
    "import os\n",
    "sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))\n",
    "\n",
    "import torch\n",
    "\n",
    "def manual_gelu(x):\n",
    "    return 0.5 * x * (1 + torch.tanh(0.79788456 * (x + 0.044715 * x * x * x)))\n",
    "\n",
    "def main():\n",
    "    device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "    x = torch.randn(4096 * 4096, device=device)\n",
    "    for _ in range(3):\n",
    "        y = manual_gelu(x)\n",
    "    if device == \"cuda\":\n",
    "        torch.cuda.synchronize()\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    main()\n",
    "'''\n",
    "os.makedirs(\"scripts\", exist_ok=True)\n",
    "with open(\"scripts/profile_gelu.py\", \"w\") as f:\n",
    "    f.write(profile_script)\n",
    "print(\"scripts/profile_gelu.py written.\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ef301c12",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 在 Notebook 中运行 nsys profile（注意：会花一些时间）\n",
    "import subprocess\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    cmd = [\n",
    "        \"nsys\", \"profile\", \"--stats=true\", \"--no-wait\",\n",
    "        \"-o\", \"var/nsys_gelu\",\n",
    "        \".venv/bin/python\", \"scripts/profile_gelu.py\",\n",
    "    ]\n",
    "    print(\"Running:\", \" \".join(cmd))\n",
    "    result = subprocess.run(cmd, capture_output=True, text=True, timeout=300)\n",
    "    print(result.stdout[-3000:])\n",
    "    print(result.stderr[-1000:])\n",
    "else:\n",
    "    print(\"CUDA not available, skip nsys profile.\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2af1be72",
   "metadata": {},
   "source": [
    "\n",
    "### 10.4 如何读 nsys 报告\n",
    "\n",
    "nsys 输出通常包含：\n",
    "\n",
    "- **CUDA API Statistics**：每个 CUDA API 调用的次数和总时间。\n",
    "- **CUDA Kernel Statistics**：每个 kernel 的执行次数、平均/最大/总时间。\n",
    "- **Memory Operation Statistics**：HtoD/DtoH 传输次数。\n",
    "\n",
    "重点关注：\n",
    "1. **Time(%) 最高的 kernel**：这是你的瓶颈。\n",
    "2. **CPU 是否在等 GPU**：看 timeline 里有没有大片空白。\n",
    "3. **kernel launch overhead**：小 kernel 太多会导致 launch 开销占比高。\n",
    "\n",
    "生成的 `.nsys-rep` 可以用 Nsight Systems GUI 打开：\n",
    "\n",
    "```bash\n",
    "nsys-ui var/nsys_gelu.nsys-rep\n",
    "```\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f69ff7a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# 用 ncu 分析单个 kernel（需要知道 kernel 名称，可用 nsys 或 torch.profiler 获取）\n",
    "if torch.cuda.is_available():\n",
    "    cmd = [\n",
    "        \"ncu\",\n",
    "        \"--kernel-name\", \"regex:gelsu\",  # GeLU backward 名称示例\n",
    "        \"--metrics\", \"sm__warps_active.avg.pct_of_peak_sustained_elapsed,dram__bytes.sum.per_second\",\n",
    "        \".venv/bin/python\", \"scripts/profile_gelu.py\",\n",
    "    ]\n",
    "    print(\"Example ncu command:\")\n",
    "    print(\" \".join(cmd))\n",
    "    print(\"\\\\nNote: run this in terminal if you want actual metrics.\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d98af89f",
   "metadata": {},
   "source": [
    "\n",
    "### 10.5 用 nvtx 标注代码区间\n",
    "\n",
    "`nvtx` 可以在代码里打标记，nsys timeline 上就会显示这些区间，方便定位。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "60374521",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "import nvtx\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    with nvtx.annotate(\"gelu_loop\", color=\"green\"):\n",
    "        for _ in range(5):\n",
    "            _ = manual_gelu(x)\n",
    "    torch.cuda.synchronize()\n",
    "    print(\"nvtx annotated region executed.\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1d738639",
   "metadata": {},
   "source": [
    "\n",
    "## Part 11：总结与练习\n",
    "\n",
    "### 11.1 核心原则\n",
    "\n",
    "1. **Measure first**：先 benchmark / profile，再优化。\n",
    "2. **Minimize reads/writes**：kernel fusion 和 tiling 是核心手段。\n",
    "3. **Know your bottleneck**：compute-bound 提 FLOPs，memory-bound 提访存效率。\n",
    "4. **Use the right tool**：\n",
    "   - 日常开发：`torch.compile` + PyTorch fused ops。\n",
    "   - 自定义融合：Triton。\n",
    "   - 极致优化或特殊硬件：CUDA。\n",
    "   - 性能分析：`nsys` + `ncu` + `torch.profiler`。\n",
    "\n",
    "### 11.2 自测练习\n",
    "\n",
    "#### 练习 1：手写 ReLU 的 CUDA 和 Triton kernel\n",
    "\n",
    "要求：\n",
    "- CUDA：`relu.cu`，输入输出都是 float32。\n",
    "- Triton：`triton_relu_kernel`，支持任意长度向量。\n",
    "- 与 `torch.relu` 对比正确性和速度。\n",
    "\n",
    "#### 练习 2：手写 row-wise sum 的 Triton kernel\n",
    "\n",
    "输入 shape `(M, N)`，输出 shape `(M,)`，对每行求和。\n",
    "\n",
    "提示：和 softmax 类似，但不需要 max/exp。\n",
    "\n",
    "#### 练习 3：用 nsys 分析训练循环\n",
    "\n",
    "对下面这个训练循环跑 `nsys profile`：\n",
    "\n",
    "```python\n",
    "model = MLP(512, 8).cuda()\n",
    "opt = torch.optim.AdamW(model.parameters(), lr=1e-3)\n",
    "x = torch.randn(128, 512, device=\"cuda\")\n",
    "for _ in range(20):\n",
    "    opt.zero_grad()\n",
    "    loss = model(x).mean()\n",
    "    loss.backward()\n",
    "    opt.step()\n",
    "```\n",
    "\n",
    "回答：\n",
    "- 时间占比最高的前 3 个 kernel 是什么？\n",
    "- 有没有 kernel launch gap？\n",
    "- 优化器 step 是否占大量时间？\n",
    "\n",
    "### 11.3 进阶阅读\n",
    "\n",
    "- [Horace He's GPU performance intro](https://horace.io/brrr_intro.html)\n",
    "- [Triton fused softmax tutorial](https://triton-lang.org/main/getting-started/tutorials/02-fused-softmax.html)\n",
    "- [Triton matmul tutorial](https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html)\n",
    "- [CUDA MODE YouTube playlist](https://www.youtube.com/@CUDAMODE)\n",
    "- [GPU Puzzles](https://github.com/srush/gpu-puzzles)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.11.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
