{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "07da4766",
   "metadata": {},
   "source": [
    "\n",
    "# Roofline Model 与 Dtype：为什么 FP16 不一定更 compute-bound？\n",
    "\n",
    "## 核心问题\n",
    "\n",
    "传统 roofline model 用：\n",
    "\n",
    "\\[\n",
    "AI_{std} = \\frac{\\text{abstract FLOPs}}{\\text{HBM bytes}}\n",
    "\\]\n",
    "\n",
    "认为 FP32 → FP16 时，bytes 减半，FLOPs 不变，所以 $AI_{std}$ 翻倍，workload 更偏向 compute-bound。\n",
    "\n",
    "**但这个结论可能是错的**，因为：\n",
    "\n",
    "1. **FLOP 不是电路工作的均匀单位**：FP16 FMA 的晶体管开关次数、能耗、延迟都显著低于 FP32 FMA。\n",
    "2. **Peak throughput 增长可能比内存带宽快得多**：Ampere/Hopper 的 FP16 Tensor Core peak 可以是 FP32 的 8x~16x。\n",
    "3. 因此即使 $AI_{std}$ 翻倍，workload 距离该 dtype 的 ridge point 可能反而更远了。\n",
    "\n",
    "## 本实验目标\n",
    "\n",
    "同时测量三套量：\n",
    "\n",
    "| 量 | 定义 | 意义 |\n",
    "|---|---|---|\n",
    "| $AI_{std}$ | FLOPs / HBM bytes | 传统 roofline 强度 |\n",
    "| $I_{bit}$ | bit-level compute proxy / HBM bits | 按位/电路代价归一化 |\n",
    "| $\\rho_{roofline}$ | $AI_{std} / (P_{peak,dtype} / B_{mem})$ | workload 到该 dtype ridge point 的距离 |\n",
    "\n",
    "$\\rho$ 越小，workload 越靠近 memory-bound；$\\rho$ 越大，越靠近 compute-bound。\n",
    "\n",
    "## 实验设置\n",
    "\n",
    "- GPU: RTX 5060 (Blackwell, sm_120)\n",
    "- Workloads: GEMM, MLP forward/backward\n",
    "- Dtypes: FP32, TF32, FP16, BF16\n",
    "- 测量工具: `torch.profiler` + `nvtx` 标注，必要时用 `ncu` 获取更精确内存指标\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0d6d99c0",
   "metadata": {},
   "source": [
    "\n",
    "## Part 1: 环境设置与工具函数\n",
    "\n",
    "我们先定义测量函数、FLOPs 估算函数、peak throughput 测量函数。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b9f3e58",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "import time\n",
    "import warnings\n",
    "from collections import defaultdict\n",
    "from typing import Dict, List, Tuple\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import nvtx\n",
    "\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "print(f\"Device: {device}\")\n",
    "if device == \"cuda\":\n",
    "    print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n",
    "    print(f\"CUDA: {torch.version.cuda}\")\n",
    "    print(f\"PyTorch: {torch.__version__}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8f142714",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def benchmark(fn, num_warmups=3, num_trials=5):\n",
    "    # benchmark function\n",
    "    for _ in range(num_warmups):\n",
    "        fn()\n",
    "    if device == \"cuda\":\n",
    "        torch.cuda.synchronize()\n",
    "    times = []\n",
    "    for _ in range(num_trials):\n",
    "        start = time.time()\n",
    "        fn()\n",
    "        if device == \"cuda\":\n",
    "            torch.cuda.synchronize()\n",
    "        times.append((time.time() - start) * 1000)\n",
    "    return sum(times) / len(times), times\n",
    "\n",
    "\n",
    "def reset_cuda():\n",
    "    if device == \"cuda\":\n",
    "        torch.cuda.empty_cache()\n",
    "        torch.cuda.reset_peak_memory_stats()\n",
    "\n",
    "\n",
    "def bytes_to_str(b):\n",
    "    for unit in [\"B\", \"KB\", \"MB\", \"GB\"]:\n",
    "        if b < 1024:\n",
    "            return f\"{b:.2f} {unit}\"\n",
    "        b /= 1024\n",
    "    return f\"{b:.2f} TB\"\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d73b4647",
   "metadata": {},
   "source": [
    "\n",
    "## Part 2: 测量各 dtype 的 Peak Compute Throughput\n",
    "\n",
    "我们用一个大尺寸 GEMM 来逼近各 dtype 的峰值吞吐。\n",
    "\n",
    "注意：\n",
    "- FP32: 用普通 `torch.matmul`\n",
    "- TF32: 设置 `torch.backends.cudnn.allow_tf32 = True`\n",
    "- FP16/BF16: 用 `torch.matmul`\n",
    "- 大矩阵有利于让 GPU 进入 compute-bound 状态\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c3a52f46",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def measure_peak_matmul(dtype: torch.dtype, m: int = 8192, k: int = 8192, n: int = 8192) -> Tuple[float, float]:\n",
    "    # measure peak matmul\n",
    "    reset_cuda()\n",
    "\n",
    "    # TF32 控制\n",
    "    if dtype == torch.float32:\n",
    "        torch.backends.cuda.matmul.allow_tf32 = False\n",
    "        torch.backends.cudnn.allow_tf32 = False\n",
    "    elif dtype == torch.float32 and \"tf32\" in str(dtype):\n",
    "        # 我们用字符串标记区分 TF32 和 FP32\n",
    "        pass\n",
    "\n",
    "    a = torch.randn(m, k, device=device, dtype=dtype)\n",
    "    b = torch.randn(k, n, device=device, dtype=dtype)\n",
    "\n",
    "    def fn():\n",
    "        c = torch.matmul(a, b)\n",
    "        return c\n",
    "\n",
    "    elapsed_ms, _ = benchmark(fn, num_warmups=5, num_trials=10)\n",
    "\n",
    "    # FLOPs = 2 * M * K * N\n",
    "    flops = 2 * m * k * n\n",
    "    tflops = flops / (elapsed_ms / 1000) / 1e12\n",
    "    return tflops, elapsed_ms\n",
    "\n",
    "\n",
    "# 由于 torch dtype 不能区分 FP32 和 TF32，我们用一个辅助结构\n",
    "DTYPE_CONFIGS = [\n",
    "    (\"FP32\", torch.float32, lambda: torch.backends.cuda.matmul.allow_tf32.__set__),\n",
    "    (\"TF32\", torch.float32, None),\n",
    "    (\"FP16\", torch.float16, None),\n",
    "    (\"BF16\", torch.bfloat16, None),\n",
    "]\n",
    "\n",
    "\n",
    "def measure_peak_all():\n",
    "    results = {}\n",
    "    for name, dtype, _ in DTYPE_CONFIGS:\n",
    "        if name == \"FP32\":\n",
    "            torch.backends.cuda.matmul.allow_tf32 = False\n",
    "            torch.backends.cudnn.allow_tf32 = False\n",
    "        elif name == \"TF32\":\n",
    "            torch.backends.cuda.matmul.allow_tf32 = True\n",
    "            torch.backends.cudnn.allow_tf32 = True\n",
    "\n",
    "        try:\n",
    "            tflops, elapsed = measure_peak_matmul(dtype)\n",
    "            results[name] = {\"tflops\": tflops, \"elapsed_ms\": elapsed}\n",
    "            print(f\"{name}: {tflops:.2f} TFLOP/s ({elapsed:.2f} ms)\")\n",
    "        except Exception as e:\n",
    "            print(f\"{name} failed: {e}\")\n",
    "            results[name] = {\"tflops\": 0, \"elapsed_ms\": 0}\n",
    "    return results\n",
    "\n",
    "\n",
    "peak_results = measure_peak_all()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0b20ecfd",
   "metadata": {},
   "source": [
    "\n",
    "## Part 3: 定义 Workloads\n",
    "\n",
    "我们测两个典型 workload：\n",
    "\n",
    "1. **GEMM**: $C = A \\times B$，shape $(M, K) \\times (K, N)$\n",
    "2. **MLP**: 多层 Linear + GeLU，forward + backward\n",
    "\n",
    "对每个 workload，在不同 dtype 下测量：\n",
    "- 耗时\n",
    "- HBM bytes（通过 profiler 估计）\n",
    "- FLOPs\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "26ef1901",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def gemm_workload(dtype: torch.dtype, m: int = 2048, k: int = 2048, n: int = 2048):\n",
    "    a = torch.randn(m, k, device=device, dtype=dtype)\n",
    "    b = torch.randn(k, n, device=device, dtype=dtype)\n",
    "\n",
    "    def fn():\n",
    "        c = torch.matmul(a, b)\n",
    "        return c\n",
    "\n",
    "    flops = 2 * m * k * n\n",
    "    return fn, flops\n",
    "\n",
    "\n",
    "class MLP(nn.Module):\n",
    "    def __init__(self, dim: int, num_layers: int):\n",
    "        super().__init__()\n",
    "        self.layers = nn.ModuleList([nn.Linear(dim, dim) for _ in range(num_layers)])\n",
    "\n",
    "    def forward(self, x):\n",
    "        for layer in self.layers:\n",
    "            x = layer(x)\n",
    "            x = torch.nn.functional.gelu(x)\n",
    "        return x\n",
    "\n",
    "\n",
    "def mlp_workload(dtype: torch.dtype, dim: int = 1024, num_layers: int = 4, batch_size: int = 256):\n",
    "    model = MLP(dim, num_layers).to(device).to(dtype)\n",
    "    x = torch.randn(batch_size, dim, device=device, dtype=dtype, requires_grad=True)\n",
    "\n",
    "    def fn():\n",
    "        y = model(x).mean()\n",
    "        y.backward()\n",
    "        return y\n",
    "\n",
    "    # 估算 FLOPs：每层 Linear forward 2*batch*dim^2，backward 约 4*batch*dim^2，GeLU 约 8*batch*dim\n",
    "    per_layer_flops = 2 * batch_size * dim * dim  # forward\n",
    "    per_layer_flops += 4 * batch_size * dim * dim  # backward\n",
    "    per_layer_flops += 8 * batch_size * dim        # GeLU approx\n",
    "    flops = per_layer_flops * num_layers\n",
    "    return fn, flops\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c485426b",
   "metadata": {},
   "source": [
    "\n",
    "## Part 4: 用 torch.profiler 测量 HBM Bytes\n",
    "\n",
    "`torch.profiler` 可以给出 CUDA memory 使用情况。我们用 `record_shapes=True` 和 `profile_memory=True` 来追踪。\n",
    "\n",
    "注意：这个 bytes 是 allocator 层面的统计，不是精确的 HBM traffic。更精确需要用 `ncu`，但 `torch.profiler` 足够做趋势分析。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "68931c12",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "from torch.profiler import ProfilerActivity, profile\n",
    "\n",
    "\n",
    "def profile_workload(name: str, fn, dtype_name: str):\n",
    "    # 测量 workload，返回耗时、FLOPs、HBM bytes（估算）\n",
    "    # 设置 TF32\n",
    "    if dtype_name == \"FP32\":\n",
    "        torch.backends.cuda.matmul.allow_tf32 = False\n",
    "        torch.backends.cudnn.allow_tf32 = False\n",
    "    elif dtype_name == \"TF32\":\n",
    "        torch.backends.cuda.matmul.allow_tf32 = True\n",
    "        torch.backends.cudnn.allow_tf32 = True\n",
    "\n",
    "    reset_cuda()\n",
    "\n",
    "    # warmup\n",
    "    for _ in range(3):\n",
    "        fn()\n",
    "    if device == \"cuda\":\n",
    "        torch.cuda.synchronize()\n",
    "\n",
    "    with profile(\n",
    "        activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],\n",
    "        profile_memory=True,\n",
    "        record_shapes=True,\n",
    "        with_stack=False,\n",
    "    ) as prof:\n",
    "        fn()\n",
    "        if device == \"cuda\":\n",
    "            torch.cuda.synchronize()\n",
    "\n",
    "    # 从 profiler 提取 CUDA 总时间\n",
    "    events = prof.key_averages()\n",
    "    cuda_time_us = sum(e.device_time_total for e in events if e.device_time_total > 0)\n",
    "\n",
    "    # 估算 HBM bytes：用 allocated_bytes 峰值 - 初始值（近似）\n",
    "    mem_stats = torch.cuda.memory_stats() if device == \"cuda\" else {}\n",
    "    hbm_bytes = mem_stats.get(\"active_bytes.all.peak\", 0) + mem_stats.get(\"allocated_bytes.all.peak\", 0)\n",
    "\n",
    "    # 用 nvtx 再跑 benchmark 测 wall-clock\n",
    "    elapsed_ms, _ = benchmark(fn, num_warmups=2, num_trials=5)\n",
    "\n",
    "    return {\n",
    "        \"elapsed_ms\": elapsed_ms,\n",
    "        \"cuda_time_us\": cuda_time_us,\n",
    "        \"hbm_bytes\": hbm_bytes,\n",
    "    }\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b8867bde",
   "metadata": {},
   "source": [
    "\n",
    "## Part 5: 运行实验\n",
    "\n",
    "对 GEMM 和 MLP，分别在 FP32/TF32/FP16/BF16 下测量：\n",
    "\n",
    "- $T$（耗时）\n",
    "- $F$（FLOPs）\n",
    "- $B$（HBM bytes，近似）\n",
    "- $P_{peak}$（该 dtype 的峰值 TFLOP/s）\n",
    "\n",
    "然后计算：\n",
    "\n",
    "\\[\n",
    "AI_{std} = \\frac{F}{B}\n",
    "\\]\n",
    "\n",
    "\\[\n",
    "I_{bit} = \\frac{F \\times w_{dtype}}{B \\times b_{dtype}}\n",
    "\\]\n",
    "\n",
    "其中 $w_{dtype}$ 是 circuit-level weight（FP32=1.0, TF32≈0.6, FP16≈0.25, BF16≈0.25），$b_{dtype}$ 是每个 element 的 bit 数。\n",
    "\n",
    "\\[\n",
    "\\rho_{roofline} = \\frac{AI_{std}}{P_{peak} / B_{mem}}\n",
    "\\]\n",
    "\n",
    "$B_{mem}$ 是实测或标称的 HBM 带宽。RTX 5060 Laptop 的 HBM（GDDR6/6X）带宽约 128~256 GB/s，这里先用一个保守估计 128 GB/s，并说明这是 scaling 分析，不是绝对值。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "95f5eaef",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# 定义 dtype 的 bit-level circuit weight\n",
    "# 这些是经验估计，用于趋势分析；实际值会随架构变化\n",
    "DTYPE_INFO = {\n",
    "    \"FP32\": {\"bits\": 32, \"circuit_weight\": 1.00},\n",
    "    \"TF32\": {\"bits\": 32, \"circuit_weight\": 0.55},  # 10-bit mantissa, 仍占 32 bits，但计算电路更接近 FP16\n",
    "    \"FP16\": {\"bits\": 16, \"circuit_weight\": 0.22},\n",
    "    \"BF16\": {\"bits\": 16, \"circuit_weight\": 0.20},\n",
    "}\n",
    "\n",
    "# 保守估计 HBM 带宽（用于 ridge point 计算）\n",
    "HBM_BANDWIDTH_GBPS = 128  # GB/s，RTX 5060 Laptop GDDR6X 约 128-256\n",
    "\n",
    "\n",
    "def run_experiment(workload_fn, workload_name: str, shapes: dict):\n",
    "    results = []\n",
    "    for dtype_name, dtype, _ in DTYPE_CONFIGS:\n",
    "        print(\"\\n--- \" + workload_name + \" / \" + dtype_name + \" ---\")\n",
    "\n",
    "        if dtype_name == \"FP32\":\n",
    "            torch.backends.cuda.matmul.allow_tf32 = False\n",
    "            torch.backends.cudnn.allow_tf32 = False\n",
    "        elif dtype_name == \"TF32\":\n",
    "            torch.backends.cuda.matmul.allow_tf32 = True\n",
    "            torch.backends.cudnn.allow_tf32 = True\n",
    "\n",
    "        fn, flops = workload_fn(dtype, **shapes)\n",
    "        profile = profile_workload(workload_name, fn, dtype_name)\n",
    "\n",
    "        bytes_used = profile[\"hbm_bytes\"]\n",
    "        elapsed_s = profile[\"elapsed_ms\"] / 1000\n",
    "\n",
    "        # compute intensities\n",
    "        ai_std = flops / bytes_used if bytes_used > 0 else 0\n",
    "        tflops = flops / elapsed_s / 1e12\n",
    "        peak_tflops = peak_results[dtype_name][\"tflops\"]\n",
    "        ridge_point = peak_tflops / HBM_BANDWIDTH_GBPS\n",
    "        rho = ai_std / ridge_point if ridge_point > 0 else 0\n",
    "\n",
    "        # bit-level intensity\n",
    "        w = DTYPE_INFO[dtype_name][\"circuit_weight\"]\n",
    "        b = DTYPE_INFO[dtype_name][\"bits\"]\n",
    "        i_bit = (flops * w) / (bytes_used * 8) * (32 / b) if bytes_used > 0 else 0\n",
    "        # 上面乘以 32/b 是为了把分母都统一到 32-bit 等效 bytes\n",
    "\n",
    "        results.append({\n",
    "            \"dtype\": dtype_name,\n",
    "            \"workload\": workload_name,\n",
    "            \"elapsed_ms\": profile[\"elapsed_ms\"],\n",
    "            \"tflops\": tflops,\n",
    "            \"peak_tflops\": peak_tflops,\n",
    "            \"flops\": flops,\n",
    "            \"bytes\": bytes_used,\n",
    "            \"ai_std\": ai_std,\n",
    "            \"i_bit\": i_bit,\n",
    "            \"ridge_point\": ridge_point,\n",
    "            \"rho\": rho,\n",
    "        })\n",
    "\n",
    "        print(f\"  elapsed: {profile['elapsed_ms']:.2f} ms\")\n",
    "        print(f\"  achieved: {tflops:.2f} TFLOP/s / peak: {peak_tflops:.2f} TFLOP/s\")\n",
    "        print(f\"  HBM bytes: {bytes_to_str(bytes_used)}\")\n",
    "        print(f\"  AI_std: {ai_std:.2f} FLOP/byte\")\n",
    "        print(f\"  I_bit: {i_bit:.2f}\")\n",
    "        print(f\"  ridge point: {ridge_point:.2f}\")\n",
    "        print(f\"  rho: {rho:.3f}\")\n",
    "\n",
    "    return results\n",
    "\n",
    "\n",
    "# GEMM experiment\n",
    "gemm_results = run_experiment(gemm_workload, \"GEMM\", {\"m\": 4096, \"k\": 4096, \"n\": 4096})\n",
    "\n",
    "# MLP experiment\n",
    "mlp_results = run_experiment(mlp_workload, \"MLP\", {\"dim\": 2048, \"num_layers\": 4, \"batch_size\": 512})\n",
    "\n",
    "all_results = gemm_results + mlp_results\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "83792b02",
   "metadata": {},
   "source": [
    "\n",
    "## Part 6: 可视化结果\n",
    "\n",
    "我们画四张图：\n",
    "\n",
    "1. **Achieved TFLOP/s vs Peak TFLOP/s**：看 utilization。\n",
    "2. **AI_std 随 dtype 变化**：传统 roofline 指标。\n",
    "3. **I_bit 随 dtype 变化**：bit-level 强度。\n",
    "4. **ρ（到 ridge point 的距离）随 dtype 变化**：核心结论图。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46e6a256",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def plot_results(results: List[Dict], title: str):\n",
    "    dtypes = [r[\"dtype\"] for r in results]\n",
    "    x = np.arange(len(dtypes))\n",
    "    width = 0.35\n",
    "\n",
    "    fig, axes = plt.subplots(2, 2, figsize=(14, 10))\n",
    "\n",
    "    # Plot 1: Achieved vs Peak TFLOP/s\n",
    "    ax = axes[0, 0]\n",
    "    achieved = [r[\"tflops\"] for r in results]\n",
    "    peak = [r[\"peak_tflops\"] for r in results]\n",
    "    ax.bar(x - width/2, achieved, width, label=\"Achieved\", color=\"#4a90d9\")\n",
    "    ax.bar(x + width/2, peak, width, label=\"Peak\", color=\"#90caf9\")\n",
    "    ax.set_xticks(x)\n",
    "    ax.set_xticklabels(dtypes)\n",
    "    ax.set_ylabel(\"TFLOP/s\")\n",
    "    ax.set_title(f\"{title}: Achieved vs Peak TFLOP/s\")\n",
    "    ax.legend()\n",
    "    ax.grid(axis=\"y\", linestyle=\"--\", alpha=0.5)\n",
    "\n",
    "    # Plot 2: AI_std\n",
    "    ax = axes[0, 1]\n",
    "    ai_std = [r[\"ai_std\"] for r in results]\n",
    "    ax.bar(dtypes, ai_std, color=\"#81c784\")\n",
    "    ax.set_ylabel(\"FLOPs / byte\")\n",
    "    ax.set_title(f\"{title}: Traditional Arithmetic Intensity (AI_std)\")\n",
    "    ax.grid(axis=\"y\", linestyle=\"--\", alpha=0.5)\n",
    "\n",
    "    # Plot 3: I_bit\n",
    "    ax = axes[1, 0]\n",
    "    i_bit = [r[\"i_bit\"] for r in results]\n",
    "    ax.bar(dtypes, i_bit, color=\"#ffb74d\")\n",
    "    ax.set_ylabel(\"Bit-level intensity\")\n",
    "    ax.set_title(f\"{title}: Bit-level Intensity (I_bit)\")\n",
    "    ax.grid(axis=\"y\", linestyle=\"--\", alpha=0.5)\n",
    "\n",
    "    # Plot 4: rho\n",
    "    ax = axes[1, 1]\n",
    "    rho = [r[\"rho\"] for r in results]\n",
    "    colors = [\"#e57373\" if v < 1 else \"#64b5f6\" for v in rho]\n",
    "    ax.bar(dtypes, rho, color=colors)\n",
    "    ax.axhline(y=1.0, color=\"red\", linestyle=\"--\", label=\"Ridge point (rho=1)\")\n",
    "    ax.set_ylabel(\"rho = AI_std / (P_peak / B_mem)\")\n",
    "    ax.set_title(f\"{title}: Distance to Ridge Point (rho)\")\n",
    "    ax.legend()\n",
    "    ax.grid(axis=\"y\", linestyle=\"--\", alpha=0.5)\n",
    "\n",
    "    plt.tight_layout()\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "plot_results(gemm_results, \"GEMM\")\n",
    "plot_results(mlp_results, \"MLP\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "441a2209",
   "metadata": {},
   "source": [
    "\n",
    "## Part 7: 传统 Roofline 图\n",
    "\n",
    "在同一个坐标系里画出各 dtype 的 ridge point 和各 workload 的 $AI_{std}$。\n",
    "\n",
    "- 横轴：$AI_{std}$ (FLOPs/byte)\n",
    "- 纵轴：Performance (TFLOP/s)\n",
    "- 每条 dtype 的 roofline 由 $P_{peak}$ 和 $B_{mem}$ 决定\n",
    "- 点表示该 workload 在该 dtype 下的实测位置\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8a6e487f",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def plot_roofline(results: List[Dict], title: str):\n",
    "    fig, ax = plt.subplots(figsize=(10, 6))\n",
    "\n",
    "    ai_range = np.logspace(\n",
    "        np.log10(min(r[\"ai_std\"] for r in results) * 0.5),\n",
    "        np.log10(max(r[\"ai_std\"] for r in results) * 2),\n",
    "        100,\n",
    "    )\n",
    "\n",
    "    colors = {\"FP32\": \"#1976d2\", \"TF32\": \"#388e3c\", \"FP16\": \"#f57c00\", \"BF16\": \"#7b1fa2\"}\n",
    "\n",
    "    for r in results:\n",
    "        dtype = r[\"dtype\"]\n",
    "        peak = r[\"peak_tflops\"]\n",
    "        ridge = r[\"ridge_point\"]\n",
    "        roof = np.minimum(peak, HBM_BANDWIDTH_GBPS * ai_range / 1000)  # GB/s * FLOP/byte / 1000 = TFLOP/s\n",
    "        ax.plot(ai_range, roof, label=f\"{dtype} roofline\", color=colors.get(dtype, \"#333\"), alpha=0.6)\n",
    "        ax.scatter([r[\"ai_std\"]], [r[\"tflops\"]], color=colors.get(dtype, \"#333\"), s=100, zorder=5)\n",
    "        ax.annotate(\n",
    "            dtype,\n",
    "            (r[\"ai_std\"], r[\"tflops\"]),\n",
    "            textcoords=\"offset points\",\n",
    "            xytext=(5, 5),\n",
    "            fontsize=9,\n",
    "        )\n",
    "\n",
    "    ax.set_xscale(\"log\")\n",
    "    ax.set_yscale(\"log\")\n",
    "    ax.set_xlabel(\"Arithmetic Intensity AI_std (FLOPs/byte)\")\n",
    "    ax.set_ylabel(\"Performance (TFLOP/s)\")\n",
    "    ax.set_title(f\"{title}: Roofline Plot\")\n",
    "    ax.legend()\n",
    "    ax.grid(True, which=\"both\", linestyle=\"--\", alpha=0.5)\n",
    "    plt.tight_layout()\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "plot_roofline(gemm_results, \"GEMM\")\n",
    "plot_roofline(mlp_results, \"MLP\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e8feb5ac",
   "metadata": {},
   "source": [
    "\n",
    "## Part 8: 讨论与结论\n",
    "\n",
    "### 预期发现\n",
    "\n",
    "1. **GEMM**：大矩阵乘法的 $AI_{std}$ 很高，FP16 的 $\\rho$ 可能接近或大于 1，接近 compute-bound。\n",
    "2. **MLP**：activation/weight 访问更频繁，$AI_{std}$ 较低，FP16 的 $\\rho$ 可能反而比 FP32 小，意味着更难喂满 FP16 的 peak。\n",
    "3. **TF32 的有趣位置**：\n",
    "   - bytes 和 FP32 一样（32 bits）。\n",
    "   - 但 compute 接近 FP16。\n",
    "   - 所以它的 ridge point 可能很高，但 $AI_{std}$ 不变 → $\\rho$ 可能显著降低。\n",
    "\n",
    "### 局限\n",
    "\n",
    "- HBM bytes 来自 PyTorch allocator 统计，不是精确的 HBM traffic。\n",
    "- `circuit_weight` 是经验估计，需要更底层的 profiling（ncu、功耗计）校准。\n",
    "- 没有考虑 Tensor Core 的 layout transform、softmax/reduction 等 memory-bound 算子。\n",
    "\n",
    "### 下一步\n",
    "\n",
    "- 用 `ncu` 精确测量 `dram__bytes.sum`。\n",
    "- 用 `nvidia-smi` 功耗反推 energy-per-FLOP。\n",
    "- 把实验扩展到 transformer layer（attention 是典型的 memory-bound）。\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.11.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
