从 CPU 到 GPU:LLM 推理引擎的 CUDA 优化之路

Series: LLM Inference Engine
Part 1: 从零到生成 — 纯 C++ 手搓推理引擎
Part 2: 从 CPU 到 GPU — CUDA 优化之路 (本篇)
Charlie Chaplin, Modern Times (1936)
Charlie Chaplin,《摩登时代》, 1936 — 从手工到机器的加速

引言

上一篇我们用纯 C++ 从零搭建了一个 LLM 推理引擎,能跑 Qwen2.5-0.5B。当时在 CPU(OpenBLAS + INT8 量化)上跑到了 33 tok/s。

这篇文章记录接下来的优化过程:把引擎搬上 GPU,一步步从 33 tok/s 做到 504 tok/s。每一步优化都有明确的 profiling 数据驱动,不靠猜。

优化路径总览:

阶段 Decode 吞吐 提升
CPU BLAS + INT8 33 tok/s baseline
CUDA + cuBLAS + INT8 110 tok/s 3.3×
+ FP16 inference 153 tok/s 4.6×
+ Flash Attention 161 tok/s 4.9×
+ GPU-resident sampling 161 tok/s -
+ CUDA caching allocator 504 tok/s 15.3×

最后一步(caching allocator)带来了最大的提升,但如果不是 nsys profiling 揭示了真正的瓶颈,我根本不会想到去优化内存分配。这是这篇文章最想传达的一点:先 profile,再优化

Step 1: CUDA 后端 — 把计算搬上 GPU

上一篇的 Backend 抽象在这里发挥了作用。新建一个 BackendCUDA 子类,实现同样的接口,背后换成 CUDA kernel:

class BackendCUDA : public BackendCPU {
public:
    void linear(float* out, const float* inp, const float* weight,
                const float* bias, std::size_t M, std::size_t N,
                std::size_t K) override;
    void rms_norm(...) override;
    void rope(...) override;
    void causal_attention(...) override;
    // ... 所有算子都有 GPU 实现
};

核心设计:

效果:从 CPU BLAS 的 33 tok/s 到 CUDA 的 110 tok/s,3.3 倍。

Step 2: FP16 全链路推理

第一版 CUDA 后端的数据流是:FP32 中间结果 → 转 FP16 → cuBLAS GEMM → FP32 输出。中间有大量 FP32↔FP16 转换。

优化思路:让所有中间 Tensor 都是 FP16。embedding 输出转 FP16 后,整个 pipeline(RMSNorm、RoPE、Attention、MLP)都在 FP16 上跑,只在最后输出 logits 时转回 FP32。

为每个 kernel 写了 FP16 版本。以 RMSNorm 为例,内部用 FP32 做累加避免精度损失:

__global__ void rms_norm_fp16_kernel(half* out, const half* x,
                                      const half* weight, size_t n,
                                      float eps) {
    // 用 FP32 累加,避免精度损失
    float local_sum = 0.0f;
    for (size_t i = tid; i < n; i += stride) {
        float v = __half2float(x[i]);
        local_sum += v * v;
    }
    // ... reduction ...
    float rms = sqrtf(sdata[0] / n + eps);

    for (size_t i = tid; i < n; i += stride) {
        out[i] = __float2half(
            __half2float(x[i]) / rms * __half2float(weight[i]));
    }
}

这里有一个坑:Sampler 也需要知道 logits 的数据类型。我们把 logits 从 FP16 转成 FP32 再传给 sampler,但 sampler 内部有一个 fp16 标志位告诉 GPU sampling kernel 用什么精度读取。由于 logits 已经是 FP32 了,这个标志位应该永远是 0——但代码错误地根据 backend.is_fp16() 设置了它。

结果:FP16 模式下输出全是乱码。INT8 FP32 正常,FP16 乱码——因为 sampling kernel 把 FP32 数据当 FP16 读了,每对字节被错误解释为一个 half float。

修复一行代码:

// Before (bug):
int fp16 = backend_.is_fp16() ? 1 : 0;
return cuda_sample(d_logits_, logits.size(), ..., fp16);

// After (fix):
// Logits are always FP32 (converted in transformer_forward)
return cuda_sample(d_logits_, logits.size(), ..., /*fp16=*/0);

教训:数据类型在管道中的每一个环节都要对齐。一个地方转了 FP32,下游如果还按 FP16 读,就是灾难。

FP16 全链路打通后:110 tok/s → 153 tok/s,+39%。带宽减半 + Tensor Core 利用率提升。

Step 3: Flash Attention

旧的 attention kernel 把完整的 scores[attend_len] 数组存在 shared memory 里。这有两个问题:

  1. Shared memory 大小随序列长度线性增长,~12K tokens 就爆了(48KB 限制)
  2. 必须先算完所有 QKT score,才能做 softmax,才能算 V 加权——三步串行

Flash Attention 的核心洞察:softmax 可以增量计算。每次只处理一小块(128 个 KV position),用两个标量 running_maxrunning_sum 来修正历史累积值:

