数据集#

AReaL 直接集成 HuggingFace datasets 包中的 Dataset 类。这使您可以在训练前完全灵活地加载、处理和过滤数据。

所需的列名(键)和数据格式取决于 Agent Workflow(用于在线强化学习)或训练引擎(用于离线训练,例如用于监督微调 SFT 的 LMEngine)的具体实现。

以下是现有实现中的两个具体示例:

SFT(离线训练)#

在 SFT 示例中,加载的数据直接传递给 train_lm 方法:

# areal/trainer/sft_trainer.py
for global_step in range(start_step, max_steps):
    batch = self._load_bcast_from(data_generator)
    self.actor.train_lm(batch)

在这种情况下,train_lm 方法需要 “input_ids”、”attention_mask” 和 “loss_mask” 键才能工作。我们首先对数据集进行分词以提取 “input_ids” 和 “loss_mask”。然后,使用 pad_sequences_to_tensors 方法来批量处理多个序列并附加 “attention_mask”:

# areal/dataset/gsm8k.py
def get_gsm8k_sft_dataset(
    path: str,
    split: str,
    tokenizer,
    max_length: int | None = None,
):
    dataset = load_dataset(path=path, name="main", split=split)

    def process(sample):
        seq_token = tokenizer.encode(
            sample["question"] + sample["answer"] + tokenizer.eos_token
        )
        prompt_token = tokenizer.encode(sample["question"])
        loss_mask = [0] * len(prompt_token) + [1] * (len(seq_token) - len(prompt_token))
        return {"input_ids": seq_token, "loss_mask": loss_mask}

    dataset = dataset.map(process).remove_columns(["question", "answer"])

    if max_length is not None:
        # Filter out sequences longer than max_length
        dataset = dataset.filter(lambda x: len(x["input_ids"]) <= max_length)

    return dataset

GRPO(在线训练)#

在 GRPO 示例中,加载的数据首先用于推理而不是训练:

# areal/trainer/rl_trainer.py
self.train_dataloader = self._create_dataloader(
    train_dataset,
    dataset_config=self.config.train_dataset,
    rank=self.actor.data_parallel_rank,
    world_size=self.actor.data_parallel_world_size,
)
for global_step in range(start_step, max_steps):
    rollout_batch = self.actor.prepare_batch(
        self.train_dataloader,
        workflow=workflow,
        workflow_kwargs=workflow_kwargs,
        should_accept_fn=dynamic_filter_fn,
        group_size=config.gconfig.n_samples,
        dynamic_bs=self.config.dynamic_bs,
    )

请注意,这里的 collate_fn 是一个恒等函数,这意味着它只是将各个数据项的列表作为一个批次返回。在 prepare_batch 中,数据随后被分派到 Workflow 的多个并发执行中,其中每个分派的数据对应一个单独的 episode。

在以下部分中,我们以 RLVRWorkflow 为例。Agent Workflow 使用输入数据的模式相同。只要符合您的 Workflow 实现,您可以随意修改自定义数据集以包含任何键。

RLVRWorkflow 实现从数据字典中提取 “messages” 字段作为生成响应的提示。此外,此数据作为关键字参数传递给 reward_fn,这允许奖励函数利用数据集中的其他字段,如 “answers”。示例如下:

# areal/workflow/rlvr.py
class RLVRWorkflow(RolloutWorkflow):

    async def arun_episode(self, engine: InferenceEngine, data):
        input_ids = self.tokenizer.apply_chat_template(
            data["messages"],
            tokenize=True,
            add_generation_prompt=True,
            enable_thinking=self.enable_thinking,
        )
        req = ModelRequest(
            input_ids=input_ids,
            ...
        )
        ...
        reward = self.reward_fn(
            prompt=prompt_str,
            completions=completions_str,
            prompt_ids=resp.input_tokens,
            completion_ids=resp.output_tokens,
            **data,
        )

因此,必须在加载数据集时构建 “messages” 字段,并且奖励函数应该被定义为处理数据集的特定字段。以下是您可以如何为此示例处理数据集:

from datasets import load_dataset

def get_gsm8k_rl_dataset(
    path: str,
    split: str,
    tokenizer,
    max_length: int | None = None,
):
    dataset = load_dataset(path=path, name="main", split=split)

    def process(sample):
        messages = [
            {
                "role": "user",
                "content": sample["question"]
                + "\nPlease put your final answer within \\boxed{}.",
            }
        ]
        return {"messages": messages}

    dataset = dataset.map(process).remove_columns(["question"])

    # Filter out sequences longer than max_length if tokenizer and max_length are provided
    if max_length is not None:

        def filter_length(sample):
            # Tokenize the user content to check length
            content = sample["messages"][0]["content"]
            tokens = tokenizer.encode(content)
            return len(tokens) <= max_length

        dataset = dataset.filter(filter_length)

    return dataset