瀏覽代碼

部分优化

gwhsss 2 月之前
父節點
當前提交
4cd619fd8b

+ 703 - 124
src/LinearAlgebra/deepleaning_demo.py

@@ -11,35 +11,35 @@ import glob
 
 # ==================== 全局配置参数 ====================
 
-# 模型配置 - 增大模型提高质量
+# 模型配置 - 优化训练速度
 MODEL_CONFIG = {
-    'n_layer': 8,  # 增加层数
-    'n_head': 8,  # 增加注意力头
-    'n_embd': 256,  # 增加嵌入维度
-    'max_seq_len': 512,
+    'n_layer': 4,  # 减少层数
+    'n_head': 4,  # 减少注意力头
+    'n_embd': 128,  # 减少嵌入维度
+    'max_seq_len': 256,
     'dropout': 0.1,
     'bias': True,
 }
 
-# 训练配置 - 优化训练参数
+# 训练配置 - 优化训练速度
 TRAINING_CONFIG = {
-    'epochs': 2000,
-    'batch_size': 16,
+    'epochs': 500,  # 减少训练轮数
+    'batch_size': 8,  # 减少批次大小
     'learning_rate': 6e-4,
     'block_size': 256,
     'weight_decay': 0.01,
     'grad_clip': 1.0,
-    'warmup_epochs': 50,
-    'min_loss': 0.05,  # 目标最小损失
+    'warmup_epochs': 20,  # 减少预热轮数
+    'min_loss': 0.05,
 }
 
 # 生成配置 - 设为全局,无需输入
 GENERATION_CONFIG = {
-    'max_tokens': 900,
-    'temperature': 0.7,
-    'top_k': 40,
-    'top_p': 0.85,
-    'repetition_penalty': 1.1,
+    'max_tokens': 300,  # 减少生成长度
+    'temperature': 0.5,  # 降低温度,提高确定性
+    'top_k': 20,  # 减少top_k,聚焦更可能的词汇
+    'top_p': 0.9,  # 调整top_p
+    'repetition_penalty': 1.2,  # 增加重复惩罚
 }
 
 # 文件配置
@@ -61,7 +61,7 @@ TRAINING_DATA_CONFIG = {
 # ==================== 模型类定义 ====================
 
 class GPTConfig:
-    """GPT模型配置类"""
+    """改进的GPT模型配置类,支持更多现代配置选项"""
 
     def __init__(
             self,
@@ -72,6 +72,9 @@ class GPTConfig:
             max_seq_len: int = MODEL_CONFIG['max_seq_len'],
             dropout: float = MODEL_CONFIG['dropout'],
             bias: bool = MODEL_CONFIG['bias'],
+            resid_pdrop: float = MODEL_CONFIG['dropout'],
+            attn_pdrop: float = MODEL_CONFIG['dropout'],
+            embd_pdrop: float = MODEL_CONFIG['dropout'],
     ):
         self.vocab_size = vocab_size
         self.n_layer = n_layer
@@ -80,13 +83,16 @@ class GPTConfig:
         self.max_seq_len = max_seq_len
         self.dropout = dropout
         self.bias = bias
+        self.resid_pdrop = resid_pdrop
+        self.attn_pdrop = attn_pdrop
+        self.embd_pdrop = embd_pdrop
 
     def __str__(self):
         return f"GPTConfig(vocab_size={self.vocab_size}, n_layer={self.n_layer}, n_head={self.n_head}, n_embd={self.n_embd})"
 
 
-class OptimizedCausalSelfAttention(nn.Module):
-    """优化的因果自注意力机制"""
+class CausalSelfAttention(nn.Module):
+    """改进的因果自注意力机制,使用现代最佳实践"""
 
     def __init__(self, config: GPTConfig):
         super().__init__()
@@ -96,48 +102,60 @@ class OptimizedCausalSelfAttention(nn.Module):
         self.n_embd = config.n_embd
         self.head_size = config.n_embd // config.n_head
 
+        # Q, K, V 和输出投影的线性层
         self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
         self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
-        self.attn_dropout = nn.Dropout(config.dropout)
-        self.resid_dropout = nn.Dropout(config.dropout)
+
+        # Dropout
+        self.attn_dropout = nn.Dropout(config.attn_pdrop)
+        self.resid_dropout = nn.Dropout(config.resid_pdrop)
 
         # 预计算因果掩码
         self.register_buffer("bias", torch.tril(torch.ones(config.max_seq_len, config.max_seq_len))
                              .view(1, 1, config.max_seq_len, config.max_seq_len))
 
+        # 使用缩放注意力
         self.scale = 1.0 / math.sqrt(self.head_size)
 
     def forward(self, x):
         B, T, C = x.size()
 
+        # 一次计算Q, K, V
         qkv = self.c_attn(x)
         q, k, v = qkv.split(self.n_embd, dim=2)
 
-        q = q.view(B, T, self.n_head, self.head_size).transpose(1, 2)
-        k = k.view(B, T, self.n_head, self.head_size).transpose(1, 2)
-        v = v.view(B, T, self.n_head, self.head_size).transpose(1, 2)
+        # 重塑并转置以进行多头注意力
+        q = q.view(B, T, self.n_head, self.head_size).transpose(1, 2)  # (B, nh, T, hs)
+        k = k.view(B, T, self.n_head, self.head_size).transpose(1, 2)  # (B, nh, T, hs)
+        v = v.view(B, T, self.n_head, self.head_size).transpose(1, 2)  # (B, nh, T, hs)
+
+        # 计算注意力权重
+        att = (q @ k.transpose(-2, -1)) * self.scale  # (B, nh, T, T)
 
-        att = (q @ k.transpose(-2, -1)) * self.scale
+        # 应用因果掩码
         att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))
         att = F.softmax(att, dim=-1)
         att = self.attn_dropout(att)
 
-        y = att @ v
-        y = y.transpose(1, 2).contiguous().view(B, T, C)
+        # 应用注意力权重到值
+        y = att @ v  # (B, nh, T, hs)
+        y = y.transpose(1, 2).contiguous().view(B, T, C)  # (B, T, C)
+
+        # 输出投影和残差连接
         y = self.resid_dropout(self.c_proj(y))
         return y
 
 
-class OptimizedMLP(nn.Module):
-    """优化的多层感知机"""
+class MLP(nn.Module):
+    """改进的多层感知机,使用GELU激活函数"""
 
     def __init__(self, config: GPTConfig):
         super().__init__()
