检查点#

本文档描述 AReaL 的检查点系统,用于在分布式 RL 训练期间保存模型以进行评估和容错恢复。

概述#

AReaL 提供两种互补的检查点机制:

机制

用途

格式

包含优化器/数据加载器状态

Saver

导出模型用于评估或发布

HuggingFace

RecoverHandler

故障后恢复训练

DCP(分布式检查点)

两种机制在训练期间自动调用,可分别通过 config.saverconfig.recover 进行配置。

检查点格式#

HuggingFace 格式#

Saver 用于模型导出:

  • 标准 HuggingFace 模型格式(safetensors + config.json)

  • 兼容 transformers.AutoModel.from_pretrained()

  • 可上传到 HuggingFace Hub

  • 不包含优化器状态

DCP 格式(分布式检查点)#

RecoverHandler 用于容错:

  • 后端原生分布式检查点格式(torch.distributed.checkpoint 或 Megatron 分布式检查点)

  • 跨所有 rank 分片以实现高效的并行 I/O

  • 包含模型权重、优化器状态、RNG 状态等

  • 后端特定:检查点仅与相同并行配置兼容

  • 覆盖之前的检查点以节省磁盘空间

架构#

PPOTrainer.train()
│
├── 训练循环
│   ├── Rollout、计算值、PPO 更新...
│   │
│   ├── _save_hf()                          # HuggingFace 导出
│   │   └── Saver.save()
│   │       └── engine.save(weight_format="hf")
│   │
│   └── _save_recover_checkpoint()          # 容错
│       └── RecoverHandler.dump()
│           └── engine.save(weight_format="dcp", with_optim=True)
│
└── 重启时
    └── RecoverHandler.load()
        ├── 恢复 dataloader、saver、evaluator 状态
        └── engine.load(weight_format="dcp", with_optim=True)

Saver:HuggingFace 模型导出#

Saver 定期以 HuggingFace 格式导出模型权重,用于评估或部署。

保存模式#

mode 参数控制检查点的写入方式:

模式

行为

auto

对 Archon 引擎使用异步,其他使用同步(默认)。对所有引擎都是零配置最优选择。

sync

始终同步 dcp.save()

async

始终使用进程内异步,带有固定内存暂存。仅适用于 Archon 引擎;其他引擎回退到同步并发出警告。额外的 CPU 固定内存与每 rank 模型分片大小成正比(例如,70B 模型在 8 GPU 上每 rank 约 17.5 GB)。

使用默认的 auto 模式,Archon 引擎用户自动获得异步检查点保存 —— 训练循环仅在检查点暂存到固定 CPU 内存时阻塞,实际磁盘 I/O 在后台进程中发生。

配置#

通过 config.saver 配置:

参数

类型

默认值

描述

mode

str

"auto"

保存模式(见上文)。

freq_epochs

int | None

None

每 N 个 epoch 保存一次。None 禁用。

freq_steps

int | None

None

每 N 步保存一次。None 禁用。

freq_secs

int | None

None

每 N 秒保存一次。None 禁用。

配置示例:

saver:
  freq_epochs: 1      # 每个 epoch 结束时保存
  freq_steps: null    # 禁用
  freq_secs: null      # 禁用
  # mode 默认为 "auto" - Archon 用户自动获得异步

当满足 epoch/步数/时间任一条件时触发保存。

输出位置#

检查点保存到:

{fileroot}/checkpoints/{user}/{experiment_name}/{trial_name}/default/
└── epoch{E}epochstep{S}globalstep{G}/
    ├── config.json
    ├── model.safetensors(或 model-00001-of-00002.safetensors 等)
    ├── tokenizer.json
    └── ...

使用方式#

使用标准 HuggingFace API 加载保存的检查点:

from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained(
    "/path/to/checkpoint/epoch0epochstep99globalstep99"
)
tokenizer = AutoTokenizer.from_pretrained(
    "/path/to/checkpoint/epoch0epochstep99globalstep99"
)

RecoverHandler:容错#

RecoverHandler 通过保存完整训练状态实现故障后恢复训练。

配置#

通过 config.recover 配置:

参数

类型

默认值

描述

mode

str

“disabled”

恢复模式:”on”/”auto” 或 “off”/”disabled”

freq_epochs

int | None

None

每 N 个 epoch 检查点

freq_steps

int | None

None

每 N 步检查点

freq_secs

int | None

None

每 N 秒检查点

retries

int

3

启用恢复时的恢复重试次数

恢复模式#

模式

行为

onauto

如果存在有效检查点则自动恢复

offdisabled

不进行检检查点或恢复

启用恢复(on/auto)时,系统将:

  1. 定期保存恢复检查点(模型权重、优化器状态、dataloader 位置)

  2. 重启时自动从最后一个有效检查点恢复

  3. 失败时重试最多 retries

配置示例:

recover:
  mode: on            # 或 "auto"(向后兼容)
  freq_steps: 100     # 每 100 步检查点
  retries: 3

保存的内容#

RecoverHandler 保存完整训练状态:

组件

内容

模型权重

DCP 格式,跨 rank 分片

优化器状态

动量、方差(Adam)、学习率调度器

RNG 状态

Python、NumPy、PyTorch、CUDA 随机状态

Dataloader 状态

数据集中的当前位置

训练进度

Epoch、step、global_step 计数器

辅助状态

Saver、Evaluator、StatsLogger 状态

输出位置#

恢复检查点保存到:

{fileroot}/checkpoints/{user}/{experiment_name}/{trial_name}/
├── default/
│   └── recover_checkpoint/     # 模型 + 优化器(DCP 格式)
│       ├── __0_0.distcp
│       ├── __1_0.distcp
│       └── ...
├── critic/                     # 如果使用 critic
│   └── recover_checkpoint/
└── recover_info/               # 元数据
    ├── step_info.json
    ├── saver_info.json
    ├── evaluator_info.json
    ├── stats_logger_info.json
    ├── checkpoint_info.json
    └── dataloader_info.pkl

恢复过程#

恢复训练时:

  1. RecoverHandler.load() 恢复所有保存的状态(如有)

  2. 训练从 last_step_info.next().global_step 继续

  3. 推理引擎权重同步以匹配恢复的状态

最佳实践#

频率指南#

场景

推荐设置

长时间训练

freq_epochs: 1freq_steps: 1000

不可预测的时间

freq_secs: 7200

不稳定的集群

freq_steps: 100 配合 recover.mode: on

磁盘空间有限

较低频率,依赖最终检查点

调试

freq_steps: 1 以快速迭代

磁盘空间注意事项#

  • Saver:每次保存创建新目录。高频率会消耗大量空间。

  • RecoverHandler:覆盖之前的检查点。同时只存在一份副本。

恢复技巧#

  1. 验证检查点有效性:检查 recover_info/step_info.json 获取最后保存的步数

  2. 需要相同配置:DCP 检查点需要相同的并行配置、实验名称和试次名称

  3. 干净重启:删除 recover_info/ 目录以全新开始