// 主循环:每次处理 128 个 KV position
for (tile_start = 0; tile_start < attend_len; tile_start += 128) {

    // 1. 只算这 128 个 score → shared memory (512 bytes)
    for (t in tile)
        scores[t] = dot(q, k[t]) * scale;

    // 2. 找这个 tile 的 max
    tile_max = reduce_max(scores);

    // 3. 修正历史(Flash Attention 的精华)
    new_max = max(running_max, tile_max);
    correction = exp(running_max - new_max);
    running_sum *= correction;
    out_acc *= correction;       // 修正之前所有 tile 的累积值

    // 4. 累加当前 tile
    for (t in tile) {
        e = exp(scores[t] - new_max);
        running_sum += e;
        out_acc += e * v[t];     // 加权 V 累积
    }
    running_max = new_max;
}

// 最终归一化
out = out_acc / running_sum;

为什么 correction 能正确修正?数学上:

exp(score - old_max) × exp(old_max - new_max) = exp(score - new_max)

所以把所有历史累积值乘以 correction,等价于用 new_max 重新算了一遍。

Shared memory 从 (attend_len + 256) × 4 变成 常量 1536 字节,不管序列多长。

attend_len 旧 kernel smem/head Flash Attention smem/head
541.2 KB1.5 KB
4482.8 KB1.5 KB
409617.0 KB1.5 KB
1228849.0 KB (OOM)1.5 KB

短序列性能基本持平(瓶颈不在 attention),但解锁了长上下文能力。Decode 吞吐在 50–750 token attend_len 范围内保持稳定在 ~161 tok/s。

Step 4: GPU-Resident Sampling

改完 Flash Attention 之后,我用 nsys 检查数据流,发现了一个不必要的 CPU roundtrip:

GPU: FP16 logits → fp16_to_fp32 → FP32 logits
                                          ↓ cudaMemcpy D2H (593KB)
CPU: std::vector<float> logits
                                          ↓ cudaMemcpy H2D (593KB)
GPU: sample_kernel → token ID

Qwen2.5-0.5B 的 vocab_size=151936,每次 decode 搬运 1.2MB 来回过 PCIe。

修复方案:让 forward() 返回 Tensor(GPU-resident)而不是 std::vector<float>。Sampler 直接从 GPU Tensor 采样:

// Before: forward() 下载到 CPU,sampler 再上传回 GPU
std::vector<float> forward(model, token_ids, pos, kv_cache);
sampler->sample(logits, step);  // 内部: H2D upload + sample kernel

// After: logits 留在 GPU,sampler 直接用 device pointer
Tensor forward(model, token_ids, pos, kv_cache);
sampler->sample(logits, step);  // 直接: sample kernel on device ptr

这个改动本身对 decode 吞吐影响不大(PCIe 传输相对于 kernel 计算时间很短),但代码更干净了——而且为下一步优化铺了路。

Step 5: nsys Profile → 发现真正的瓶颈 → CUDA Caching Allocator

这是整篇文章最重要的一步。不是优化算法,而是用 profiling 工具发现了一个完全意想不到的瓶颈。

跑一次 nsys profile

nsys profile --trace=cuda,osrt --stats=true \
  bazel-bin/bin/qwen/qwen \
    --model models/Qwen2.5-0.5B \
    --prompt "Once upon a time" \
    --max_tokens 50 \
    --backend cuda --quantize int8 --fp16

CUDA API 时间分布:

 Time (%)  Total Time (ns)  Num Calls   Name
 --------  ---------------  ---------   ----
     36.0%       216913486       1341   cudaFree           ← 36%!!
     33.6%       202353991       2963   cudaMemcpy
     17.3%       104017852      17773   cudaLaunchKernel   ← 实际计算
      8.2%        49292731       1341   cudaMalloc         ← 8%

看到这个数据我惊了:cudaMalloc + cudaFree 占了 44% 的 CUDA API 时间(266ms),比实际 GPU kernel 计算(104ms)还多 2.5 倍!

为什么?因为 cudaMalloccudaFree 都是同步操作——它们会阻塞 CPU 线程,等 GPU 上所有 pending kernel 执行完毕。这意味着每次分配/释放都强制同步了 GPU pipeline。

1341 次 malloc + 1341 次 free。哪来的?每次 transformer_forward() 创建 ~12 个临时 Tensor:

Tensor x({seq_len * hidden}, DType::Float16, Device::CUDA);
Tensor norm_out({seq_len * hidden}, DType::Float16, Device::CUDA);
Tensor q({seq_len * num_heads * head_dim}, DType::Float16, Device::CUDA);
Tensor k_new({seq_len * kv_dim}, DType::Float16, Device::CUDA);
Tensor v_new({seq_len * kv_dim}, DType::Float16, Device::CUDA);
Tensor attn_out({seq_len * hidden}, DType::Float16, Device::CUDA);
Tensor gate_up({seq_len * 2 * intermediate}, DType::Float16, Device::CUDA);
Tensor gate({seq_len * intermediate}, DType::Float16, Device::CUDA);
Tensor final_out({hidden}, DType::Float16, Device::CUDA);
Tensor logits({vocab_size}, DType::Float16, Device::CUDA);
// ... plus logits_fp32, x_fp32 for conversions

