从零到生成:纯 C++ 手搓 LLM 推理引擎全记录

Flying Machine by Leonardo da Vinci
Leonardo da Vinci,《飞行器设计手稿》, c. 1487

引言

每天都在用 LLM,但它到底是怎么一步步把一串 token 变成人话的?PyTorch、vLLM、llama.cpp 这些框架把所有细节都藏起来了——你调一个 model.generate(),文字就出来了。但我想搞清楚中间到底发生了什么。

于是我用纯 C++20,不依赖任何 ML 框架,从零写了一个 LLM 推理引擎:vvllm。它能跑 Qwen2.5-0.5B,输入一段 prompt,输出连贯的文本。整个过程——从 mmap 加载权重,到 BPE 编码,到 Transformer 前向传播,到 nucleus sampling——全部手写。

这篇文章按照代码的实际流程,一步步记录这个引擎的构建和优化过程。

问题定义

目标:给定一段文本(prompt),让模型续写出连贯的内容。

模型:Qwen2.5-0.5B(通义千问 0.5B 参数版本)

约束:纯 C++20,不用 PyTorch / TensorFlow / ONNX Runtime。唯一允许的外部库是 OpenBLAS(可选的矩阵乘法加速)和 nlohmann/json(解析配置文件)。

构建系统:Bazel 7.4.1,跨平台支持 aarch64 / x86_64。

总览

整个推理引擎的构建和优化分为 7 步,按照实际开发顺序:

步骤 内容 关键指标
Step 0 基础设施:Backend + Tensor + SafeTensors 能加载 290 个 tensor,BF16→F32
Step 1 BPE Tokenizer 文本 ↔ token ID 互转
Step 2 Transformer 前向传播 24 层完整 forward pass
Step 3 修复 RoPE 旋转位置编码 从输出乱码到正确生成
Step 4 KV Cache decode 阶段 56× 加速
Step 5 OpenBLAS 加速线性投影 线性层 8.6× 加速
Step 6 Sampler & 自回归生成 Temperature + Top-p 采样

Step 0: 基础设施 — Backend + Tensor + SafeTensors

在写任何推理逻辑之前,需要先解决三个基础问题:怎么管理内存、怎么抽象计算、怎么加载模型权重。

Backend 抽象

定义一个抽象的计算后端接口。所有算子(matmul、add、softmax、RoPE 等)都是虚函数,具体实现由子类提供。这样后面可以无缝切换 naive CPU 和 OpenBLAS 后端:

class Backend {
public:
    virtual void matmul(float* out, const float* A, const float* B,
                        std::size_t M, std::size_t N, std::size_t K) = 0;
    virtual void rms_norm(float* out, const float* x, const float* weight,
                          std::size_t n, float eps) = 0;
    virtual void softmax(float* out, const float* x, std::size_t n) = 0;
    virtual void rope(float* q, float* k, std::size_t seq_len,
                      std::size_t num_heads, std::size_t num_kv_heads,
                      std::size_t head_dim, std::size_t pos, float theta) = 0;
    virtual void linear(float* out, const float* inp, const float* weight,
                        const float* bias, std::size_t M, std::size_t N,
                        std::size_t K) = 0;
    // ... add, mul, silu, embedding
};

第一个实现是 naive CPU 后端——所有操作都是最直白的 for 循环。比如 linear projection(推理中最核心的计算):

void BackendCPU::linear(float* out, const float* inp, const float* weight,
                        const float* bias, std::size_t M, std::size_t N,
                        std::size_t K) {
    // out[M, N] = inp[M, K] @ weight^T[K, N] + bias[N]
    for (std::size_t i = 0; i < M; i++) {
        for (std::size_t j = 0; j < N; j++) {
            float sum = bias ? bias[j] : 0.0f;
            for (std::size_t k = 0; k < K; k++) {
                sum += inp[i * K + k] * weight[j * K + k];
            }
            out[i * N + j] = sum;
        }
    }
}

三重循环,O(M×N×K) 标量乘法。慢,但正确。先让它跑起来,再优化。

SafeTensors 加载器

HuggingFace 的模型权重以 SafeTensors 格式存储。格式很简单:

+------------------+
| 8 bytes          |  header_size (uint64)
+------------------+
| JSON header      |  tensor names, shapes, dtypes, offsets
+------------------+
| raw tensor data  |  二进制浮点数据(BF16 或 F32)
+------------------+

关键设计决策是用 mmap 而不是 read 来加载文件。mmap 把文件映射到虚拟地址空间,操作系统按需从磁盘加载数据页——而不是一次性把整个文件读进内存。对 1GB+ 的权重文件,这减少了启动时间和内存峰值。

mapped_ = mmap(nullptr, file_size_, PROT_READ, MAP_PRIVATE, fd_, 0);

Qwen2.5-0.5B 的权重是 BF16(Brain Float 16)格式存储的。BF16 和 float32 有相同的 8 位指数部分,只是尾数从 23 位截断到 7 位。转换很简单——左移 16 位就行:

float bf16_to_f32(uint16_t bf16) {
    uint32_t bits = static_cast<uint32_t>(bf16) << 16;
    float result;
    std::memcpy(&result, &bits, sizeof(float));
    return result;
}

一个模型有 290 个 tensor(每层有 attention 的 Q/K/V/O 权重、MLP 的 gate/up/down 权重、两个 layernorm 权重,加上 embedding 和 final norm),加载后占 ~1.9GB 内存。

Step 1: BPE Tokenizer — 把文字变成数字

LLM 的输入不是原始文本,而是 token ID 序列。Tokenizer 负责文本和 token ID 之间的转换。Qwen2.5 使用 BPE(Byte Pair Encoding)分词。

编码分三步:

"Hello world"

Step 1: Pre-tokenize(按空格拆分,空格变成 Ġ 前缀)
  → ["Hello", "Ġworld"]

Step 2: BPE merge(反复合并最高优先级的相邻 pair)
  "Ġworld" → ['Ġ', 'w', 'o', 'r', 'l', 'd']
          → ['Ġ', 'wo', 'r', 'l', 'd']     // merge('w','o'), rank=300
          → ['Ġ', 'wor', 'l', 'd']          // merge('wo','r')
          → ['Ġ', 'world']                   // ...
          → ['Ġworld']                        // final merge

Step 3: Lookup(查词表,token → ID)
  ['Hello', 'Ġworld'] → [9707, 1879]

BPE 的核心循环:每一轮扫描所有相邻 token pair,找到 merge rank 最低的那个(优先级最高),合并成一个新 token。重复直到没有可合并的 pair。

// Step 3: Apply BPE merges
while (tokens.size() > 1) {
    int best_rank = -1;
    size_t best_idx = 0;
    for (size_t j = 0; j < tokens.size() - 1; j++) {
        std::string key = tokens[j] + " " + tokens[j + 1];
        auto it = merge_ranks_.find(key);
        if (it != merge_ranks_.end()) {
            if (best_rank < 0 || it->second < best_rank) {
                best_rank = it->second;
                best_idx = j;
            }
        }
    }
    if (best_rank < 0) break;
    tokens[best_idx] = tokens[best_idx] + tokens[best_idx + 1];
    tokens.erase(tokens.begin() + best_idx + 1);
}

解码是反向操作:token ID 查词表得到 token 字符串,拼接后把 Ġ(UTF-8: 0xC4 0xA0)替换回空格。

Qwen2.5 的词表有 151,936 个 token——这意味着最后的 logits 输出也是一个 151,936 维的向量。

Step 2: Transformer 前向传播 — 推理引擎的核心

这是整个引擎最重要的部分。一个完整的 forward pass 做了什么?

token_ids: [t0, t1, t2, ..., tn]
                |
        Embedding Lookup
                |
          x: [seq_len, 896]
                |
     +----------+----------+
     |     Transformer      |
     |      Layer × 24       |
     |                      |
     |  RMS Norm            |
     |  Q, K, V Projection  |
     |  RoPE                |
     |  Causal Attention    |
     |  Output Projection   |
     |  + Residual          |
     |  RMS Norm            |
     |  MLP (SwiGLU)        |
     |  + Residual          |
     +----------+-----------+
                |
          Final RMS Norm
                |
        Logits Projection
                |
        logits: [151936]