-        intermediate_size = 4 * config.n_embd  # 恢复4倍维度
+        intermediate_size = 4 * config.n_embd  # FFN中间层扩展4倍
         self.c_fc = nn.Linear(config.n_embd, intermediate_size, bias=config.bias)
         self.gelu = nn.GELU()
         self.c_proj = nn.Linear(intermediate_size, config.n_embd, bias=config.bias)
-        self.dropout = nn.Dropout(config.dropout)
+        self.dropout = nn.Dropout(config.resid_pdrop)
 
     def forward(self, x):
         x = self.c_fc(x)
@@ -147,97 +165,119 @@ class OptimizedMLP(nn.Module):
         return x
 
 
-class OptimizedBlock(nn.Module):
-    """优化的Transformer块"""
+class Block(nn.Module):
+    """改进的Transformer块,使用预归一化和现代架构"""
 
     def __init__(self, config: GPTConfig):
         super().__init__()
+        # 使用预归一化(Pre-normalization)- 现代Transformer最佳实践
         self.ln_1 = nn.LayerNorm(config.n_embd, eps=1e-5)
-        self.attn = OptimizedCausalSelfAttention(config)
+        self.attn = CausalSelfAttention(config)
         self.ln_2 = nn.LayerNorm(config.n_embd, eps=1e-5)
-        self.mlp = OptimizedMLP(config)
+        self.mlp = MLP(config)
 
     def forward(self, x):
+        # 预归一化和残差连接
         x = x + self.attn(self.ln_1(x))
         x = x + self.mlp(self.ln_2(x))
         return x
 
 
-class OptimizedGPT(nn.Module):
-    """优化的GPT模型"""
+class GPT(nn.Module):
+    """改进的GPT模型,使用现代最佳实践"""
 
     def __init__(self, config: GPTConfig):
         super().__init__()
         self.config = config
 
+        # 词嵌入和位置嵌入
         self.wte = nn.Embedding(config.vocab_size, config.n_embd)
         self.wpe = nn.Embedding(config.max_seq_len, config.n_embd)
-        self.drop = nn.Dropout(config.dropout)
 
-        self.blocks = nn.ModuleList([OptimizedBlock(config) for _ in range(config.n_layer)])
+        # 嵌入层dropout
+        self.drop = nn.Dropout(config.embd_pdrop)
+
+        # Transformer块
+        self.blocks = nn.ModuleList([Block(config) for _ in range(config.n_layer)])
+
+        # 最终层归一化
         self.ln_f = nn.LayerNorm(config.n_embd, eps=1e-5)
+
+        # 语言模型头
         self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
 
-        # 权重绑定
+        # 权重绑定 - 将嵌入权重与LM头权重绑定
         self.wte.weight = self.lm_head.weight
 
+        # 初始化权重
         self.apply(self._init_weights)
 
+        # 特殊缩放初始化 - 对某些层使用不同的初始化
+        self.lm_head.weight.data.normal_(mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer))
+
     def _init_weights(self, module):
-        """权重初始化"""
+        """改进的权重初始化,使用GPT-2论文中的方法"""
         if isinstance(module, nn.Linear):
-            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
+            # 使用正态分布初始化,标准差根据层数调整
+            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02 / math.sqrt(2 * self.config.n_layer))
             if module.bias is not None:
                 torch.nn.init.zeros_(module.bias)
         elif isinstance(module, nn.Embedding):
             torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
 
-    def forward(self, idx, targets=None):
+    def forward(self, idx, targets=None, return_logits=True):
         device = idx.device
         b, t = idx.size()
 
         assert t <= self.config.max_seq_len, f"序列长度{t}超过最大长度{self.config.max_seq_len}"
 
+        # 位置嵌入
         pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0)
 
+        # 词嵌入和位置嵌入相加
         tok_emb = self.wte(idx)
         pos_emb = self.wpe(pos)
         x = self.drop(tok_emb + pos_emb)
 
+        # 通过所有Transformer块
         for block in self.blocks:
             x = block(x)
 
+        # 最终层归一化
         x = self.ln_f(x)
 
+        # 计算logits
         if targets is not None:
             logits = self.lm_head(x)
-            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
+            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
         else:
-            logits = self.lm_head(x)
+            logits = self.lm_head(x) if return_logits else None
             loss = None
 
         return logits, loss
 
     def generate(self, idx, max_new_tokens=None, temperature=None, top_k=None, top_p=None, repetition_penalty=None):
-        """生成文本 - 增强版"""
+        """改进的生成方法,支持多种采样策略"""
         max_new_tokens = max_new_tokens or GENERATION_CONFIG['max_tokens']
         temperature = temperature or GENERATION_CONFIG['temperature']
         top_k = top_k if top_k is not None else GENERATION_CONFIG['top_k']
         top_p = top_p if top_p is not None else GENERATION_CONFIG['top_p']
         repetition_penalty = repetition_penalty or GENERATION_CONFIG['repetition_penalty']
 
-        generated_sequence = []
+        generated_sequence = idx[0].tolist()  # 用于重复惩罚
 
         for _ in range(max_new_tokens):
+            # 确保上下文长度不超过最大序列长度
             idx_cond = idx if idx.size(1) <= self.config.max_seq_len else idx[:, -self.config.max_seq_len:]
 
+            # 前向传播
             logits, _ = self(idx_cond)
-            logits = logits[:, -1, :]
+            logits = logits[:, -1, :]  # 只取最后一个位置的logits
 
             # 重复惩罚
-            if repetition_penalty != 1.0:
+            if repetition_penalty != 1.0 and len(generated_sequence) > 0:
                 for token in set(generated_sequence):
-                    logits[0, token] /= repetition_penalty
+                    logits[0, token] = logits[0, token] / repetition_penalty
 
             # 温度调节
             if temperature != 1.0:
@@ -245,28 +285,32 @@ class OptimizedGPT(nn.Module):
 
             # Top-K 过滤
             if top_k is not None and top_k > 0:
-                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
-                logits[logits < v[:, -1].unsqueeze(-1)] = -float('Inf')
+                values, _ = torch.topk(logits, min(top_k, logits.size(-1)))
+                logits[logits < values[:, -1].unsqueeze(-1)] = float('-inf')
 
             # Top-P (核采样) 过滤
             if top_p is not None and top_p < 1.0:
                 sorted_logits, sorted_indices = torch.sort(logits, descending=True)
                 cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
 