函数返回时,这些 Tensor 的析构函数调用 cudaFree。54 次 forward × ~25 个 alloc/free ≈ 1350 次。和 nsys 报告的 1341 完全吻合。

关键洞察:decode 阶段,每次 forward 的 Tensor shape 完全相同(seq_len=1)。每次都 malloc 同样大小的 buffer,用完 free,下一次又 malloc 同样大小——这是纯浪费。

修复:CUDA caching allocator

核心思想极简:freed 的 GPU buffer 不真正释放,而是放入一个 free list(按 size 索引)。下次申请同样大小的 buffer 时,直接从 free list 取,O(1) 查找,零 GPU 同步:

// Free list: byte_size → [ptr, ptr, ...] (exact size matching)
static std::multimap<std::size_t, void*> cuda_free_list_;

// Live allocations: ptr → byte_size (for returning to correct bucket)
static std::unordered_map<void*, std::size_t> cuda_alloc_sizes_;

static void* cuda_pool_alloc(std::size_t bytes) {
    auto it = cuda_free_list_.find(bytes);
    if (it != cuda_free_list_.end()) {
        void* ptr = it->second;
        cuda_free_list_.erase(it);
        cuda_alloc_sizes_.emplace(ptr, bytes);
        return ptr;    // ← 命中!零 GPU 同步
    }
    void* ptr = cuda_malloc(bytes);  // miss: 真正分配
    cuda_alloc_sizes_.emplace(ptr, bytes);
    return ptr;
}

static void cuda_pool_free(void* ptr) {
    auto it = cuda_alloc_sizes_.find(ptr);
    if (it != cuda_alloc_sizes_.end()) {
        cuda_free_list_.emplace(it->second, ptr);  // 回收到 free list
        cuda_alloc_sizes_.erase(it);
    } else {
        cuda_free(ptr);  // 不是 pool 管理的 buffer,真正释放
    }
}

30 行代码。插入到 device_allocate()device_deallocate() 中,对上层完全透明。

第一次 forward 时,所有 Tensor 都 miss(cold start),走 cudaMalloc。函数返回时,buffer 进入 free list。第二次 forward 开始,所有同 size 的 Tensor 都 hit,直接复用——零 cudaMalloc、零 cudaFree、零 GPU 同步。

修复后再跑一次 nsys

 Time (%)  Total Time (ns)  Num Calls   Name
 --------  ---------------  ---------   ----
      7.5%        45833654        751   cudaMalloc    ← -44% calls
      7.3%        44586600        729   cudaFree      ← -79% time!
     16.6%       101269184      17773   cudaLaunchKernel  ← 不变

cudaFree 从 217ms 降到 45ms(-79%)。612 次分配被 pool 命中,完全避免了 GPU 同步。

性能结果:

指标 Before (无 pool) After (caching) 提升
Decode 吞吐 153 tok/s 504 tok/s +3.3×
Decode avg 延迟 6.5 ms 2.0 ms -69%
Decode P50 延迟 6.0 ms 1.6 ms -73%
cudaFree time 217 ms 45 ms -79%

3.3 倍加速,来自 30 行内存池代码。这就是 profiling 的价值——不是猜瓶颈在算子上、在 attention 上、在 GEMM 上,而是让数据告诉你:最大的瓶颈竟然是 cudaFree

总结

从 CPU 的 33 tok/s 到 GPU 的 504 tok/s,15 倍加速,整个过程的关键经验:

1. Profile first, optimize second. 如果不跑 nsys,我可能会花几天优化 CUDA kernel 的 occupancy 或 shared memory 布局——但实际瓶颈是 cudaFree 的 GPU 同步。数据不会骗你。

2. 最大的提升往往来自最简单的修复。 Flash Attention 实现花了几个小时,效果 +5%。caching allocator 写了 30 行,效果 +230%。复杂度和收益不成正比。

3. 全链路思维。 FP16 sampler bug 是一个端到端的类型对齐问题。GPU-resident sampling 是一个数据搬运问题。这些不是 kernel 级别的优化,而是系统级别的。要看整条管道,不能只盯着单个算子。

4. 抽象层的价值不只是代码复用。 Backend 抽象让 CPU/BLAS/CUDA 切换零侵入。Tensor 的 device_allocate/device_deallocate 抽象让 caching allocator 可以透明插入,不改任何上层代码。好的抽象让优化变得局部化。

整个项目现在是 6200 行 C++/CUDA,跑 Qwen2.5-0.5B 在 RTX 4060 上做到 504 tok/s decode。还有很多可以做的——INT4 量化、prefill Q tiling、多请求 batching——但核心的 profiling-driven optimization loop 已经建立起来了。

"Premature optimization is the root of all evil. But so is premature pessimism — profile, then decide."