1. Embedding Lookup

最简单的一步:把 token ID 映射到 896 维的向量。embedding table 是一个 [151936, 896] 的矩阵,查找就是 memcpy 一行:

void embedding(float* out, const float* table,
               std::size_t token_id, std::size_t hidden_size) {
    std::memcpy(out, table + token_id * hidden_size,
                hidden_size * sizeof(float));
}

2. RMS Norm

每层的输入先做 RMS normalization——和 LayerNorm 类似,但不减均值,只除以 RMS(Root Mean Square):

void rms_norm(float* out, const float* x, const float* weight,
              std::size_t n, float eps) {
    float sum = std::accumulate(x, x + n, 0.0f,
        [](float acc, float x) { return acc + x * x; });
    float rms = sqrt(sum / n + eps);
    for (size_t i = 0; i < n; i++) {
        out[i] = x[i] / rms * weight[i];
    }
}

公式:out[i] = (x[i] / √(mean(x²) + ε)) × weight[i]。RMS Norm 比 Layer Norm 少算一次均值和一次减法,在推理时更快,而且实验表明效果相当。

3. Q, K, V 线性投影

注意力机制的第一步:把 896 维的隐藏状态投影到 Q, K, V 空间。这就是矩阵乘法——LLM 推理中最耗时的计算。

// 对每个 token,计算 Q, K, V
for (std::size_t s = 0; s < seq_len; s++) {
    const float* inp = norm_out.data() + s * hidden;
    linear(q.data() + s * num_heads * head_dim,
           inp, attn.q_proj_weight, q_bias(attn),
           num_heads * head_dim, hidden, backend);   // Q: [14×64, 896]
    linear(k_new.data() + s * kv_dim,
           inp, attn.k_proj_weight, k_bias(attn),
           kv_dim, hidden, backend);                  // K: [2×64, 896]
    linear(v_new.data() + s * kv_dim,
           inp, attn.v_proj_weight, v_bias(attn),
           kv_dim, hidden, backend);                  // V: [2×64, 896]
}

注意一个关键细节:Qwen2 用了 Grouped Query Attention (GQA)——14 个 query heads 共享 2 个 KV heads。也就是说 K 和 V 的维度是 2×64=128,远小于 Q 的 14×64=896。这大幅减少了 KV cache 的内存占用,同时对模型质量影响很小。

4. RoPE(旋转位置编码)

Transformer 本身是位置不变的——打乱输入顺序,输出不变。RoPE 通过对 Q 和 K 施加位置相关的旋转,让模型知道每个 token 在序列中的位置。

核心思想:对 head_dim 维向量的每一对元素 (xi, xi+half),按照位置 pos 和频率 θ 做二维旋转:

for (std::size_t i = 0; i < half; i++) {
    float freq = 1.0f / std::pow(theta, (float)(2 * i) / head_dim);
    float angle = position * freq;
    float x0 = head[i];
    float x1 = head[i + half];
    head[i]        = x0 * cos(angle) - x1 * sin(angle);
    head[i + half] = x0 * sin(angle) + x1 * cos(angle);
}

频率从低到高,低维度编码高频(局部位置关系),高维度编码低频(长距离依赖)。这和傅里叶变换的思想类似。

5. Causal Attention(因果注意力)

这是 Transformer 的核心。对每个 query position,计算它与所有之前 token 的注意力分数,然后加权求和 value:

void causal_attention(float* out, const float* q, std::size_t q_idx,
                      const float* k, const float* v,
                      std::size_t attend_len, std::size_t num_heads,
                      std::size_t num_kv_heads, std::size_t head_dim,
                      float scale, Backend& backend) {
    std::size_t groups = num_heads / num_kv_heads;  // 14/2 = 7

    for (std::size_t h = 0; h < num_heads; h++) {
        std::size_t kv_h = h / groups;  // GQA: 7 个 Q head 共享 1 个 KV head

        // 1. scores = Q · K^T / sqrt(head_dim)
        for (std::size_t t = 0; t < attend_len; t++) {
            float dot = 0.0f;
            for (std::size_t d = 0; d < head_dim; d++) {
                dot += q_head[d] * k_head[d];
            }
            scores[t] = dot * scale;  // scale = 1/√64 = 0.125
        }

        // 2. attention weights = softmax(scores)
        backend.softmax(scores.data(), scores.data(), attend_len);

        // 3. output = weighted sum of V
        for (std::size_t t = 0; t < attend_len; t++) {
            for (std::size_t d = 0; d < head_dim; d++) {
                head_out[d] += scores[t] * v_head[d];
            }
        }
    }
}