-                # 移除累积概率超过top_p的token
+                # 创建掩码,移除累积概率超过top_p的token
                 sorted_indices_to_remove = cumulative_probs > top_p
                 sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
                 sorted_indices_to_remove[..., 0] = 0
 
+                # 将掩码应用到原始logits
                 indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
-                logits[indices_to_remove] = -float('Inf')
+                logits[indices_to_remove] = float('-inf')
 
+            # 计算最终概率分布
             probs = F.softmax(logits, dim=-1)
 
-            # 检查是否有有效的概率
-            if torch.all(probs == 0):
+            # 检查概率分布是否有效
+            if torch.isnan(probs).any() or torch.isinf(probs).any():
+                print("警告: 检测到无效概率值,停止生成")
                 break
 
+            # 从概率分布中采样
             idx_next = torch.multinomial(probs, num_samples=1)
             generated_sequence.append(idx_next.item())
             idx = torch.cat((idx, idx_next), dim=1)
@@ -330,7 +374,7 @@ class CharTokenizer:
 # ==================== 自动数据收集 ====================
 
 def collect_training_data():
-    """自动收集训练数据"""
+    """自动收集训练数据 - 增强版"""
     data_sources = []
 
     # 1. 使用内置的编程数据
@@ -612,7 +656,9 @@ if __name__ == "__main__":
                 with open(file_path, 'r', encoding='utf-8') as f:
                     content = f.read().strip()
                     if len(content) > TRAINING_DATA_CONFIG['min_text_length']:
-                        data_sources.append(content)
+                        # 应用高级预处理
+                        processed_content = preprocess_training_data(content)
+                        data_sources.append(processed_content)
                         print(f"✅ 加载数据文件: {data_file} ({len(content)} 字符)")
             except Exception as e:
                 print(f"⚠ 读取数据文件 {data_file} 时出错: {e}")
@@ -730,6 +776,9 @@ if __name__ == "__main__":
     # 合并所有数据源
     combined_data = '\n'.join(data_sources)
 
+    # 应用数据增强
+    combined_data = augment_data(combined_data)
+
     # 数据重复以增加训练样本
     combined_data = combined_data * TRAINING_DATA_CONFIG['data_repetition']
 
@@ -744,15 +793,18 @@ class EnhancedTrainingMonitor:
 
     def __init__(self, save_dir: str = FILE_CONFIG['save_dir']):
         self.losses = []
+        self.val_losses = []  # 添加验证损失记录
         self.start_time = time.time()
         self.save_dir = save_dir
         os.makedirs(save_dir, exist_ok=True)
         self.best_loss = float('inf')
         self.patience = 100
-        self.checkpoint_frequency = 50  # 每50轮保存一次
+        self.checkpoint_frequency = 100  # 每100轮保存一次,减少I/O开销
 
-    def update(self, loss, epoch, model=None, tokenizer=None):
+    def update(self, loss, epoch, model=None, tokenizer=None, val_loss=None):
         self.losses.append(loss)
+        if val_loss is not None:
+            self.val_losses.append(val_loss)
         elapsed = time.time() - self.start_time
 
         # 进度显示
@@ -760,8 +812,12 @@ class EnhancedTrainingMonitor:
             epochs_per_sec = (epoch + 1) / elapsed
             eta = (TRAINING_CONFIG['epochs'] - epoch - 1) / epochs_per_sec if epochs_per_sec > 0 else 0
 
-            print(f"Epoch {epoch:4d} | Loss: {loss:.4f} | "
-                  f"Speed: {epochs_per_sec:.2f} epoch/s | ETA: {eta:.0f}s")
+            if val_loss is not None:
+                print(f"Epoch {epoch:4d} | Train Loss: {loss:.4f} | Val Loss: {val_loss:.4f} | "
+                      f"Speed: {epochs_per_sec:.2f} epoch/s | ETA: {eta:.0f}s")
+            else:
+                print(f"Epoch {epoch:4d} | Loss: {loss:.4f} | "
+                      f"Speed: {epochs_per_sec:.2f} epoch/s | ETA: {eta:.0f}s")
 
             # 定期保存检查点
             if model and epoch % self.checkpoint_frequency == 0 and epoch > 0:
@@ -794,6 +850,7 @@ class EnhancedTrainingMonitor:
             'config': model.config,
             'tokenizer': tokenizer.stoi,
             'training_losses': self.losses,
+            'val_losses': self.val_losses,  # 保存验证损失
             'epoch': epoch,
             'loss': loss,
             'timestamp': datetime.now().strftime("%Y%m%d_%H%M%S"),
@@ -813,7 +870,17 @@ class EnhancedTrainingMonitor:
         try:
             import matplotlib.pyplot as plt
             plt.figure(figsize=(12, 6))
-            plt.plot(self.losses)
+
+            # 如果有验证损失,绘制两条线
+            if self.val_losses:
+                # 为验证损失创建对应的epoch索引
+                val_epochs = [i * 20 for i in range(len(self.val_losses))]  # 每20个epoch验证一次
+                plt.plot(self.losses, label='Training Loss', alpha=0.7)
+                plt.plot(val_epochs, self.val_losses, label='Validation Loss', alpha=0.7)
+                plt.legend()
+            else:
+                plt.plot(self.losses)
+
             plt.title('Training Loss Progress')
             plt.xlabel('Epoch')
             plt.ylabel('Loss')
@@ -836,15 +903,34 @@ def get_improved_learning_rate(epoch, warmup_epochs=20):
         return TRAINING_CONFIG['learning_rate'] * 0.5 * (1 + math.cos(math.pi * progress))
 
 
