880 字
4 分钟
小样本文本生成训练的实践:从 Transformer 架构到生成优化
1. 问题背景
小样本文本生成面临几个典型问题:
- 数据不足:模型容易记住训练集,泛化能力差。
- 显存限制:Transformer 模型的多头注意力计算复杂,显存消耗随序列长度平方增加。
- 生成重复:小模型在生成时容易重复相同 token,降低可读性。
- 训练不稳定:梯度爆炸、损失震荡常见,尤其在深层模型或长序列训练时。
针对这些问题,我们设计了以下解决方案。
2. 小型 Transformer 架构
本文采用 自定义小型 Transformer 模型:
- Embedding + 学习型位置编码:使用
nn.Embedding+LearnedPositionalEncoding,提升序列位置信息表达能力。 - Transformer Block:每个 block 包含:
- 多头自注意力 (
nn.MultiheadAttention) - 前馈网络 (
nn.Linear → ReLU → Linear) - LayerNorm 与残差连接
- 模型输出:线性层映射到词表大小,直接预测下一个 token。
- 多头自注意力 (
class SmallTransformer(nn.Module): def __init__(self, vocab_size, d_model=512, n_head=8, dim_ff=512, n_layers=12, max_len=64): super().__init__() self.embedding = nn.Embedding(vocab_size, d_model) self.pos_enc = LearnedPositionalEncoding(d_model, max_len=max_len) self.layers = nn.ModuleList([TransformerBlock(d_model, n_head, dim_ff) for _ in range(n_layers)]) self.norm = nn.LayerNorm(d_model) self.fc_out = nn.Linear(d_model, vocab_size)
def forward(self, x): x = self.embedding(x) x = self.pos_enc(x) for layer in self.layers: x = layer(x) x = self.norm(x) return self.fc_out(x)3. 数据处理与小样本训练策略
为了适应小样本场景,我们采用以下策略:
- 限制最大序列长度:
MAX_SEQ_LEN=64,显著减少自注意力计算量,降低显存占用。 - 梯度累积:
ACCUM_STEPS=4,模拟大批量训练,同时显存占用低。 - 混合精度训练:使用
torch.amp.autocast和GradScaler,进一步节省显存,提高训练速度。 - 数据填充与标签对齐:每条序列右填充 PAD token,并将下一个 token 作为预测目标。
x = seq + [PAD_ID] * pad_leny = seq[1:] + [PAD_ID] * (pad_len + 1)- 断点训练:保存训练状态,包括模型、优化器、调度器和 GradScaler,可随时恢复训练。
4. 优化训练稳定性
- 梯度裁剪:
clip_grad_norm_防止梯度爆炸。 - 学习率调度:线性 Warmup + 线性衰减 (
LambdaLR),稳定训练收敛。 - 早停机制:验证集损失不下降超过
PATIENCE轮则停止训练,避免过拟合。
5. 生成策略优化
在小样本和小模型场景下,生成文本容易重复。我们结合 Nucleus Sampling 与 重复惩罚:
- 温度调节:控制生成的随机性,避免过于保守。
- Top-p 采样:只在累计概率达到
TOP_P的候选 token 中采样,保证多样性。 - 重复惩罚:对最近出现过的 token 降低概率,减少重复生成。
def nucleus_sampling(logits, p=0.9, repetition_penalty=2, past_tokens=None): probs = F.softmax(logits / temperature, dim=-1) for token_id in past_tokens[-100:]: probs[token_id] /= repetition_penalty # Top-p 采样 sorted_probs, sorted_indices = torch.sort(probs, descending=True) cumulative_probs = torch.cumsum(sorted_probs, dim=-1) mask = cumulative_probs <= p top_probs = sorted_probs[mask] top_indices = sorted_indices[mask] next_id = top_indices[torch.multinomial(top_probs / top_probs.sum(), 1).item()].item() return next_id生成效果示例:
输入: Hello i am -> Hello i am happy to meet you输入: Today is -> Today is a beautiful day输入: I am -> I am feeling great today6. 实践经验总结
- 小模型 + 小样本:合理限制模型规模,结合梯度累积和混合精度,可在单卡显存有限的情况下训练。
- 位置编码:学习型位置编码比固定正弦编码在短文本生成上更灵活。
- 生成优化:Top-p + 重复惩罚是抑制重复文本的有效方法。
- 断点训练:在小样本和长训练轮数场景中非常实用,防止训练中断导致损失。
小样本文本生成不是无法解决,只要结合 训练策略 + 模型设计 + 生成优化,即可实现高质量文本生成。