"Causal" 的含义是:每个 token 只能看到它自己和它之前的 token,不能看到未来。这是通过 attend_len = pos + s + 1 来实现的——第 s 个 token 只 attend 到 s+1 个位置。

6. MLP(SwiGLU)

注意力之后是一个两层 MLP,Qwen2 使用 SwiGLU 激活函数:

// gate = linear(x, gate_proj)    // [896] → [4864]
// up   = linear(x, up_proj)      // [896] → [4864]
// gate = silu(gate) * up          // element-wise
// out  = linear(gate, down_proj)  // [4864] → [896]

backend.silu(gate.data(), gate.data(), seq_len * intermediate);
backend.mul(gate.data(), gate.data(), up.data(), seq_len * intermediate);

SiLU(也叫 Swish): silu(x) = x × σ(x) = x / (1 + e-x)。SwiGLU 的 "gating" 机制让网络可以选择性地保留或过滤信息——gate 路径控制开关,up 路径提供内容。

7. 残差连接

每个 attention block 和 MLP block 之后都有残差连接:x = x + block_output。它解决了深层网络的梯度消失问题——梯度可以沿着 skip connection 直接回传,不会在 24 层中逐渐衰减到零。

8. 最终预测

24 层 Transformer 处理完毕后,取最后一个 token 的隐藏状态,做一次 RMS Norm,然后投影到 vocab 维度得到 logits:

// Final norm (only last token)
backend.rms_norm(final_out.data(),
                 x.data() + (seq_len - 1) * hidden,
                 final_norm_weight, hidden, eps);

// Logits = final_out @ embed_tokens^T  (tied weights)
linear(logits.data(), final_out.data(), embed_tokens, nullptr,
       config.vocab_size, hidden, backend);

注意 tied weights:logits 投影矩阵和 embedding table 是同一个矩阵。这是一种常见的参数共享技巧——输入(token→向量)和输出(向量→token 概率)使用同一组参数,减少了 151936×896 ≈ 1.36 亿个参数。

Step 3: 修 Bug — RoPE 旋转位置编码的配对惯例

写完 forward pass,运行模型——输出全是乱码。debug 了很久才发现问题出在 RoPE 的配对惯例上。

RoPE 需要把 head_dim 维的向量两两配对进行旋转。但配对方式有两种:

head_dim = 8 的向量: [x0, x1, x2, x3, x4, x5, x6, x7]

Interleaved (GPT-NeoX 惯例):
  配对 (x0,x1), (x2,x3), (x4,x5), (x6,x7)
  相邻元素配对

Split-half (Qwen2 / Llama 惯例):
  配对 (x0,x4), (x1,x5), (x2,x6), (x3,x7)
  前半和后半配对

我一开始用了 interleaved 惯例(因为很多早期 GPT 实现是这么做的),但 Qwen2 用的是 split-half。两种惯例的数学本质完全一样——都是二维旋转——但应用到不同的元素对上,得到的结果天差地别。

debug 过程:

  1. 先用 Python 跑同样的输入,对比 HuggingFace 的 logits 输出
  2. 逐层对比发现,第一层的 Q projection 结果正确,但 RoPE 之后就不对了
  3. 写 numpy 脚本分别用两种惯例计算 RoPE,和 HuggingFace 对比
  4. 确认 split-half 和 HuggingFace 结果一致,修改 C++ 实现

教训:LLM 推理中最难 debug 的不是算法错误,而是惯例差异。RoPE 配对方式、attention mask 的实现、weight 的 layout([N,K] 还是 [K,N])——这些细节在论文里不会写,只有看源码才知道。

Step 4: KV Cache — 56 倍 Decode 加速

修完 RoPE,模型能正确生成文本了。但速度惨不忍睹——每生成一个 token 要重新跑一遍完整的 forward pass。