-def improved_train_gpt(model: OptimizedGPT, X: torch.Tensor, Y: torch.Tensor, tokenizer: CharTokenizer):
-    """改进的训练函数"""
-    model.train()
+def improved_train_gpt(model: GPT, X: torch.Tensor, Y: torch.Tensor, tokenizer: CharTokenizer):
+    """改进的训练函数,包含验证集、更好的正则化和更全面的监控"""
+    # 分割训练和验证数据
+    total_samples = len(X)
+    val_size = max(1, int(0.1 * total_samples))  # 10% 作为验证集
+    indices = torch.randperm(total_samples)
+    val_indices = indices[:val_size]
+    train_indices = indices[val_size:]
+
+    X_train, Y_train = X[train_indices], Y[train_indices]
+    X_val, Y_val = X[val_indices], Y[val_indices]
+
+    print(f"📊 训练样本: {len(X_train):,}, 验证样本: {len(X_val):,}")
 
     optimizer = torch.optim.AdamW(
         model.parameters(),
         lr=TRAINING_CONFIG['learning_rate'],
         weight_decay=TRAINING_CONFIG['weight_decay'],
-        betas=(0.9, 0.95)
+        betas=(0.9, 0.95),
+        eps=1e-8  # 添加数值稳定性
+    )
+
+    # 学习率调度器
+    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
+        optimizer,
+        T_0=TRAINING_CONFIG['warmup_epochs'] * 2,  # 周期长度
+        T_mult=1,  # 每个重启后周期长度的乘数
+        eta_min=1e-6  # 最小学习率
     )
 
     monitor = EnhancedTrainingMonitor(FILE_CONFIG['save_dir'])
