6 小时,从零完成游戏开发、PPO 算法实现到训练调优。
4 轮迭代、3 次失败、1 个零和陷阱 — AI Coding 时代的强化学习实战记录。
1项目概览
目标是搭建一个多游戏 RL 实验室:用 PPO(Proximal Policy Optimization)训练贪吃蛇 AI。不依赖 Stable-Baselines3 等框架,所有算法从零实现,全程由 Claude Code 辅助完成 — 从游戏逻辑、环境封装、CNN 网络设计到训练循环,约 1100 行代码(8 个文件),6 小时内走完从零到可用的全过程。
最终结果:1.3M 参数的 Actor-Critic CNN,5M 步训练(35 分钟,GTX 1660 Ti),best_score=44.20。但到达这个结果经历了 4 轮训练、3 次失败,每次失败都暴露了一个深层问题 — 这个调试过程本身比最终代码更有价值。
2分层架构:为训练速度服务的设计
核心原则:逻辑层不 import pygame,渲染层不修改游戏状态,PPO 算法不知道是什么游戏。
这不是多余的解耦。SnakeGame.step() 是纯 Python + NumPy,单步耗时约 10-50 µs;如果 import Pygame,SDL 子系统初始化就会拖慢整个训练进程。8 个并行环境每轮 rollout(128 步 × 8 envs = 1024 transitions),CPU 侧约 10-50 ms 完成,远低于 GPU 计算时间。
multiprocessing 使用 spawn 模式,进程创建开销大。而 SnakeGame 的 step 太轻量(微秒级),多进程的 IPC 开销反而比计算本身更高。所以直接用 Python 列表持有 8 个 env 实例,同步顺序 step。
3游戏逻辑实现细节
约 170 行 Python,核心数据结构和碰撞检测:
# 方向向量表 — 索引即方向码 UP, RIGHT, DOWN, LEFT = 0, 1, 2, 3 DIRS = [(-1,0), (0,1), (1,0), (0,-1)] # (row, col) 偏移 # 蛇身用 deque,appendleft/pop 均 O(1) self._snake: deque[tuple[int, int]] = deque() # 空位集合 — 保证食物生成 O(1) 均匀采样 self._empty_cells: set[tuple[int, int]] = set()
坐标系设计
使用 (row, col) 而非 (x, y),与 NumPy ndarray[row, col] 索引一致。观测编码时零转换 — obs[channel, row, col] 直接对应游戏坐标。
碰撞检测的一个细节
# 自碰撞:排除即将离开的尾巴格 body_without_tail = set(self._snake) - {self._snake[-1]} if new_head in body_without_tail: return ..., REWARD_DEATH, True, ...
蛇头移到下一格时,尾巴也会同时离开。如果不排除尾巴,蛇追着自己尾巴转圈会被误判为死亡。
食物生成:O(1) 均匀采样
维护 _empty_cells 集合,每步更新 add/remove。食物生成时从集合中均匀随机取一个,保证公平性和确定性(使用注入的 np.random.Generator)。备选方案 while True: sample until not on snake 在蛇很长时命中率骤降,最坏 O(∞)。
反向输入的处理
# 反向动作(差 2)→ 静默保持当前方向 if abs(action - self._direction) != 2: self._direction = action
这对 RL 很重要:PPO 动作空间是 Discrete(4),网络偶尔输出反向动作是正常探索。如果反向 = 死亡,entropy 探索噪声会制造大量假阳性死亡信号,严重干扰梯度。
4观测编码:为什么是 4 通道
环境封装层(env.py)将游戏状态转换为 (4, H, W) 的 float32 张量:
为什么不用单通道 int 编码(0=空, 1=身, 2=头, 3=食物)?
- 单通道编码下,卷积核必须同时学习 "2 旁边有 1 是身体" 和 "2 旁边有 3 是食物" 两种完全不同的语义关系。4 通道直接分离语义,每个 filter 专注一种特征
- 方向信息用全图填充(而非拼接到 FC 层),让卷积层在任意感受野位置都能感知方向。否则方向只能在 Flatten 后的全连接层拼接,空间关联性丢失
- 归一化到 [0,1] 有助于训练稳定性,避免 int 值的量级差异
def _get_obs(self) -> np.ndarray: obs = np.zeros((4, self.height, self.width), dtype=np.float32) hr, hc = self._game.snake[0] obs[0, hr, hc] = 1.0 # Ch 0: 头 for r, c in self._game.snake[1:]: obs[1, r, c] = 1.0 # Ch 1: 身体(不含头) if self._game.food: obs[2, food[0], food[1]] = 1.0 # Ch 2: 食物 obs[3, :, :] = self._game.direction / 3.0 # Ch 3: 方向(全图) return obs
Truncation vs Termination
区分terminated(真死亡)和 truncated(超时截断)。截断条件:steps > 100 × len(snake)。
动态阈值的原因:蛇长 30 的合法路径可能超过 2000 步,固定 max_steps 会频繁截断长蛇,引入虚假终止信号。动态 100 × length 给出足够余量,只截断真正的无意义绕圈。截断时 reward=0 — 截断 ≠ 死亡,不应等同惩罚。
5网络架构:Actor-Critic CNN
共享骨干 + 双头输出。在 15×15 棋盘上约 130 万参数,VRAM < 50MB。
为什么 Flatten 而不是 Global Average Pooling
贪吃蛇是位置敏感任务 — "蛇头在 (3,7) 且食物在 (3,9)" 和 "蛇头在 (3,7) 且食物在 (12,1)" 对策略的影响完全不同。GAP 会丢弃空间信息。代价是 Flatten 后的全连接层参数较多(64×15×15×256 = 3.7M 参数),但在 GTX 1660 Ti 6GB VRAM 下完全不是问题。
正交初始化
def _orthogonal_init(layer, gain=np.sqrt(2)): nn.init.orthogonal_(layer.weight, gain=gain) nn.init.zeros_(layer.bias) return layer
所有 Conv/Linear 层用正交初始化(gain=√2 配合 ReLU)。关键区别在输出层:
- Actor gain=0.01:使初始 logits 接近 0,Softmax 后接近均匀分布 [0.25, 0.25, 0.25, 0.25]。保证早期充分随机探索
- Critic gain=1.0:标准初始化,使初始 V(s) 有合理方差,不被 advantage normalization 压平
6PPO 算法实现
PPO 的核心思想:限制每次策略更新的幅度,防止灾难性遗忘。以下是三个关键公式及其实现。
6.1 Clipped Surrogate Objective
LCLIP = -E[ min( rt · At, clip(rt, 1−ε, 1+ε) · At ) ]
# ratio = exp(new_log_prob - old_log_prob) ratio = torch.exp(new_log_probs - old_log_probs[mb_idx]) # clipped surrogate pg_loss1 = -mb_adv * ratio pg_loss2 = -mb_adv * torch.clamp(ratio, 1 - 0.2, 1 + 0.2) policy_loss = torch.max(pg_loss1, pg_loss2).mean()
clip_range=0.2 意味着新策略的概率比不能偏离旧策略超过 ±20%。这防止了一次更新就把策略带偏。
6.2 GAE(广义优势估计)
At = δt + γλ · (1 − donet) · At+1
# 从后往前递推 for t in reversed(range(n_steps)): delta = rewards[t] + gamma * next_val * (1.0 - dones[t]) - values[t] last_gae = delta + gamma * gae_lambda * (1.0 - dones[t]) * last_gae advantages[t] = last_gae returns = advantages + values # V_target = A + V_old
γ=0.99 控制折扣深度,λ=0.95 平衡偏差和方差。λ 越大方差越高但偏差越小;0.95 是经典默认值。
6.3 总损失函数
loss = policy_loss \
+ 0.5 * vf_loss \ # Critic MSE,带 clip
- ent_coef * entropy # 熵正则,鼓励探索
Value loss 也做了 clipping:V_clipped = V_old + clip(V_new - V_old, -0.2, +0.2),取 clipped 和 unclipped MSE 的最大值,防止 Critic 跳跃式更新。
Advantage Normalization
adv = (adv - adv.mean()) / (adv.std() + 1e-8)
每次 PPO 更新前标准化 advantage。这是 PPO 稳定训练的关键 — 否则 advantage 的绝对值会随训练波动,导致梯度忽大忽小。但这也是零和陷阱被遮蔽的原因(见第 8 节)。
7训练循环:数据流详解
= 1024 transitions
buf: (128, 8, 4, 15, 15)
advantages: (128, 8)
returns: (128, 8)
shuffle indices
split minibatch=256
= 16 gradient steps
per rollout
10 episodes
argmax actions
# Rollout buffer 形状 buf_obs = np.zeros((n_steps, n_envs, 4, H, W), dtype=np.float32) # 128×8×4×15×15 buf_actions = np.zeros((n_steps, n_envs), dtype=np.int64) # 128×8 buf_rewards = np.zeros((n_steps, n_envs), dtype=np.float32) # 128×8 buf_dones = np.zeros((n_steps, n_envs), dtype=np.float32) # 128×8 buf_values = np.zeros((n_steps, n_envs), dtype=np.float32) # 128×8 buf_log_probs= np.zeros((n_steps, n_envs), dtype=np.float32) # 128×8 # 每步:CPU 侧 8 个 env 依次 step,GPU 侧批量前向推理 obs_t = torch.from_numpy(np.stack(obs_list)).to(device) # (8, 4, 15, 15) → GPU actions, log_probs, _, values = net.get_action_and_value(obs_t) # 批量采样
每个 rollout 产生 1024 个 transition(128 步 × 8 环境),PPO 用 4 个 epoch、每 epoch 4 个 minibatch 消费这些数据 — 即每条 transition 被训练 4 次。这是 PPO on-policy 的数据效率极限,也是 clip 机制存在的原因:同一批数据反复训练时,clip 防止策略偏离太远。
整个 5M 步训练需要 ~4883 次 rollout,每次 rollout 做 16 次梯度更新。
8四轮迭代:从失败到成功的完整故事
这是本文最核心的部分。最终的 44.20 分不是一次跑出来的,而是经历了 4 轮训练和 3 次失败,每次失败都暴露了一个深层问题。
初始奖励设计:吃食物 +1.0,死亡 -5.0,每步 -0.01,靠近食物 +0.1,远离食物 -0.1(曼哈顿距离 shaping)。
训练曲线看似健康:entropy 从 1.386 降到 0.46~0.65,explained variance 0.5~0.8,Critic 学得不错。但 eval 大量出现 length=300/310/320,恰好等于 max_steps = 100 × 初始蛇长 3。
删除距离 shaping 后改为纯稀疏奖励,棋盘从 20×20 缩到 15×15(状态空间 -44%,食物密度 +79%)。结果更差了:best 从 2.80 跌到 1.50,entropy 坍塌到 0.19。
这时发现了一个根本性的数学 bug:
蛇长 +1 → max_steps += 100 → 多走 100 步 × (-0.01) = -1.0
净收益 = +1.0 − 1.0 = 0
为什么第一轮没发现? 因为距离 shaping ±0.1 提供了唯一真实的梯度信号。删除 shaping 后零和真相暴露。
数学验证(假设均为截断结束,无死亡):
| 吃食物数 | 蛇长 | max_steps | 步数奖励 | 食物奖励 | 总奖励 |
|---|---|---|---|---|---|
| 0 | 3 | 300 | -3.0 | 0 | -3.0 |
| 1 | 4 | 400 | -4.0 | +1.0 | -3.0 |
| 5 | 8 | 800 | -8.0 | +5.0 | -3.0 |
| 10 | 13 | 1300 | -13.0 | +10.0 | -3.0 |
REWARD_FOOD × 1 + REWARD_STEP × max_steps_factor = 0 这个零和关系,在两个设计决策各自的文档里都看不出来 — 只有把它们乘在一起算才能发现。
修复零和:REWARD_FOOD 从 +1.0 提高到 +3.0,每个食物净收益变为 +3 − 1 = +2。ent_coef 从 0.01 提到 0.03 防坍塌。
这次 agent 终于学会了吃食物,avg_score 从 0 涨到 5.78。但分析增长曲线发现增长在 800K 步被 LR 退火截断:
| 区间 | avg_score 增量 | LR 剩余 | 分析 |
|---|---|---|---|
| 500K-600K | +0.62 | 44% | |
| 600K-700K | +0.83 | 33% | 加速 |
| 700K-800K | +1.10 | 23% | 全程最快 |
| 800K-900K | +0.64 | 13% | 骤降 |
| 900K-1M | +0.61 | →0% | 停滞 |
三项关键调整:
- 训练步数 1M → 5M:给 agent 5 倍更多的学习时间
- LR floor = 10%:
new_lr = initial_lr × max(0.1, 1 − step/total),后期保留 2.5e-5 的微调能力 - Entropy 退火:
ent_coef = 0.005 + (0.03 - 0.005) × (1 − step/total),前期高探索(0.03),后期收敛(0.005)
训练关键指标变化
| 指标 | 初始 (0 步) | 中期 (2.5M) | 终期 (5M) | 含义 |
|---|---|---|---|---|
| Entropy | 1.386 | ~0.45 | ~0.33 | 从均匀(ln4=1.386)收敛到确定性策略 |
| Explained Var | -0.01 | ~0.68 | ~0.63 | Critic 预测准确度(越接近 1 越好) |
| Mean Score | 0 | ~30 | ~34 | 最近 100 局平均得分 |
| Policy Loss | ~0 | ~-0.005 | ~-0.004 | 负值正常(maximize objective) |
| FPS | ~2200 | ~2200 | ~2100 | 全程稳定 |
9超参数退火:为什么是关键
第四轮成功的核心不是 "训练更久",而是两个退火策略的精确配合。
Learning Rate 退火
def linear_schedule(initial_lr, current_step, total_steps): fraction = max(0.1, 1.0 - current_step / total_steps) return initial_lr * fraction
关键是 max(0.1, ...) — LR 不降到 0,保留初始值 10% 的下限(2.5e-5)。这来自第三轮的教训:退到 0 意味着后期完全丧失学习能力,而贪吃蛇的难度随蛇长非线性增长,后期恰恰最需要微调。
Entropy Coefficient 退火
# 前期高探索(0.03),后期收敛(0.005) ent_fraction = max(0.0, 1.0 - global_step / total_steps) updater.ent_coef = 0.005 + (0.03 - 0.005) * ent_fraction
这需要 PPOUpdater 支持运行时修改 ent_coef(初始实现是构造时固定的,第四轮改为可动态赋值)。
前期(0-1M 步):高 LR (2.5e-4) + 高 entropy (0.03) = 大步探索,快速覆盖状态空间
中期(1M-3M):LR 渐降 + entropy 渐降 = 在有价值的区域精细搜索
后期(3M-5M):LR=2.5e-5 + entropy=0.005 = 稳定微调,策略接近确定性
10奖励工程:最终参数及其推导
| 事件 | 奖励 | 定义位置 | 推导 |
|---|---|---|---|
| 吃到食物 | +3.0 | snake_game.py | 净收益 = 3.0 − 0.01×100 = +2.0(明确正激励) |
| 死亡 | -5.0 | env.py 覆盖 | 死亡:食物 = 5:3 ≈ 1.67:1(吃 2 个食物可覆盖 1 次死亡) |
| 每步存活 | -0.01 | snake_game.py | 鼓励效率,避免无谓绕圈 |
| 步数截断 | 0.0 | env.py | 截断 ≠ 死亡,不额外惩罚 |
为什么死亡在 env.py 覆盖(-5.0)而非修改 snake_game.py(-1.0)? game/ 目录是纯游戏逻辑,不应耦合 RL 超参数。奖励缩放属于训练侧决策,放在 rl/env.py 方便日后调参而不动游戏代码。
不变式约束(第二轮的教训产出):
3.0 > 0.01 × 100 = 1.0 ✓
11Claude Code 的工作方式
整个项目所有代码由 Claude Code 辅助完成。但 "辅助" 不是 "让 AI 随便写" — 我们通过 OpenSpec 工作流严格管控每个改动。
实际的迭代过程
以第二轮失败后发现零和陷阱为例,完整流程是:
- Explore:分析 1M 步训练日志,推算 mean_ep_reward ≈ -3.18 和 ep_length ≈ 410,发现 "无论吃几个食物总奖励都约 -3.0" 的零和模式
- Propose:生成
snake-fix-reward-zero-sum提案,明确修改 REWARD_FOOD 1.0→3.0,并新增数学不变式 spec - 人类确认:审核提案中的数学推导和修改范围
- Apply:修改 1 行常量 + 1 行 argparse default + 备份旧 checkpoint
- 回写 Design Log:在 phase-2 的 design.md 追加 Decision Log,记录零和 bug 的完整分析和教训
最关键的是 Design Log 的回写。每次训练失败的根因分析、数学推导、教训都被持久化到 OpenSpec 文档中。下次对话只能看到文档,看不到历史对话 — 如果信息只存在于聊天记录里,下一轮迭代就会丢失上下文。
Claude Code 在不同阶段的角色
| 阶段 | 角色 | 具体贡献 |
|---|---|---|
| 架构设计 | 架构师 | 提出三层分离方案,设计 (row,col) 坐标系与 NumPy 对齐 |
| 观测编码 | RL 工程师 | 设计 4 通道方案,论证方向通道用全图填充优于 FC 拼接 |
| 训练失败诊断 | 数据分析师 | 从日志中反推零和陷阱,做数学证明 |
| 超参调优 | 调参顾问 | 根据增长曲线建议 LR floor=10%、entropy 退火 schedule |
| 代码实现 | 程序员 | 所有 .py 文件的编写,含正交初始化、GAE、PPO update |
| 文档维护 | 记录员 | OpenSpec 提案、设计文档、失败分析全程维护 |
12完整超参数速查
| 类别 | 参数 | 值 | 备注 |
|---|---|---|---|
| 环境 | 棋盘 | 15×15 | 状态空间 225 格 |
| n_envs | 8 | CPU 并行 | |
| max_steps_factor | 100 | max_steps = 100×蛇长 | |
| 奖励 | REWARD_FOOD | +3.0 | 在 snake_game.py 定义 |
| DEATH_REWARD | -5.0 | 在 env.py 覆盖 | |
| REWARD_STEP | -0.01 | 在 snake_game.py 定义 | |
| 网络 | Conv layers | 3×(3×3, pad=1) | 感受野 7×7 |
| FC hidden | 256 | ||
| 参数量 | ~1.3M | 15×15 棋盘 | |
| PPO | total_timesteps | 5,000,000 | |
| n_steps | 128 | 每 env 每 rollout | |
| batch_size | 256 | 4 minibatch/epoch | |
| n_epochs | 4 | ||
| learning_rate | 2.5e-4 → 2.5e-5 | 线性退火, floor=10% | |
| ent_coef | 0.03 → 0.005 | 线性退火 | |
| clip_range | 0.2 | 固定 | |
| γ / λ | 0.99 / 0.95 | 标准 GAE |
13运行项目
# 安装 pip install -r requirements.txt pip install torch --index-url https://download.pytorch.org/whl/cu121 # 人类玩 cd snake && set PYTHONPATH=.. python -m game.play_human --width 15 --height 15 # 从零训练(GTX 1660 Ti 约 35 分钟) python -m rl.train --seed 42 --total-timesteps 5000000 # 观看 AI 游玩 python -m rl.evaluate --checkpoint checkpoints/best.pt --episodes 20 --fps 120 # 从 checkpoint 续训 python -m rl.train --resume checkpoints/best.pt
Checkpoint 保存了 model_state_dict、optimizer_state_dict(含 Adam 动量)、timestep、best_score、next_eval_step,支持无损恢复训练。评估时使用 argmax(取最大概率动作),不做采样。
评估结果(20 局,argmax 策略,120 FPS)
使用训练最优 checkpoint(timestep=4,750,336,best_score=44.20)运行 20 局评估:
| 局次 | 得分 | 步数 |
|---|---|---|
| 1 | 41 | 538 |
| 2 | 27 | 379 |
| 3 | 47 | 694 |
| 4 | 40 | 651 |
| 5 | 45 | 743 |
| 6 | 30 | 378 |
| 7 | 33 | 394 |
| 8 | 48 | 892 |
| 9 | 38 | 604 |
| 10 | 33 | 396 |
| 局次 | 得分 | 步数 |
|---|---|---|
| 11 | 41 | 670 |
| 12 | 49 | 857 |
| 13 | 43 | 680 |
| 14 | 40 | 555 |
| 15 | 52 | 942 |
| 16 | 45 | 794 |
| 17 | 33 | 486 |
| 18 | 34 | 455 |
| 19 | 36 | 547 |
| 20 | 33 | 488 |
评估均分 39.40 低于训练 best_score 44.20,这是正常现象 — best_score 是训练过程中 100 局滑动窗口的最高平均值,而非单次峰值。20 局评估中最高达到 52 分(蛇长覆盖 15×15 棋盘的 23.1%),最低 27 分,标准差约 6.7,表明策略稳定但仍存在局间波动。
下一步计划:将 shared/ppo.py 复用到新游戏(Pong 或格斗 PVP),验证跨游戏泛化能力。