问题出在自回归生成的本质:生成第 N 个 token 时,需要所有前 N-1 个 token 的 attention context。naive 做法是每次都把所有 token 重新送入模型——但前 N-1 个 token 的 K, V 在之前的步骤中已经算过了!

Naive (无 cache):
  Step 1: forward([t0, t1, t2, t3, t4])        ← 5 个 token, 计算 5 组 K,V
  Step 2: forward([t0, t1, t2, t3, t4, t5])     ← 6 个 token, 重新计算 6 组 K,V
  Step 3: forward([t0, t1, t2, t3, t4, t5, t6]) ← 7 个 token, 重新计算 7 组 K,V
  → 复杂度 O(n²):每步重算所有历史

KV Cache:
  Prefill:  forward([t0, t1, t2, t3, t4])  ← 缓存 5 组 K,V
  Decode 1: forward([t5], pos=5)             ← 只算 t5 的 Q, 用缓存的 K,V
  Decode 2: forward([t6], pos=6)             ← 只算 t6 的 Q, 用缓存的 K,V
  → 复杂度 O(n):每步只计算新 token

KV Cache 的实现很直观——每层维护一个不断增长的 K 和 V 向量:

struct KVCache {
    std::vector<std::vector<float>> k_cache;  // [num_layers][seq_len * kv_dim]
    std::vector<std::vector<float>> v_cache;
    std::size_t seq_len = 0;

    void append(std::size_t layer_idx, const float* new_k,
                const float* new_v, std::size_t num_new_tokens,
                std::size_t kv_dim) {
        std::size_t count = num_new_tokens * kv_dim;
        auto& kc = k_cache[layer_idx];
        kc.insert(kc.end(), new_k, new_k + count);
        auto& vc = v_cache[layer_idx];
        vc.insert(vc.end(), new_v, new_v + count);
    }
};

在 forward pass 中,新 token 的 K, V 被 append 到 cache,attention 使用完整的 cache(而不只是当前 token):

// Append new K/V to cache
kv_cache.append(layer_idx, k_new.data(), v_new.data(), seq_len, kv_dim);

// Attention 使用 cache 中的全部 K/V
causal_attention(attn_out, q.data(), s,
                 kv_cache.k_data(layer_idx),   // 所有历史 K
                 kv_cache.v_data(layer_idx),   // 所有历史 V
                 pos + s + 1,                   // attend 长度
                 num_heads, num_kv_heads, head_dim, scale, backend);

生成循环变成了两个阶段:

// Prefill: 一次性处理整个 prompt, 填充 KV Cache
auto logits = forward(model, token_ids, 0, kv_cache);

// Decode: 每次只送入新 token
for (int step = 0; step < max_tokens; step++) {
    int next_token = sampler.sample(logits);
    std::size_t pos = token_ids.size() - 1;
    logits = forward(model, {next_token}, pos, kv_cache);  // 只送 1 个 token
}

Benchmark 结果(Google Benchmark, synthetic tiny model):

方式 Context=64 对比
DecodeNoCache(每步全量重算) ~20 ms
DecodeWithCache(只算新 token) ~362 μs 56×

56 倍加速。而且随着上下文长度增加,差距会越来越大——无 cache 时每步都重算全部 token 的 K/V 投影,复杂度是 O(n);有 cache 时只算 1 个 token,复杂度是 O(1)(attention 本身还是 O(n),但投影部分的节省是巨大的)。

GQA 在这里也发挥了作用:每层只需要缓存 2 个 KV head(128 维),而不是 14 个(896 维)。对于 24 层模型,KV cache 的内存占用减少了 7 倍。

Step 5: OpenBLAS — 8.6 倍加速线性投影

Profile 发现,inference 的绝大部分时间花在 linear() 函数上——这不意外,一个 Transformer 层有 7 个线性投影(Q, K, V, O, gate, up, down),每个都是矩阵乘法。对于 Qwen2.5-0.5B,最大的投影是 [896, 4864],即 896×4864 ≈ 436 万次乘加运算。24 层就是一亿次。