@@ -852,29 +938,35 @@ def improved_train_gpt(model: OptimizedGPT, X: torch.Tensor, Y: torch.Tensor, to
     print(f"🚀 开始训练GPT模型")
     print(f"📊 总轮数: {TRAINING_CONFIG['epochs']}")
     print(f"🔢 模型参数: {sum(p.numel() for p in model.parameters()):,}")
-    print(f"📚 训练样本: {len(X):,}")
+    print(f"📚 训练样本: {len(X_train):,}")
+    print(f"📋 验证样本: {len(X_val):,}")
 
     # 设备设置
     device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
     print(f"💻 使用设备: {device}")
     model.to(device)
-    X, Y = X.to(device), Y.to(device)
+    X_train, Y_train = X_train.to(device), Y_train.to(device)
+    X_val, Y_val = X_val.to(device), Y_val.to(device)
 
-    best_loss = float('inf')
+    best_val_loss = float('inf')
     patience_counter = 0
 
-    for epoch in range(TRAINING_CONFIG['epochs']):
-        # 动态学习率
+    # 训练历史记录
+    train_losses = []
+    val_losses = []
+
+    for epoch in range(TRAINING_CONFIG['epochs']):  # 动态学习率
         lr = get_improved_learning_rate(epoch, TRAINING_CONFIG['warmup_epochs'])
         for param_group in optimizer.param_groups:
             param_group['lr'] = lr
 
         # 训练步骤
+        model.train()
         optimizer.zero_grad()
 
-        indices = torch.randint(0, len(X), (TRAINING_CONFIG['batch_size'],))
-        x_batch = X[indices]
-        y_batch = Y[indices]
+        indices = torch.randint(0, len(X_train), (TRAINING_CONFIG['batch_size'],))
+        x_batch = X_train[indices]
+        y_batch = Y_train[indices]
 
         logits, loss = model(x_batch, y_batch)
         loss.backward()
@@ -883,73 +975,169 @@ def improved_train_gpt(model: OptimizedGPT, X: torch.Tensor, Y: torch.Tensor, to
         torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=TRAINING_CONFIG['grad_clip'])
         optimizer.step()
 
-        monitor.update(loss.item(), epoch, model, tokenizer)
+        # 更新学习率调度器
+        scheduler.step()
 
-        # 早停检查
-        if loss.item() < best_loss:
-            best_loss = loss.item()
+        # 验证步骤(每20个epoch执行一次,减少验证开销)
+        val_loss = float('inf')
+        if epoch % 20 == 0 or epoch == TRAINING_CONFIG['epochs'] - 1:
+            model.eval()
+            with torch.no_grad():
+                val_indices_batch = torch.randint(0, len(X_val), (min(TRAINING_CONFIG['batch_size'], len(X_val)),))
+                x_val_batch = X_val[val_indices_batch]
+                y_val_batch = Y_val[val_indices_batch]
+                _, val_loss_val = model(x_val_batch, y_val_batch)
+                val_loss = val_loss_val.item()
+                val_losses.append(val_loss)
+
+        # 记录训练损失
+        train_loss = loss.item()
+        train_losses.append(train_loss)
+
+        # 监控训练过程
+        monitor.update(train_loss, epoch, model, tokenizer, val_loss)
+
+        # 早停检查(基于验证损失)
+        if val_loss < best_val_loss:
+            best_val_loss = val_loss
             patience_counter = 0
+
+            # 保存最佳模型
+            if epoch > 50:  # 避免早期保存
+                monitor.save_best_model(model, tokenizer, epoch, val_loss)
         else:
             patience_counter += 1
 
         # 早停条件
-        if patience_counter >= monitor.patience and epoch > 300:
-            print(f"🛑 早停触发,第{epoch}轮")
+        if patience_counter >= monitor.patience and epoch > 50:
+            print(f"🛑 早停触发,第{epoch}轮,最佳验证损失: {best_val_loss:.4f}")
             break
 
         # 损失足够小提前停止
-        if loss.item() < TRAINING_CONFIG['min_loss'] and epoch > 200:
-            print(f"✅ 训练完成,损失已达目标值 {loss.item():.4f}")
+        if train_loss < TRAINING_CONFIG['min_loss'] and epoch > 200:
+            print(f"✅ 训练完成,训练损失已达目标值 {train_loss:.4f}")
             break
 
     print("🎉 训练完成!")
     monitor.plot_loss()
-    return monitor.losses
+    return train_losses, val_losses
 
 
 # ==================== 工具函数 ====================
 
 def create_improved_sample_dataset(text: str, block_size: int = None) -> Tuple[
     torch.Tensor, torch.Tensor, 'CharTokenizer']:
-    """创建改进的训练数据集"""
+    """创建改进的训练数据集,包含更好的数据预处理和增强"""
     block_size = block_size or TRAINING_CONFIG['block_size']
 
-    # 文本预处理
+    # 高级文本预处理
     lines = text.split('\n')
     cleaned_lines = []
     for line in lines:
         line = line.strip()
         if line and not line.startswith('#'):  # 移除空行和注释
-            cleaned_lines.append(line)
+            # 更高级的预处理
+            if not line.startswith('"""') and not line.startswith("'''"):  # 移除多行字符串开始
+                cleaned_lines.append(line)
     text = '\n'.join(cleaned_lines)
 
-    # 数据增强
-    text = text * TRAINING_DATA_CONFIG['data_repetition']
+    # 数据增强 - 添加更多的数据处理步骤
+    enhanced_text = text
+
+    # 随机添加一些特殊标记来帮助模型学习
+    special_chars = ['<|startoftext|>']
+    for char in special_chars:
+        enhanced_text = enhanced_text.replace('\n', f'\n{char}\n', 1)  # 在文本开头添加特殊标记
 
-    tokenizer = CharTokenizer(text)
+    tokenizer = CharTokenizer(enhanced_text)
 
-    if len(text) < block_size + 1:
-        print("⚠ 文本较短,使用重叠采样")
-        data = torch.tensor(tokenizer.encode(text), dtype=torch.long)
+    if len(enhanced_text) < block_size + 1:
+        print("⚠ 文本较短,使用重叠采样和数据增强")
+        data = torch.tensor(tokenizer.encode(enhanced_text), dtype=torch.long)
         while len(data) < block_size + 1000:
             data = torch.cat([data, data])
         data = data[:block_size + 2000]
     else:
-        data = torch.tensor(tokenizer.encode(text), dtype=torch.long)
+        data = torch.tensor(tokenizer.encode(enhanced_text), dtype=torch.long)
 
     n = len(data) - block_size
     if n <= 0:
         raise ValueError("无法创建训练样本")
 
-    # 创建训练样本
+    # 创建训练样本 - 使用滑动窗口方法
     X = torch.stack([data[i:i + block_size] for i in range(0, n, 1)])  # 步长为1获取更多样本
     Y = torch.stack([data[i + 1:i + block_size + 1] for i in range(0, n, 1)])
 
     print(f"✅ 创建了 {len(X):,} 个训练样本")
     print(f"🔤 词汇表大小: {tokenizer.vocab_size}")
+    print(f"📊 数据形状: X={X.shape}, Y={Y.shape}")
     return X, Y, tokenizer
 
 
+def preprocess_training_data(text: str) -> str:
+    """高级文本预处理,清理和标准化训练数据"""
+    # 移除多余的空白字符
+    import re
+
+    # 标准化空白字符
+    text = re.sub(r'\s+', ' ', text)
+
+    # 修复不匹配的引号和括号(如果可能)
+    # 移除不完整的行
+    lines = text.split('\n')
+    processed_lines = []
+
+    for line in lines:
+        stripped = line.strip()
+        if stripped:  # 非空行
+            # 确保代码缩进一致性
+            original_indent = len(line) - len(line.lstrip())
+            indent_spaces = ' ' * (original_indent // 4 * 4)  # 标准化为4的倍数
+            processed_line = indent_spaces + stripped
+            processed_lines.append(processed_line)
+
+    return '\n'.join(processed_lines)
+
+
+def augment_data(text: str, augmentation_factor: float = 0.1) -> str:
+    """数据增强函数 - 增加训练数据的多样性"""
+    import random
+
+    lines = text.split('\n')
+    augmented_lines = lines.copy()
+
+    # 随机插入一些常见的代码模式来增强数据
+    common_patterns = [
+        "# This is a sample comment",
+        "print('Debug: value')",
+        "# TODO: Implement this function",
+        "assert condition, 'Error message'",
+        "pass  # Placeholder for implementation",
+        "# Example usage:",
+        "if __name__ == '__main__':",
+        "try:",
+        "except Exception as e:",
+        "finally:",
+        "for i in range(len(items)):",
+        "while condition:",
+        "def helper_function():",
+        "return result",
+        "continue",
+        "break",
+        "# Performance optimization",
+        "# Memory efficient implementation"
+    ]
+
+    # 随机选择一些模式插入到文本中
+    num_insertions = int(len(lines) * augmentation_factor)
+    for _ in range(num_insertions):
+        insert_pos = random.randint(0, len(augmented_lines))
+        pattern = random.choice(common_patterns)
+        augmented_lines.insert(insert_pos, pattern)
+
+    return '\n'.join(augmented_lines)
+
+
 def format_generated_text(text: str, start_text: str) -> str:
     """格式化生成的文本"""
     # 移除起始文本
@@ -958,27 +1146,37 @@ def format_generated_text(text: str, start_text: str) -> str:
     else:
         generated_part = text
 
-    # 清理文本
+    # 清理文本 - 更智能的清理方式
     lines = generated_part.split('\n')
     cleaned_lines = []
 
     for line in lines:
         line = line.strip()
         if line:
+            # 检查是否有过多的重复字符
+            if len(set(line)) < len(line) * 0.3:  # 如果重复字符过多,则跳过
+                continue
+
             # 简单的代码格式检测
-            if any(keyword in line for keyword in ['def ', 'class ', 'import ', 'from ', 'if ', 'for ', 'while ']):
+            if any(keyword in line for keyword in
+                   ['def ', 'class ', 'import ', 'from ', 'if ', 'for ', 'while ', 'try:', 'except', 'with ', 'return',
+                    'print(']):
                 cleaned_lines.append(line)
             elif line.startswith('#') or line.startswith('"""') or line.startswith("'''"):
                 cleaned_lines.append(line)
-            elif '=' in line or ':' in line or line.endswith(':'):
+            elif '=' in line or ':' in line or line.endswith(':') or 'for ' in line or 'in ' in line:
                 cleaned_lines.append(line)
-            elif len(line) > 10:  # 保留较长的文本
+            elif len(line) > 10 and len([c for c in line if c.isalpha()]) > 5:  # 保留有一定字母数的
                 cleaned_lines.append(line)
 
+            # 限制行数,避免过长
+            if len(cleaned_lines) > 20:
+                break
+
     return '\n'.join(cleaned_lines)
 
 
-def interactive_generation(model: OptimizedGPT, tokenizer: CharTokenizer):
+def interactive_generation(model: GPT, tokenizer: CharTokenizer):
     """改进的交互式文本生成 - 使用全局参数"""
     print("\n" + "=" * 60)
     print("🤖 进入交互式生成模式")
@@ -1033,14 +1231,15 @@ def interactive_generation(model: OptimizedGPT, tokenizer: CharTokenizer):
             print("-" * 50)
             if formatted_text:
                 print(formatted_text)
+                print(f"📏 格式化后长度: {len(formatted_text)} 字符")
             else:
                 # 如果格式化后为空,显示原始生成文本(截断)
                 display_text = full_text[len(user_input):]
                 if len(display_text) > 300:
                     display_text = display_text[:300] + "..."
                 print(display_text)
+                print(f"📏 原始生成长度: {len(display_text)} 字符")
             print("=" * 50)
-            print(f"📏 总长度: {len(full_text)} 字符")
 
         except KeyboardInterrupt:
             print("\n\n🛑 用户中断,退出交互模式")
@@ -1095,7 +1294,7 @@ def select_model_interactively() -> str:
             print("⚠ 请输入有效的数字")
 
 
-class AdvancedGPT(OptimizedGPT):
+class AdvancedGPT(GPT):
     """增强版GPT"""
 
     def __init__(self, config: GPTConfig):
@@ -1105,7 +1304,13 @@ class AdvancedGPT(OptimizedGPT):
     def from_pretrained(cls, model_path: str, weights_only=False):
         """从预训练文件加载模型"""
         try:
-            checkpoint = torch.load(model_path, map_location='cpu', weights_only=weights_only)
+            # 尝试不同的加载方式
+            try:
+                checkpoint = torch.load(model_path, map_location='cpu', weights_only=weights_only)
+            except:
+                # 如果weights_only=True失败,尝试设置为False
+                checkpoint = torch.load(model_path, map_location='cpu', weights_only=False)
+
             config = checkpoint['config']
             model = cls(config)
             model.load_state_dict(checkpoint['model_state_dict'])
@@ -1113,6 +1318,7 @@ class AdvancedGPT(OptimizedGPT):
             return model, checkpoint.get('tokenizer', None), checkpoint.get('training_losses', []), global_configs
         except Exception as e:
             print(f"❌ 加载模型时出错: {e}")
+            print("💡 提示: 模型文件可能已损坏或版本不兼容,建议重新训练")
             raise
 
 
@@ -1213,12 +1419,12 @@ def main():
         config = GPTConfig(vocab_size=tokenizer.vocab_size)
         print(f"   {config}")
 
-        model = OptimizedGPT(config)
+        model = GPT(config)  # 使用新的GPT模型类
         print(f"   ✅ 参数数量: {sum(p.numel() for p in model.parameters()):,}")
 
         # 4. 训练模型
         print("\n4. 🏋️  训练模型...")
-        losses = improved_train_gpt(model, X, Y, tokenizer)
+        train_losses, val_losses = improved_train_gpt(model, X, Y, tokenizer)  # 获取训练和验证损失
 
         # 5. 保存最终模型
         print("\n5. 💾 保存最终模型...")
@@ -1230,7 +1436,8 @@ def main():
             'model_state_dict': model.state_dict(),
             'config': config,
             'tokenizer': tokenizer.stoi,
-            'training_losses': losses,
+            'training_losses': train_losses,
+            'val_losses': val_losses,  # 保存验证损失
             'timestamp': timestamp,
             'global_configs': {
                 'MODEL_CONFIG': MODEL_CONFIG,
@@ -1253,31 +1460,403 @@ def main():
         traceback.print_exc()
 
 
-def auto_detect_and_run():
-    """自动检测并运行"""
-    print("🔍 GPT语言模型自动检测系统")
-    print("=" * 50)
+def interactive_menu():
+    """交互式菜单"""
+    while True:
+        print("\n" + "=" * 60)
+        print("🤖 GPT语言模型系统 - 交互式菜单")
+        print("=" * 60)
+        print("1. 🏋️  训练新模型")
+        print("2. 💬 使用现有模型进行提问")
+        print("3. 🔍 模型评估")
+        print("4. 🚪 退出")
+        print("=" * 60)
+
+        choice = input("请选择操作 (1-4): ").strip()
+
+        if choice == "1":
+            print("\n🚀 开始训练新模型...")
+            main()
+        elif choice == "2":
+            print("\n💬 开始使用现有模型提问...")
+            use_existing_model()
+        elif choice == "3":
+            print("\n🔬 开始模型评估...")
+            evaluate_existing_model()
+        elif choice == "4":
+            print("\n👋 感谢使用,再见!")
+            break
+        else:
+            print("⚠️ 无效选择,请输入 1-4")
+
 
+def use_existing_model():
+    """使用现有模型进行提问"""
     available_models = get_available_models()
 
-    if available_models:
-        print(f"📂 检测到 {len(available_models)} 个现有模型")
-        print("🔄 自动加载最新模型...")
+    if not available_models:
+        print("❌ 未检测到任何现有模型")
+        train_choice = input("是否要训练一个新模型?(y/n): ").strip().lower()
+        if train_choice == 'y':
+            main()
+        return
 
-        latest_model_path = available_models[0][1]
-        model, tokenizer = load_and_test_model(latest_model_path)
+    print(f"\n📂 检测到 {len(available_models)} 个现有模型")
+    for i, (filename, _) in enumerate(available_models, 1):
+        print(f"   {i}. {filename}")
 
-        if model and tokenizer:
-            print("\n✅ 模型加载成功,进入交互式生成模式")
-            interactive_generation(model, tokenizer)
+    try:
+        model_idx = int(input(f"\n选择模型 (1-{len(available_models)}), 或按回车使用最新模型 [默认: 1]: ") or "1") - 1
+        if 0 <= model_idx < len(available_models):
+            selected_model_path = available_models[model_idx][1]
         else:
-            print("❌ 模型加载失败,开始训练新模型...")
-            main()
-    else:
-        print("❌ 未检测到现有模型,开始训练新模型...")
-        main()
+            selected_model_path = available_models[0][1]  # 默认使用第一个
+            print("⚠️ 输入超出范围,使用最新模型")
+    except ValueError:
+        selected_model_path = available_models[0][1]  # 默认使用第一个
+        print("⚠️ 无效输入,使用最新模型")
+
+    print(f"\n📥 加载模型: {selected_model_path}")
+    try:
+        model, tokenizer_dict, losses, global_configs = AdvancedGPT.from_pretrained(selected_model_path,
+                                                                                    weights_only=False)
+        tokenizer = CharTokenizer(stoi=tokenizer_dict)
+        print("✅ 模型加载成功")
+        print(f"🔤 词汇表大小: {tokenizer.vocab_size}")
+
+        print("\n🎯 进入交互式提问模式")
+        interactive_generation(model, tokenizer)
+
+    except Exception as e:
+        print(f"❌ 加载模型失败: {e}")
+        print("💡 建议重新训练模型")
+
+
+def evaluate_existing_model():
+    """评估现有模型"""
+    available_models = get_available_models()
+
+    if not available_models:
+        print("❌ 未检测到任何现有模型")
+        return
+
+    print(f"\n📂 检测到 {len(available_models)} 个现有模型")
+    for i, (filename, _) in enumerate(available_models, 1):
+        print(f"   {i}. {filename}")
+
+    try:
+        model_idx = int(input(f"\n选择模型 (1-{len(available_models)}), 或按回车使用最新模型 [默认: 1]: ") or "1") - 1
+        if 0 <= model_idx < len(available_models):
+            selected_model_path = available_models[model_idx][1]
+        else:
+            selected_model_path = available_models[0][1]  # 默认使用第一个
+            print("⚠️ 输入超出范围,使用最新模型")
+    except ValueError:
+        selected_model_path = available_models[0][1]  # 默认使用第一个
+        print("⚠️ 无效输入,使用最新模型")
+
+    print(f"\n📥 加载模型: {selected_model_path}")
+    try:
+        model, tokenizer_dict, losses, global_configs = AdvancedGPT.from_pretrained(selected_model_path,
+                                                                                    weights_only=False)
+        tokenizer = CharTokenizer(stoi=tokenizer_dict)
+        print("✅ 模型加载成功")
+
+        # 创建测试数据进行评估
+        test_text = "def hello_world():\n    print('Hello, World!')\n    return True\n\nclass TestClass:\n    def __init__(self):\n        self.value = 42"
+        X_test, Y_test, _ = create_improved_sample_dataset(test_text)
+
+        print("\n🔬 开始模型评估")
+        metrics = comprehensive_model_evaluation(model, X_test, Y_test, tokenizer)
+
+    except Exception as e:
+        print(f"❌ 评估模型失败: {e}")
+
+
+if __name__ == "__main__":
+    # 启动交互式菜单
+    interactive_menu()
+
+
+def evaluate_model(model: GPT, X_val: torch.Tensor, Y_val: torch.Tensor, tokenizer: CharTokenizer,
+                   device: torch.device = None):
+    """评估模型性能"""
+    if device is None:
+        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+
+    model.eval()
+    model.to(device)
+
+    X_val, Y_val = X_val.to(device), Y_val.to(device)
+
+    with torch.no_grad():
+        # 计算验证损失
+        logits, val_loss = model(X_val, Y_val)
+
+        # 计算准确率(top-1)
+        predictions = torch.argmax(logits, dim=-1)
+        accuracy = (predictions == Y_val).float().mean()
+
+        # 计算困惑度
+        perplexity = torch.exp(val_loss)
+
+        # 计算top-k准确率
+        k = min(5, logits.size(-1))  # 确保k不超过词汇表大小
+        top_k_acc = calculate_top_k_accuracy(logits, Y_val, k)
+
+    return {
+        'loss': val_loss.item(),
+        'accuracy': accuracy.item(),
+        'perplexity': perplexity.item(),
+        'top_k_accuracy': top_k_acc,
+        'num_samples': len(X_val)
+    }
+
+
+def calculate_top_k_accuracy(logits: torch.Tensor, targets: torch.Tensor, k: int = 5):
+    """计算top-k准确率"""
+    with torch.no_grad():
+        # 获取top-k预测
+        top_k_predictions = torch.topk(logits, k, dim=-1).indices
+        # 检查目标是否在top-k预测中
+        targets_expanded = targets.unsqueeze(-1).expand_as(top_k_predictions)
+        correct = (top_k_predictions == targets_expanded).any(dim=-1).float()
+        accuracy = correct.mean()
+        return accuracy.item()
+
+
+def visualize_attention(model: GPT, tokenizer: CharTokenizer, text: str, device: torch.device = None):
+    """可视化注意力权重(如果模型支持)"""
+    print("🔍 注意力可视化功能待实现")
+    # 这里可以实现注意力权重的可视化
+    # 目前我们的模型没有返回注意力权重,所以暂时留空
+
+
+def plot_training_curves(train_losses: list, val_losses: list = None, save_path: str = None):
+    """绘制训练曲线"""
+    try:
+        import matplotlib.pyplot as plt
+
+        plt.figure(figsize=(12, 5))
+
+        # 训练损失
+        plt.subplot(1, 2, 1)
+        plt.plot(train_losses, label='Training Loss', alpha=0.7)
+        plt.title('Training Loss')
+        plt.xlabel('Epoch')
+        plt.ylabel('Loss')
+        plt.grid(True, alpha=0.3)
+        plt.legend()
+
+        # 验证损失(如果有)
+        if val_losses:
+            plt.subplot(1, 2, 2)
+            # 为验证损失创建对应的epoch索引
+            val_epochs = [i * 20 for i in range(len(val_losses))]  # 每20个epoch验证一次
+            plt.plot(val_losses, label='Validation Loss', alpha=0.7)
+            plt.title('Validation Loss')
+            plt.xlabel('Epoch')
+            plt.ylabel('Loss')
+            plt.grid(True, alpha=0.3)
+            plt.legend()
+
+        if save_path:
+            plt.savefig(save_path, dpi=150, bbox_inches='tight')
+            print(f"📊 训练曲线已保存: {save_path}")
+        else:
+            plt.show()
+
+    except ImportError:
+        print("⚠ 未安装matplotlib,无法绘制训练曲线")
+
+
+def generate_sample_text(model: GPT, tokenizer: CharTokenizer, prompt: str = "", max_length: int = 100,
+                         device: torch.device = None):
+    """生成示例文本以评估模型"""
+    if device is None:
+        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+
+    model.eval()
+    model.to(device)
+
+    # 编码提示文本
+    prompt_tokens = torch.tensor([tokenizer.encode(prompt)], dtype=torch.long).to(device)
+
+    with torch.no_grad():
+        generated_tokens = model.generate(
+            prompt_tokens,
+            max_new_tokens=max_length,
+            temperature=0.8,
+            top_k=50,
+            top_p=0.9
+        )
+
+    generated_text = tokenizer.decode(generated_tokens[0].tolist())
+    return generated_text
+
+
+def comprehensive_model_evaluation(model: GPT, X_val: torch.Tensor, Y_val: torch.Tensor, tokenizer: CharTokenizer):
+    """综合模型评估"""
+    print("🔬 开始综合模型评估...")
+
+    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+    print(f"💻 使用设备: {device}")
+
+    # 基本指标评估
+    metrics = evaluate_model(model, X_val, Y_val, tokenizer, device)
+
+    print("\n📋 评估结果:")
+    print(f"   损失: {metrics['loss']:.4f}")
+    print(f"   准确率: {metrics['accuracy']:.4f}")
+    print(f"   困惑度: {metrics['perplexity']:.4f}")
+    print(f"   Top-5 准确率: {metrics['top_k_accuracy']:.4f}")
+
+    # 生成示例文本
+    print("\n📝 生成示例:")
+    sample_prompts = ["def ", "class ", "import ", "for ", "if "]
+
+    for prompt in sample_prompts:
+        generated = generate_sample_text(model, tokenizer, prompt, max_length=50, device=device)
+        print(f"   提示: '{prompt}' -> 生成: {generated[len(prompt):60 + len(prompt)]}...")
+
+    return metrics
+
+
+def test_optimized_model():
+    """测试优化后的模型性能"""
+    print("🧪 开始测试优化后的模型...")
+
+    # 1. 创建测试数据
+    print("\n1. 📊 准备测试数据...")
+    test_text = """
+def hello_world():
+    print("Hello, World!")
+    return True
+
+class TestClass:
+    def __init__(self):
+        self.value = 42
+
+    def get_value(self):
+        return self.value
+
+# 测试循环
+for i in range(10):
+    print(f"Number: {i}")
+
+# 测试条件
+if True:
+    print("Condition is true")
+else:
+    print("Condition is false")
+"""
+
+    # 创建测试数据集
+    X_test, Y_test, test_tokenizer = create_improved_sample_dataset(test_text)
+    print(f"   ✅ 测试数据集: {len(X_test)} 样本")
+
+    # 2. 创建模型配置和模型
+    print("\n2. 🧠 创建优化后的模型...")
+    config = GPTConfig(vocab_size=test_tokenizer.vocab_size)
+    model = GPT(config)
+    print(f"   ✅ 模型参数: {sum(p.numel() for p in model.parameters()):,}")
+
+    # 3. 简单的训练测试(少量epoch)
+    print("\n3. 🏋️  运行简短训练测试...")
+    # 由于完整的训练需要时间,我们只验证模型结构是否正确
+    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+    model.to(device)
+    X_test, Y_test = X_test[:10].to(device), Y_test[:10].to(device)  # 使用小批量测试
+
+    # 前向传播测试
+    model.train()
+    logits, loss = model(X_test, Y_test)
+    print(f"   ✅ 前向传播成功,损失: {loss.item():.4f}")
+
+    # 反向传播测试
+    loss.backward()
+    print(f"   ✅ 反向传播成功")
+
+    # 4. 生成测试
+    print("\n4. ✍️  测试文本生成...")
+    model.eval()
+    with torch.no_grad():
+        start_tokens = torch.tensor([test_tokenizer.encode("def ")], dtype=torch.long).to(device)
+        generated = model.generate(start_tokens, max_new_tokens=20)
+        generated_text = test_tokenizer.decode(generated[0].tolist())
+        print(f"   提示: 'def '")
+        print(f"   生成: {generated_text}")
+
+    # 5. 模型评估测试
+    print("\n5. 📈 测试模型评估功能...")
+    eval_metrics = evaluate_model(model, X_test, Y_test, test_tokenizer, device)
+    print(f"   评估结果: {eval_metrics}")
+
+    # 6. 生成示例文本测试
+    print("\n6. 📝 测试生成示例...")
+    sample_text = generate_sample_text(model, test_tokenizer, "class ", max_length=30, device=device)
+    print(f"   生成示例: {sample_text}")
+
+    print("\n✅ 所有测试通过!优化后的模型功能正常。")
+
+    return model, test_tokenizer
+
+
+def run_performance_comparison():
+    """运行性能比较测试"""
+    print("🚀 开始性能比较测试...")
+
+    import time
+
+    # 测试模型初始化时间
+    print("\n1. ⏱️  测试模型初始化性能...")
+    start_time = time.time()
+    config = GPTConfig(vocab_size=1000)  # 使用小词汇表进行快速测试
+    model = GPT(config)
+    init_time = time.time() - start_time
+    print(f"   模型初始化时间: {init_time:.4f}s")
+
+    # 测试前向传播速度
+    print("\n2. ⚡ 测试前向传播速度...")
+    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+    model.to(device)
+
+    # 创建小批量测试数据
+    test_input = torch.randint(0, 1000, (4, 100)).to(device)  # 4个序列,每个100个token
+    test_target = torch.randint(0, 1000, (4, 100)).to(device)
+
+    start_time = time.time()
+    for _ in range(10):  # 运行10次
+        logits, loss = model(test_input, test_target)
+    forward_time = time.time() - start_time
+    print(f"   10次前向传播时间: {forward_time:.4f}s")
+    print(f"   平均每次前向传播时间: {forward_time / 10:.4f}s")
+
+    # 测试生成速度
+    print("\n3. 🚀 测试生成速度...")
+    start_time = time.time()
+    with torch.no_grad():
+        generated = model.generate(test_input[:1], max_new_tokens=50)
+    generation_time = time.time() - start_time
+    print(f"   生成50个token时间: {generation_time:.4f}s")
+
+    print(f"\n✅ 性能测试完成!")
+    return True
 
 
 if __name__ == "__main__":
-    # 自动检测并运行:有模型就加载,没有就训练
-    auto_detect_and_run()
+    # 运行性能测试
+    print("🔬 运行优化后的模型测试...")
+    print("=" * 60)
+
+    # 运行性能比较
+    run_performance_comparison()
+
+    # 测试优化后的模型
+    test_optimized_model()
+
+    print("\n" + "=" * 60)
+    print("🎉 所有测试完成!模型优化成功。")
+
+    # 旧的自动检测功能已替换为交互式菜单
+    # 如需训练,请运行主程序直接进入交互菜单

二進制
src/LinearAlgebra/my_gptmodel/checkpoint_epoch_100.pth


二進制
src/LinearAlgebra/my_gptmodel/checkpoint_epoch_200.pth


二進制
src/LinearAlgebra/my_gptmodel/checkpoint_epoch_300.pth


二進制
src/LinearAlgebra/my_gptmodel/checkpoint_epoch_400.pth


二進制
src/LinearAlgebra/my_gptmodel/checkpoint_epoch_500.pth


二進制
src/LinearAlgebra/my_gptmodel/checkpoint_epoch_600.pth


二進制
src/LinearAlgebra/my_gptmodel/checkpoint_epoch_700.pth


二進制
src/LinearAlgebra/my_gptmodel/checkpoint_epoch_800.pth


二進制
src/LinearAlgebra/my_gptmodel/gpt_model_best.pth