1511 字
8 分钟
小样本文本生成训练代码
import torchimport torch.nn as nnimport torch.nn.functional as Fimport randomimport mathfrom datasets import load_datasetfrom transformers import BertTokenizerFastfrom torch.amp import autocast, GradScalerfrom torch.utils.data import Dataset, DataLoaderimport osfrom tqdm import tqdm
print("Init")
# ===== 配置参数 =====CONFIG = { "BATCH_SIZE": 128, # 每个训练批次的样本数,显存占用与计算量随之增加,越大训练越快但显存压力大 "MAX_SEQ_LEN": 64, # 每个序列的最大长度,显存占用与计算量随序列长度平方关系增长(注意Transformer自注意力机制) "D_MODEL": 512, # Transformer隐藏层维度,显存和计算量直接与D_MODEL平方成正比,影响模型容量和表达能力:D_MODEL 必须能被 N_HEAD 整除 "DIM_FF": 512, # 前馈网络维度(Feed-Forward层),显存和计算量随DIM_FF增加而增加 "N_HEAD": 8, # 多头注意力头数,显存占用和计算量随头数增加而增加 "N_LAYERS": 12, # Transformer层数,显存和计算量线性增加,层数越多模型越大 "ACCUM_STEPS": 4, # 梯度累积步数,相当于把小批次累积成大批次训练,**显存占用低**但计算时间稍微增加 "EPOCHS": 45, # 总训练轮数,计算时间线性增加,与显存无关· "LEARNING_RATE": 0.001, # 学习率,影响模型收敛速度,不直接影响显存或计算量 "PATIENCE": 5, # 验证集损失未下降的早停耐心轮数,控制训练提前结束,不影响显存 "TOP_P": 0.9, # nucleus sampling采样的累计概率阈值,影响生成文本多样性,不影响训练显存或速度 "TEMPERATURE": 1.2, # 生成温度,控制采样随机性,不影响训练显存或速度 "REPETITION_PENALTY": 2 # 生成时重复惩罚系数,用于抑制重复token,不影响训练显存或速度}
# ===== 1. 加载数据集 =====data_files = { "train": "./datas/train-wiki.parquet", "validation": "./datas/valid-wiki.parquet", "test": "./datas/test-wiki.parquet"}dataset = load_dataset("parquet", data_files=data_files)
print(dataset["train"][0]) # 查看第一条数据
# 数据量:训练集5000条,验证集500条num_train_samples = min(5000, len(dataset["train"]))num_val_samples = min(500, len(dataset["validation"]))train_data = dataset["train"] #.select(range(num_train_samples))val_data = dataset["validation"] #.select(range(num_val_samples))
print(f"训练集条数: {len(train_data)}, 验证集条数: {len(val_data)}")
# ===== 2. 使用 BERT 分词器 =====tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased", cache_dir="./my_cache")#bert-base-chinesetoken2id = tokenizer.get_vocab()id2token = {id_: token for token, id_ in token2id.items()}vocab_size = len(token2id)PAD_ID = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0 #token2id.get("[PAD]", 0)UNK_ID = tokenizer.unk_token_id if tokenizer.unk_token_id is not None else 100 #token2id.get("[UNK]", 100)
# 确保 id2token 包含 UNK_IDif UNK_ID not in id2token: print(f"Warning: UNK_ID {UNK_ID} not in id2token, adding [UNK]") id2token[UNK_ID] = "[UNK]"
print("词表大小:", vocab_size)print("前20个 token:", list(token2id.keys())[:20])print(f"PAD_ID: {PAD_ID}, UNK_ID: {UNK_ID}")
# ===== 3. 学习型位置编码 =====class LearnedPositionalEncoding(nn.Module): def __init__(self, d_model, max_len=512): super().__init__() self.pos_embedding = nn.Embedding(max_len, d_model)
def forward(self, x): positions = torch.arange(0, x.size(1), device=x.device).unsqueeze(0) return x + self.pos_embedding(positions)
# ===== 4. Transformer Block =====class TransformerBlock(nn.Module): def __init__(self, d_model=CONFIG["D_MODEL"], n_head=CONFIG["N_HEAD"], dim_ff=CONFIG["DIM_FF"], dropout=0.5): super().__init__() self.attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=n_head, batch_first=True) self.norm1 = nn.LayerNorm(d_model) self.ff = nn.Sequential( nn.Linear(d_model, dim_ff), nn.ReLU(), nn.Dropout(dropout), nn.Linear(dim_ff, d_model) ) self.norm2 = nn.LayerNorm(d_model) self.dropout = nn.Dropout(dropout)
def forward(self, x): attn_out, _ = self.attn(x, x, x) x = self.norm1(x + self.dropout(attn_out)) ff_out = self.ff(x) x = self.norm2(x + ff_out) return x
# ===== 5. 小型 Transformer 模型 =====class SmallTransformer(nn.Module): def __init__(self, vocab_size, d_model=CONFIG["D_MODEL"], n_head=CONFIG["N_HEAD"], dim_ff=CONFIG["DIM_FF"], n_layers=CONFIG["N_LAYERS"], max_len=CONFIG["MAX_SEQ_LEN"]): super().__init__() self.embedding = nn.Embedding(vocab_size, d_model) self.pos_enc = LearnedPositionalEncoding(d_model, max_len=max_len) self.layers = nn.ModuleList([TransformerBlock(d_model, n_head, dim_ff, dropout=0.3) for _ in range(n_layers)]) self.norm = nn.LayerNorm(d_model) self.fc_out = nn.Linear(d_model, vocab_size) self.dropout = nn.Dropout(0.3) self.max_len = max_len
def forward(self, x): if x.size(1) > self.max_len: x = x[:, :self.max_len] x = self.embedding(x) x = self.pos_enc(x) x = self.dropout(x) for layer in self.layers: x = layer(x) x = self.norm(x) return self.fc_out(x)
# ===== 6. 数据预处理 =====class NewsDataset(Dataset): def __init__(self, sentences, tokenizer, max_seq_len=CONFIG["MAX_SEQ_LEN"]): self.data = [] for s in sentences: tokens = tokenizer.encode(s, add_special_tokens=True, max_length=max_seq_len, truncation=True) if len(tokens) > 1: self.data.append(tokens)
def __len__(self): return len(self.data)
def __getitem__(self, idx): seq = self.data[idx] if len(seq) > CONFIG["MAX_SEQ_LEN"]: seq = seq[:CONFIG["MAX_SEQ_LEN"]] pad_len = CONFIG["MAX_SEQ_LEN"] - len(seq) x = seq + [PAD_ID] * pad_len y = seq[1:] + [PAD_ID] * (pad_len + 1) return torch.tensor(x), torch.tensor(y)
train_dataset = NewsDataset(train_data["text"], tokenizer)val_dataset = NewsDataset(val_data["text"], tokenizer)
print(f"训练数据条数: {len(train_dataset)}")print("前10条示例:", train_data["text"][:10])
# ===== 7. 数据加载器 =====
train_loader = DataLoader(train_dataset, batch_size=CONFIG["BATCH_SIZE"], shuffle=True, num_workers=0)val_loader = DataLoader(val_dataset, batch_size=CONFIG["BATCH_SIZE"], shuffle=False, num_workers=0)
# ===== 8. 训练或加载模型(支持断点训练) =====device = "cuda" if torch.cuda.is_available() else "cpu"model = SmallTransformer(vocab_size).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=CONFIG["LEARNING_RATE"], weight_decay=1e-4)
# Warmup 调度器from torch.optim.lr_scheduler import LambdaLR
def evaluate(model, val_loader, loss_fn, device): model.eval() total_loss = 0 num_batches = 0 with torch.no_grad(): for x_batch, y_batch in val_loader: x_batch, y_batch = x_batch.to(device), y_batch.to(device) with autocast('cuda'): logits = model(x_batch) logits = logits.view(-1, vocab_size) targets = y_batch.view(-1) loss = loss_fn(logits, targets) total_loss += loss.item() num_batches += 1 return total_loss / num_batches if num_batches > 0 else float('inf')
def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps): def lr_lambda(current_step): if current_step < num_warmup_steps: return float(current_step) / float(max(1, num_warmup_steps)) return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))) return LambdaLR(optimizer, lr_lambda)
num_warmup_steps = 1000num_training_steps = len(train_loader) * 50 # EPOCHS=50scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps)
loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_ID, label_smoothing=0.1)scaler = GradScaler('cuda')
# 断点训练checkpoint_path = "checkpoint.pt"start_epoch = 0best_val_loss = float('inf')counter = 0
if os.path.exists(checkpoint_path): checkpoint = torch.load(checkpoint_path, map_location=device) model.load_state_dict(checkpoint["model_state_dict"]) optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) scheduler.load_state_dict(checkpoint["scheduler_state_dict"]) scaler.load_state_dict(checkpoint["scaler_state_dict"]) start_epoch = checkpoint["epoch"] + 1 best_val_loss = checkpoint["best_val_loss"] counter = checkpoint.get("counter", 0) print(f"Resuming training from epoch {start_epoch}")
# 训练循环for epoch in range(start_epoch, CONFIG["EPOCHS"]): model.train() total_loss = 0 num_batches = 0
with tqdm(total=len(train_loader), desc=f"Epoch {epoch+1}/{CONFIG['EPOCHS']}") as pbar: for i, (x_batch, y_batch) in enumerate(train_loader): x_batch, y_batch = x_batch.to(device), y_batch.to(device) with autocast('cuda'): logits = model(x_batch) logits = logits.view(-1, vocab_size) targets = y_batch.view(-1) loss = loss_fn(logits, targets) / CONFIG["ACCUM_STEPS"] scaler.scale(loss).backward()
if (i + 1) % CONFIG["ACCUM_STEPS"] == 0: scaler.unscale_(optimizer) grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) scaler.step(optimizer) scaler.update() optimizer.zero_grad()
# 只在这里更新 progress bar pbar.set_postfix_str(f"\033[32mGradient norm: {grad_norm:.4f}\033[0m")
total_loss += loss.item() * CONFIG["ACCUM_STEPS"] num_batches += 1 pbar.update(1)
avg_train_loss = total_loss / num_batches if num_batches > 0 else float('inf')
# 验证 val_loss = evaluate(model, val_loader, loss_fn, device) print(f"Epoch {epoch+1}, Train Loss: {avg_train_loss:.4f}, Val Loss: {val_loss:.4f}")
# 保存最佳模型 if val_loss < best_val_loss: best_val_loss = val_loss counter = 0 torch.save(model.state_dict(), "best_model.pt") print(f"Best model saved at epoch {epoch+1}") else: counter += 1 if counter >= CONFIG["PATIENCE"]: print(f"Early stopping at epoch {epoch+1}") break
# 保存断点 torch.save({ "epoch": epoch, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "scheduler_state_dict": scheduler.state_dict(), "scaler_state_dict": scaler.state_dict(), "best_val_loss": best_val_loss, "counter": counter }, checkpoint_path)
scheduler.step()
# ===== 9. 改进的生成逻辑(带重复惩罚) =====def nucleus_sampling(logits, p=CONFIG["TOP_P"], repetition_penalty=CONFIG["REPETITION_PENALTY"], past_tokens=None): if past_tokens is None: past_tokens = [] logits = logits / CONFIG["TEMPERATURE"] # temperature 调整 probs = F.softmax(logits, dim=-1)
# 扩大重复惩罚作用范围 token_counts = {} for token_id in past_tokens[-100:]: # 改为最近 50 个 token token_counts[token_id] = token_counts.get(token_id, 0) + 1 for token_id, count in token_counts.items(): probs[token_id] /= repetition_penalty ** count
probs = probs / probs.sum() sorted_probs, sorted_indices = torch.sort(probs, descending=True) cumulative_probs = torch.cumsum(sorted_probs, dim=-1) mask = cumulative_probs <= p top_p_probs = sorted_probs[mask] top_p_indices = sorted_indices[mask] if top_p_probs.sum() == 0: top_p_probs = sorted_probs[:1] top_p_indices = sorted_indices[:1] top_p_probs = top_p_probs / top_p_probs.sum() next_id = top_p_indices[torch.multinomial(top_p_probs, 1).item()].item() return next_id
def generate(model, start_text, max_len=20, temperature=CONFIG["TEMPERATURE"], top_p=CONFIG["TOP_P"]): model.eval() input_ids = tokenizer.encode(start_text, add_special_tokens=False, max_length=CONFIG["MAX_SEQ_LEN"], truncation=True) output_text = start_text past_tokens = input_ids.copy()
for _ in range(max_len): if len(input_ids) > CONFIG["MAX_SEQ_LEN"]: input_ids = input_ids[-CONFIG["MAX_SEQ_LEN"]:] x_tensor = torch.tensor([input_ids]).to(device) with torch.no_grad(): with autocast('cuda'): logits = model(x_tensor) logits_last = logits[0, -1] / temperature next_id = nucleus_sampling(logits_last, p=top_p, repetition_penalty=CONFIG["REPETITION_PENALTY"], past_tokens=past_tokens) token = id2token.get(next_id, "[UNK]") output_text += token input_ids.append(next_id) past_tokens.append(next_id) return tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(input_ids))
# ===== 10. 测试生成 =====print("\n测试生成(Nucleus Sampling):")print("输入: Hello i am ->", generate(model, "Hello i am ", max_len=10))print("输入: Today is ->", generate(model, "Today is ", max_len=10))print("输入: I am ->", generate(model, "I am ", max_len=10))