naive 后端的三重循环 matmul 效率极低——和 上一篇 matmul 优化文章里的 baseline 一样,没有 SIMD、没有 tiling、没有 cache 优化。

最直接的解决方案:把 linear 操作换成 OpenBLAS 的 cblas_sgemm

void BackendBLAS::linear(float* out, const float* inp, const float* weight,
                         const float* bias, std::size_t M, std::size_t N,
                         std::size_t K) {
    // out[M, N] = inp[M, K] @ weight^T[K, N]
    // weight is [N, K] row-major, so we use CblasTrans on B.
    cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
                M, N, K, 1.0f, inp, K, weight, K, 0.0f, out, N);

    if (bias) {
        for (std::size_t i = 0; i < M; i++)
            for (std::size_t j = 0; j < N; j++)
                out[i * N + j] += bias[j];
    }
}

注意 weight 的存储格式:PyTorch 的惯例是 weight shape 为 [out_features, in_features],即 [N, K] row-major。所以 inp @ weight^T 需要对 weight 转置,对应 CblasTrans 参数。

这一行 cblas_sgemm 背后是 OpenBLAS 多年优化的结晶——SIMD 向量化、cache-friendly 的 tile 布局、寄存器级别的 micro-kernel——正是上一篇文章里花了 7 步才达到的那些优化。

Benchmark 结果(aarch64, -c opt):

Dimension CPU (naive) BLAS (OpenBLAS) 加速
64 2,227 ns 884 ns 2.5×
128 12,148 ns 3,804 ns 3.2×
256 58,317 ns 16,672 ns 3.5×
576 366,660 ns 42,761 ns 8.6×

维度越大,加速比越高。小矩阵时(dim=64),OpenBLAS 的线程调度和 packing 开销相对较大,收益有限。但在实际模型的维度(hidden_size=896, intermediate_size=4864)下,加速效果非常显著。

切换后端只需要在启动时加一个参数:--backend blas。Backend 抽象让这个切换零侵入——forward pass 代码完全不需要改动。

Step 6: Sampler — 从 Logits 到文字

forward pass 输出的是一个 151,936 维的 logits 向量——每个值代表一个 token 的"可能性得分"。怎么从中选出下一个 token?

最简单的方法:Greedy(贪心)

直接取 logits 最大值对应的 token。确定性,但容易陷入重复。

Temperature Scaling

把 logits 除以 temperature 参数,然后过 softmax 转成概率分布。temperature < 1 让分布更尖锐(更确定性),temperature > 1 让分布更平坦(更随机)。

Top-p(Nucleus Sampling)

先把 token 按概率从高到低排序,累加概率直到超过 top_p 阈值,只从这些 token 中采样。这既保留了随机性(不像 greedy 那么死板),又排除了低概率的"噪音"token(不像纯随机那么离谱)。

int Sampler::sample(const std::vector<float>& logits) {
    if (temperature_ == 0.0f) {
        return argmax(logits);  // greedy
    }

    // 1. Temperature scaling + softmax
    for (std::size_t i = 0; i < n; i++)
        scaled[i] = logits[i] / temperature_;
    softmax(scaled);

    // 2. Top-p: sort by probability, keep cumsum <= top_p
    std::sort(indices.begin(), indices.end(),
              [&](int a, int b) { return scaled[a] > scaled[b]; });

    float cumsum = 0.0f;
    for (std::size_t i = 0; i < n; i++) {
        cumsum += scaled[indices[i]];
        if (cumsum >= top_p_) { cutoff = i + 1; break; }
    }

    // 3. Sample from the truncated distribution
    std::uniform_real_distribution<float> dist(0.0f, kept_sum);
    float r = dist(rng_);
    // ... weighted random selection
}

默认参数是 temperature=0.7, top_p=0.9——在多样性和连贯性之间的平衡点。

把一切串起来

最终的推理流程:

// 1. 加载配置、权重、tokenizer
auto config = load_config("config.json");
auto weights = SafeTensorsLoader("model.safetensors").load_all(backend);
auto model = create_model(config, backend);
load_weights(model, weights);

// 2. Tokenize prompt
auto token_ids = tokenizer.encode("The capital of France is");  // [5 tokens]

