1511 字
8 分钟
小样本文本生成训练代码
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import math
from datasets import load_dataset
from transformers import BertTokenizerFast
from torch.amp import autocast, GradScaler
from torch.utils.data import Dataset, DataLoader
import os
from tqdm import tqdm
print("Init")
# ===== 配置参数 =====
CONFIG = {
"BATCH_SIZE": 128, # 每个训练批次的样本数,显存占用与计算量随之增加,越大训练越快但显存压力大
"MAX_SEQ_LEN": 64, # 每个序列的最大长度,显存占用与计算量随序列长度平方关系增长(注意Transformer自注意力机制)
"D_MODEL": 512, # Transformer隐藏层维度,显存和计算量直接与D_MODEL平方成正比,影响模型容量和表达能力:D_MODEL 必须能被 N_HEAD 整除
"DIM_FF": 512, # 前馈网络维度(Feed-Forward层),显存和计算量随DIM_FF增加而增加
"N_HEAD": 8, # 多头注意力头数,显存占用和计算量随头数增加而增加
"N_LAYERS": 12, # Transformer层数,显存和计算量线性增加,层数越多模型越大
"ACCUM_STEPS": 4, # 梯度累积步数,相当于把小批次累积成大批次训练,**显存占用低**但计算时间稍微增加
"EPOCHS": 45, # 总训练轮数,计算时间线性增加,与显存无关·
"LEARNING_RATE": 0.001, # 学习率,影响模型收敛速度,不直接影响显存或计算量
"PATIENCE": 5, # 验证集损失未下降的早停耐心轮数,控制训练提前结束,不影响显存
"TOP_P": 0.9, # nucleus sampling采样的累计概率阈值,影响生成文本多样性,不影响训练显存或速度
"TEMPERATURE": 1.2, # 生成温度,控制采样随机性,不影响训练显存或速度
"REPETITION_PENALTY": 2 # 生成时重复惩罚系数,用于抑制重复token,不影响训练显存或速度
}
# ===== 1. 加载数据集 =====
data_files = {
"train": "./datas/train-wiki.parquet",
"validation": "./datas/valid-wiki.parquet",
"test": "./datas/test-wiki.parquet"
}
dataset = load_dataset("parquet", data_files=data_files)
print(dataset["train"][0]) # 查看第一条数据
# 数据量:训练集5000条,验证集500条
num_train_samples = min(5000, len(dataset["train"]))
num_val_samples = min(500, len(dataset["validation"]))
train_data = dataset["train"] #.select(range(num_train_samples))
val_data = dataset["validation"] #.select(range(num_val_samples))
print(f"训练集条数: {len(train_data)}, 验证集条数: {len(val_data)}")
# ===== 2. 使用 BERT 分词器 =====
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased", cache_dir="./my_cache")#bert-base-chinese
token2id = tokenizer.get_vocab()
id2token = {id_: token for token, id_ in token2id.items()}
vocab_size = len(token2id)
PAD_ID = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0 #token2id.get("[PAD]", 0)
UNK_ID = tokenizer.unk_token_id if tokenizer.unk_token_id is not None else 100 #token2id.get("[UNK]", 100)
# 确保 id2token 包含 UNK_ID
if UNK_ID not in id2token:
print(f"Warning: UNK_ID {UNK_ID} not in id2token, adding [UNK]")
id2token[UNK_ID] = "[UNK]"
print("词表大小:", vocab_size)
print("前20个 token:", list(token2id.keys())[:20])
print(f"PAD_ID: {PAD_ID}, UNK_ID: {UNK_ID}")
# ===== 3. 学习型位置编码 =====
class LearnedPositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=512):
super().__init__()
self.pos_embedding = nn.Embedding(max_len, d_model)
def forward(self, x):
positions = torch.arange(0, x.size(1), device=x.device).unsqueeze(0)
return x + self.pos_embedding(positions)
# ===== 4. Transformer Block =====
class TransformerBlock(nn.Module):
def __init__(self, d_model=CONFIG["D_MODEL"], n_head=CONFIG["N_HEAD"], dim_ff=CONFIG["DIM_FF"], dropout=0.5):
super().__init__()
self.attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=n_head, batch_first=True)
self.norm1 = nn.LayerNorm(d_model)
self.ff = nn.Sequential(
nn.Linear(d_model, dim_ff),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(dim_ff, d_model)
)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
attn_out, _ = self.attn(x, x, x)
x = self.norm1(x + self.dropout(attn_out))
ff_out = self.ff(x)
x = self.norm2(x + ff_out)
return x
# ===== 5. 小型 Transformer 模型 =====
class SmallTransformer(nn.Module):
def __init__(self, vocab_size, d_model=CONFIG["D_MODEL"], n_head=CONFIG["N_HEAD"], dim_ff=CONFIG["DIM_FF"],
n_layers=CONFIG["N_LAYERS"], max_len=CONFIG["MAX_SEQ_LEN"]):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.pos_enc = LearnedPositionalEncoding(d_model, max_len=max_len)
self.layers = nn.ModuleList([TransformerBlock(d_model, n_head, dim_ff, dropout=0.3) for _ in range(n_layers)])
self.norm = nn.LayerNorm(d_model)
self.fc_out = nn.Linear(d_model, vocab_size)
self.dropout = nn.Dropout(0.3)
self.max_len = max_len
def forward(self, x):
if x.size(1) > self.max_len:
x = x[:, :self.max_len]
x = self.embedding(x)
x = self.pos_enc(x)
x = self.dropout(x)
for layer in self.layers:
x = layer(x)
x = self.norm(x)
return self.fc_out(x)
# ===== 6. 数据预处理 =====
class NewsDataset(Dataset):
def __init__(self, sentences, tokenizer, max_seq_len=CONFIG["MAX_SEQ_LEN"]):
self.data = []
for s in sentences:
tokens = tokenizer.encode(s, add_special_tokens=True, max_length=max_seq_len, truncation=True)
if len(tokens) > 1:
self.data.append(tokens)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
seq = self.data[idx]
if len(seq) > CONFIG["MAX_SEQ_LEN"]:
seq = seq[:CONFIG["MAX_SEQ_LEN"]]
pad_len = CONFIG["MAX_SEQ_LEN"] - len(seq)
x = seq + [PAD_ID] * pad_len
y = seq[1:] + [PAD_ID] * (pad_len + 1)
return torch.tensor(x), torch.tensor(y)
train_dataset = NewsDataset(train_data["text"], tokenizer)
val_dataset = NewsDataset(val_data["text"], tokenizer)
print(f"训练数据条数: {len(train_dataset)}")
print("前10条示例:", train_data["text"][:10])
# ===== 7. 数据加载器 =====
train_loader = DataLoader(train_dataset, batch_size=CONFIG["BATCH_SIZE"], shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=CONFIG["BATCH_SIZE"], shuffle=False, num_workers=0)
# ===== 8. 训练或加载模型(支持断点训练) =====
device = "cuda" if torch.cuda.is_available() else "cpu"
model = SmallTransformer(vocab_size).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=CONFIG["LEARNING_RATE"], weight_decay=1e-4)
# Warmup 调度器
from torch.optim.lr_scheduler import LambdaLR
def evaluate(model, val_loader, loss_fn, device):
model.eval()
total_loss = 0
num_batches = 0
with torch.no_grad():
for x_batch, y_batch in val_loader:
x_batch, y_batch = x_batch.to(device), y_batch.to(device)
with autocast('cuda'):
logits = model(x_batch)
logits = logits.view(-1, vocab_size)
targets = y_batch.view(-1)
loss = loss_fn(logits, targets)
total_loss += loss.item()
num_batches += 1
return total_loss / num_batches if num_batches > 0 else float('inf')
def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps):
def lr_lambda(current_step):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)))
return LambdaLR(optimizer, lr_lambda)
num_warmup_steps = 1000
num_training_steps = len(train_loader) * 50 # EPOCHS=50
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps)
loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_ID, label_smoothing=0.1)
scaler = GradScaler('cuda')
# 断点训练
checkpoint_path = "checkpoint.pt"
start_epoch = 0
best_val_loss = float('inf')
counter = 0
if os.path.exists(checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
scaler.load_state_dict(checkpoint["scaler_state_dict"])
start_epoch = checkpoint["epoch"] + 1
best_val_loss = checkpoint["best_val_loss"]
counter = checkpoint.get("counter", 0)
print(f"Resuming training from epoch {start_epoch}")
# 训练循环
for epoch in range(start_epoch, CONFIG["EPOCHS"]):
model.train()
total_loss = 0
num_batches = 0
with tqdm(total=len(train_loader), desc=f"Epoch {epoch+1}/{CONFIG['EPOCHS']}") as pbar:
for i, (x_batch, y_batch) in enumerate(train_loader):
x_batch, y_batch = x_batch.to(device), y_batch.to(device)
with autocast('cuda'):
logits = model(x_batch)
logits = logits.view(-1, vocab_size)
targets = y_batch.view(-1)
loss = loss_fn(logits, targets) / CONFIG["ACCUM_STEPS"]
scaler.scale(loss).backward()
if (i + 1) % CONFIG["ACCUM_STEPS"] == 0:
scaler.unscale_(optimizer)
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
# 只在这里更新 progress bar
pbar.set_postfix_str(f"\033[32mGradient norm: {grad_norm:.4f}\033[0m")
total_loss += loss.item() * CONFIG["ACCUM_STEPS"]
num_batches += 1
pbar.update(1)
avg_train_loss = total_loss / num_batches if num_batches > 0 else float('inf')
# 验证
val_loss = evaluate(model, val_loader, loss_fn, device)
print(f"Epoch {epoch+1}, Train Loss: {avg_train_loss:.4f}, Val Loss: {val_loss:.4f}")
# 保存最佳模型
if val_loss < best_val_loss:
best_val_loss = val_loss
counter = 0
torch.save(model.state_dict(), "best_model.pt")
print(f"Best model saved at epoch {epoch+1}")
else:
counter += 1
if counter >= CONFIG["PATIENCE"]:
print(f"Early stopping at epoch {epoch+1}")
break
# 保存断点
torch.save({
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"scheduler_state_dict": scheduler.state_dict(),
"scaler_state_dict": scaler.state_dict(),
"best_val_loss": best_val_loss,
"counter": counter
}, checkpoint_path)
scheduler.step()
# ===== 9. 改进的生成逻辑(带重复惩罚) =====
def nucleus_sampling(logits, p=CONFIG["TOP_P"], repetition_penalty=CONFIG["REPETITION_PENALTY"], past_tokens=None):
if past_tokens is None:
past_tokens = []
logits = logits / CONFIG["TEMPERATURE"] # temperature 调整
probs = F.softmax(logits, dim=-1)
# 扩大重复惩罚作用范围
token_counts = {}
for token_id in past_tokens[-100:]: # 改为最近 50 个 token
token_counts[token_id] = token_counts.get(token_id, 0) + 1
for token_id, count in token_counts.items():
probs[token_id] /= repetition_penalty ** count
probs = probs / probs.sum()
sorted_probs, sorted_indices = torch.sort(probs, descending=True)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
mask = cumulative_probs <= p
top_p_probs = sorted_probs[mask]
top_p_indices = sorted_indices[mask]
if top_p_probs.sum() == 0:
top_p_probs = sorted_probs[:1]
top_p_indices = sorted_indices[:1]
top_p_probs = top_p_probs / top_p_probs.sum()
next_id = top_p_indices[torch.multinomial(top_p_probs, 1).item()].item()
return next_id
def generate(model, start_text, max_len=20, temperature=CONFIG["TEMPERATURE"], top_p=CONFIG["TOP_P"]):
model.eval()
input_ids = tokenizer.encode(start_text, add_special_tokens=False, max_length=CONFIG["MAX_SEQ_LEN"],
truncation=True)
output_text = start_text
past_tokens = input_ids.copy()
for _ in range(max_len):
if len(input_ids) > CONFIG["MAX_SEQ_LEN"]:
input_ids = input_ids[-CONFIG["MAX_SEQ_LEN"]:]
x_tensor = torch.tensor([input_ids]).to(device)
with torch.no_grad():
with autocast('cuda'):
logits = model(x_tensor)
logits_last = logits[0, -1] / temperature
next_id = nucleus_sampling(logits_last, p=top_p, repetition_penalty=CONFIG["REPETITION_PENALTY"],
past_tokens=past_tokens)
token = id2token.get(next_id, "[UNK]")
output_text += token
input_ids.append(next_id)
past_tokens.append(next_id)
return tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(input_ids))
# ===== 10. 测试生成 =====
print("\n测试生成(Nucleus Sampling):")
print("输入: Hello i am ->", generate(model, "Hello i am ", max_len=10))
print("输入: Today is ->", generate(model, "Today is ", max_len=10))
print("输入: I am ->", generate(model, "I am ", max_len=10))