// 3. Prefill: 处理整个 prompt
KVCache kv_cache(config.num_hidden_layers);
auto logits = forward(model, token_ids, 0, kv_cache);

// 4. Decode: 逐 token 生成
for (int step = 0; step < max_tokens; step++) {
    int next_token = sampler.sample(logits);
    if (next_token == config.eos_token_id) break;

    std::cout << tokenizer.decode({next_token}) << std::flush;

    std::size_t pos = token_ids.size() - 1;
    logits = forward(model, {next_token}, pos, kv_cache);
    token_ids.push_back(next_token);
}

输出:

Model: qwen2 (896d, 24 layers)
Loading weights...
Loaded 290 tensors
Model initialized
Prompt: "The capital of France is" (5 tokens)

The capital of France is Paris, where the most important cultural
institutions are located. The city is famous for its museums,
including the Lou

它活了。

总结与思考

从零构建这个推理引擎,几个深刻的体会:

1. LLM 推理 = 大量矩阵乘法。 一次 forward pass 有 24 层 × 7 个线性投影 = 168 次 matmul。这就是为什么 GPU 在 LLM 推理中如此重要——它的核心优势就是大规模并行 matmul。也是为什么 OpenBLAS 一行代码就能带来 8.6 倍加速。

2. KV Cache 是推理优化的第一优先级。 56 倍加速,而且实现简单——本质上就是一个 append-only 的数组。但它改变了生成的复杂度类别:从 O(n²) 到 O(n)。没有 KV Cache 的 LLM 推理在实践中根本不可用。

3. 惯例差异比算法 bug 更难发现。 RoPE 的 split-half vs interleaved、weight 的 [N,K] vs [K,N]、BPE 的空格前缀——这些在论文里一笔带过,但实现时每一个都可能让你 debug 半天。对比 reference 实现(HuggingFace)的中间结果是唯一可靠的 debug 手段。

4. 抽象层设计很重要。 Backend 接口让我可以先用 naive 循环确保正确性,再换 OpenBLAS 提升性能,forward pass 代码一行不改。这个简单的策略避免了同时 debug 正确性和性能两个问题。

5. 最大的收获是理解。 用完这么多 LLM 框架,直到自己手写一遍才真正理解:为什么 KV Cache 对长上下文如此关键、为什么 GQA 能大幅降低内存、为什么 prefill 和 decode 的性能特征完全不同、为什么 tied embedding 既省内存又不损失质量。

整个项目大约 1500 行 C++(不含测试和 benchmark),跑一个真实的 0.5B 模型。

Next Plan

Update: 后续的 CUDA 优化已完成,详见 Part 2: 从 CPU 到 GPU — CUDA 优化之路。从 33 tok/s 做到了 504 tok/s。

当前的引擎能跑,但离"好用"还有距离。接下来计划按优先级推进:

  1. Instruct 模型支持 — 目前只支持 base model 的续写。Instruct 模型需要 chat template 格式化(<|im_start|>user\n...<|im_end|>),以及多个 stop token 的处理。这是让模型真正"可对话"的前提。
  2. 多线程并行 — 当前完全单线程。最直接的优化是并行化 linear projection(多个 head 的 QKV 投影可以同时算),以及 attention 中多个 head 的独立计算。预计在 CPU 上还能再提 3-5 倍。
  3. CUDA GPU 后端 — CPU 推理的瓶颈是 matmul 吞吐。GPU 的 Tensor Core 天然适合大规模矩阵乘法。目标是写一个最小的 CUDA backend,用 cuBLAS 替换 OpenBLAS,把 decode 延迟从毫秒级降到微秒级。
  4. 量化推理 (INT8/INT4) — 当前所有权重都是 float32,一个 0.5B 模型就占 1.9GB。INT8 量化可以把内存减半,INT4 再减半,同时 SIMD/GPU 的整数运算吞吐更高。关键挑战是量化后的精度损失控制。
  5. Continuous Batching — 目前只能处理单条序列。生产级推理引擎(vLLM, TensorRT-LLM)的核心能力是同时处理多条请求,动态调度 KV cache 内存。这涉及 PagedAttention 等更复杂的内存管理策略。

"What I cannot create, I do not understand." — Richard Feynman