|
|
@@ -0,0 +1,1283 @@
|
|
|
+import torch
|
|
|
+import torch.nn as nn
|
|
|
+import torch.nn.functional as F
|
|
|
+import math
|
|
|
+import json
|
|
|
+import os
|
|
|
+import time
|
|
|
+from typing import List, Optional, Tuple, Dict
|
|
|
+from datetime import datetime
|
|
|
+import glob
|
|
|
+
|
|
|
+# ==================== 全局配置参数 ====================
|
|
|
+
|
|
|
+# 模型配置 - 增大模型提高质量
|
|
|
+MODEL_CONFIG = {
|
|
|
+ 'n_layer': 8, # 增加层数
|
|
|
+ 'n_head': 8, # 增加注意力头
|
|
|
+ 'n_embd': 256, # 增加嵌入维度
|
|
|
+ 'max_seq_len': 512,
|
|
|
+ 'dropout': 0.1,
|
|
|
+ 'bias': True,
|
|
|
+}
|
|
|
+
|
|
|
+# 训练配置 - 优化训练参数
|
|
|
+TRAINING_CONFIG = {
|
|
|
+ 'epochs': 2000,
|
|
|
+ 'batch_size': 16,
|
|
|
+ 'learning_rate': 6e-4,
|
|
|
+ 'block_size': 256,
|
|
|
+ 'weight_decay': 0.01,
|
|
|
+ 'grad_clip': 1.0,
|
|
|
+ 'warmup_epochs': 50,
|
|
|
+ 'min_loss': 0.05, # 目标最小损失
|
|
|
+}
|
|
|
+
|
|
|
+# 生成配置 - 设为全局,无需输入
|
|
|
+GENERATION_CONFIG = {
|
|
|
+ 'max_tokens': 900,
|
|
|
+ 'temperature': 0.7,
|
|
|
+ 'top_k': 40,
|
|
|
+ 'top_p': 0.85,
|
|
|
+ 'repetition_penalty': 1.1,
|
|
|
+}
|
|
|
+
|
|
|
+# 文件配置
|
|
|
+FILE_CONFIG = {
|
|
|
+ 'save_dir': "my_gptmodel",
|
|
|
+ 'training_data_file': "training_data.txt",
|
|
|
+ 'programming_data_file': "programming_data.txt",
|
|
|
+ 'model_prefix': "gpt_model",
|
|
|
+ 'tokenizer_prefix': "tokenizer",
|
|
|
+}
|
|
|
+
|
|
|
+# 训练数据配置
|
|
|
+TRAINING_DATA_CONFIG = {
|
|
|
+ 'data_repetition': 5, # 增加数据重复
|
|
|
+ 'min_text_length': 500,
|
|
|
+}
|
|
|
+
|
|
|
+
|
|
|
+# ==================== 模型类定义 ====================
|
|
|
+
|
|
|
+class GPTConfig:
|
|
|
+ """GPT模型配置类"""
|
|
|
+
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ vocab_size: int = 50257,
|
|
|
+ n_layer: int = MODEL_CONFIG['n_layer'],
|
|
|
+ n_head: int = MODEL_CONFIG['n_head'],
|
|
|
+ n_embd: int = MODEL_CONFIG['n_embd'],
|
|
|
+ max_seq_len: int = MODEL_CONFIG['max_seq_len'],
|
|
|
+ dropout: float = MODEL_CONFIG['dropout'],
|
|
|
+ bias: bool = MODEL_CONFIG['bias'],
|
|
|
+ ):
|
|
|
+ self.vocab_size = vocab_size
|
|
|
+ self.n_layer = n_layer
|
|
|
+ self.n_head = n_head
|
|
|
+ self.n_embd = n_embd
|
|
|
+ self.max_seq_len = max_seq_len
|
|
|
+ self.dropout = dropout
|
|
|
+ self.bias = bias
|
|
|
+
|
|
|
+ def __str__(self):
|
|
|
+ return f"GPTConfig(vocab_size={self.vocab_size}, n_layer={self.n_layer}, n_head={self.n_head}, n_embd={self.n_embd})"
|
|
|
+
|
|
|
+
|
|
|
+class OptimizedCausalSelfAttention(nn.Module):
|
|
|
+ """优化的因果自注意力机制"""
|
|
|
+
|
|
|
+ def __init__(self, config: GPTConfig):
|
|
|
+ super().__init__()
|
|
|
+ assert config.n_embd % config.n_head == 0
|
|
|
+
|
|
|
+ self.n_head = config.n_head
|
|
|
+ self.n_embd = config.n_embd
|
|
|
+ self.head_size = config.n_embd // config.n_head
|
|
|
+
|
|
|
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
|
|
|
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
|
|
|
+ self.attn_dropout = nn.Dropout(config.dropout)
|
|
|
+ self.resid_dropout = nn.Dropout(config.dropout)
|
|
|
+
|
|
|
+ # 预计算因果掩码
|
|
|
+ self.register_buffer("bias", torch.tril(torch.ones(config.max_seq_len, config.max_seq_len))
|
|
|
+ .view(1, 1, config.max_seq_len, config.max_seq_len))
|
|
|
+
|
|
|
+ self.scale = 1.0 / math.sqrt(self.head_size)
|
|
|
+
|
|
|
+ def forward(self, x):
|
|
|
+ B, T, C = x.size()
|
|
|
+
|
|
|
+ qkv = self.c_attn(x)
|
|
|
+ q, k, v = qkv.split(self.n_embd, dim=2)
|
|
|
+
|
|
|
+ q = q.view(B, T, self.n_head, self.head_size).transpose(1, 2)
|
|
|
+ k = k.view(B, T, self.n_head, self.head_size).transpose(1, 2)
|
|
|
+ v = v.view(B, T, self.n_head, self.head_size).transpose(1, 2)
|
|
|
+
|
|
|
+ att = (q @ k.transpose(-2, -1)) * self.scale
|
|
|
+ att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))
|
|
|
+ att = F.softmax(att, dim=-1)
|
|
|
+ att = self.attn_dropout(att)
|
|
|
+
|
|
|
+ y = att @ v
|
|
|
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
|
|
|
+ y = self.resid_dropout(self.c_proj(y))
|
|
|
+ return y
|
|
|
+
|
|
|
+
|
|
|
+class OptimizedMLP(nn.Module):
|
|
|
+ """优化的多层感知机"""
|
|
|
+
|
|
|
+ def __init__(self, config: GPTConfig):
|
|
|
+ super().__init__()
|
|
|
+ intermediate_size = 4 * config.n_embd # 恢复4倍维度
|
|
|
+ self.c_fc = nn.Linear(config.n_embd, intermediate_size, bias=config.bias)
|
|
|
+ self.gelu = nn.GELU()
|
|
|
+ self.c_proj = nn.Linear(intermediate_size, config.n_embd, bias=config.bias)
|
|
|
+ self.dropout = nn.Dropout(config.dropout)
|
|
|
+
|
|
|
+ def forward(self, x):
|
|
|
+ x = self.c_fc(x)
|
|
|
+ x = self.gelu(x)
|
|
|
+ x = self.c_proj(x)
|
|
|
+ x = self.dropout(x)
|
|
|
+ return x
|
|
|
+
|
|
|
+
|
|
|
+class OptimizedBlock(nn.Module):
|
|
|
+ """优化的Transformer块"""
|
|
|
+
|
|
|
+ def __init__(self, config: GPTConfig):
|
|
|
+ super().__init__()
|
|
|
+ self.ln_1 = nn.LayerNorm(config.n_embd, eps=1e-5)
|
|
|
+ self.attn = OptimizedCausalSelfAttention(config)
|
|
|
+ self.ln_2 = nn.LayerNorm(config.n_embd, eps=1e-5)
|
|
|
+ self.mlp = OptimizedMLP(config)
|
|
|
+
|
|
|
+ def forward(self, x):
|
|
|
+ x = x + self.attn(self.ln_1(x))
|
|
|
+ x = x + self.mlp(self.ln_2(x))
|
|
|
+ return x
|
|
|
+
|
|
|
+
|
|
|
+class OptimizedGPT(nn.Module):
|
|
|
+ """优化的GPT模型"""
|
|
|
+
|
|
|
+ def __init__(self, config: GPTConfig):
|
|
|
+ super().__init__()
|
|
|
+ self.config = config
|
|
|
+
|
|
|
+ self.wte = nn.Embedding(config.vocab_size, config.n_embd)
|
|
|
+ self.wpe = nn.Embedding(config.max_seq_len, config.n_embd)
|
|
|
+ self.drop = nn.Dropout(config.dropout)
|
|
|
+
|
|
|
+ self.blocks = nn.ModuleList([OptimizedBlock(config) for _ in range(config.n_layer)])
|
|
|
+ self.ln_f = nn.LayerNorm(config.n_embd, eps=1e-5)
|
|
|
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
|
|
+
|
|
|
+ # 权重绑定
|
|
|
+ self.wte.weight = self.lm_head.weight
|
|
|
+
|
|
|
+ self.apply(self._init_weights)
|
|
|
+
|
|
|
+ def _init_weights(self, module):
|
|
|
+ """权重初始化"""
|
|
|
+ if isinstance(module, nn.Linear):
|
|
|
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
|
|
+ if module.bias is not None:
|
|
|
+ torch.nn.init.zeros_(module.bias)
|
|
|
+ elif isinstance(module, nn.Embedding):
|
|
|
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
|
|
+
|
|
|
+ def forward(self, idx, targets=None):
|
|
|
+ device = idx.device
|
|
|
+ b, t = idx.size()
|
|
|
+
|
|
|
+ assert t <= self.config.max_seq_len, f"序列长度{t}超过最大长度{self.config.max_seq_len}"
|
|
|
+
|
|
|
+ pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0)
|
|
|
+
|
|
|
+ tok_emb = self.wte(idx)
|
|
|
+ pos_emb = self.wpe(pos)
|
|
|
+ x = self.drop(tok_emb + pos_emb)
|
|
|
+
|
|
|
+ for block in self.blocks:
|
|
|
+ x = block(x)
|
|
|
+
|
|
|
+ x = self.ln_f(x)
|
|
|
+
|
|
|
+ if targets is not None:
|
|
|
+ logits = self.lm_head(x)
|
|
|
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
|
|
|
+ else:
|
|
|
+ logits = self.lm_head(x)
|
|
|
+ loss = None
|
|
|
+
|
|
|
+ return logits, loss
|
|
|
+
|
|
|
+ def generate(self, idx, max_new_tokens=None, temperature=None, top_k=None, top_p=None, repetition_penalty=None):
|
|
|
+ """生成文本 - 增强版"""
|
|
|
+ max_new_tokens = max_new_tokens or GENERATION_CONFIG['max_tokens']
|
|
|
+ temperature = temperature or GENERATION_CONFIG['temperature']
|
|
|
+ top_k = top_k if top_k is not None else GENERATION_CONFIG['top_k']
|
|
|
+ top_p = top_p if top_p is not None else GENERATION_CONFIG['top_p']
|
|
|
+ repetition_penalty = repetition_penalty or GENERATION_CONFIG['repetition_penalty']
|
|
|
+
|
|
|
+ generated_sequence = []
|
|
|
+
|
|
|
+ for _ in range(max_new_tokens):
|
|
|
+ idx_cond = idx if idx.size(1) <= self.config.max_seq_len else idx[:, -self.config.max_seq_len:]
|
|
|
+
|
|
|
+ logits, _ = self(idx_cond)
|
|
|
+ logits = logits[:, -1, :]
|
|
|
+
|
|
|
+ # 重复惩罚
|
|
|
+ if repetition_penalty != 1.0:
|
|
|
+ for token in set(generated_sequence):
|
|
|
+ logits[0, token] /= repetition_penalty
|
|
|
+
|
|
|
+ # 温度调节
|
|
|
+ if temperature != 1.0:
|
|
|
+ logits = logits / temperature
|
|
|
+
|
|
|
+ # Top-K 过滤
|
|
|
+ if top_k is not None and top_k > 0:
|
|
|
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
|
|
+ logits[logits < v[:, -1].unsqueeze(-1)] = -float('Inf')
|
|
|
+
|
|
|
+ # Top-P (核采样) 过滤
|
|
|
+ if top_p is not None and top_p < 1.0:
|
|
|
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
|
|
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
|
|
+
|
|
|
+ # 移除累积概率超过top_p的token
|
|
|
+ sorted_indices_to_remove = cumulative_probs > top_p
|
|
|
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
|
|
+ sorted_indices_to_remove[..., 0] = 0
|
|
|
+
|
|
|
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
|
|
+ logits[indices_to_remove] = -float('Inf')
|
|
|
+
|
|
|
+ probs = F.softmax(logits, dim=-1)
|
|
|
+
|
|
|
+ # 检查是否有有效的概率
|
|
|
+ if torch.all(probs == 0):
|
|
|
+ break
|
|
|
+
|
|
|
+ idx_next = torch.multinomial(probs, num_samples=1)
|
|
|
+ generated_sequence.append(idx_next.item())
|
|
|
+ idx = torch.cat((idx, idx_next), dim=1)
|
|
|
+
|
|
|
+ return idx
|
|
|
+
|
|
|
+
|
|
|
+class CharTokenizer:
|
|
|
+ """增强版分词器"""
|
|
|
+
|
|
|
+ def __init__(self, text: str = None, stoi: Dict = None):
|
|
|
+ if stoi is not None:
|
|
|
+ self.stoi = self._normalize_stoi(stoi)
|
|
|
+ elif text is not None:
|
|
|
+ chars = sorted(list(set(text)))
|
|
|
+ self.stoi = {ch: i for i, ch in enumerate(chars)}
|
|
|
+ else:
|
|
|
+ raise ValueError("必须提供text或stoi参数")
|
|
|
+
|
|
|
+ self.itos = {v: k for k, v in self.stoi.items()}
|
|
|
+ self.vocab_size = len(self.stoi)
|
|
|
+ self.unknown_token = '?'
|
|
|
+
|
|
|
+ def _normalize_stoi(self, stoi_dict: Dict) -> Dict:
|
|
|
+ """标准化stoi字典"""
|
|
|
+ normalized = {}
|
|
|
+ for k, v in stoi_dict.items():
|
|
|
+ if isinstance(k, str) and k.isdigit():
|
|
|
+ char_key = chr(int(k))
|
|
|
+ normalized[char_key] = int(v)
|
|
|
+ elif isinstance(k, int):
|
|
|
+ normalized[chr(k)] = int(v)
|
|
|
+ else:
|
|
|
+ normalized[k] = int(v)
|
|
|
+ return normalized
|
|
|
+
|
|
|
+ def encode(self, text: str) -> List[int]:
|
|
|
+ return [self.stoi.get(ch, 0) for ch in text]
|
|
|
+
|
|
|
+ def decode(self, indices: List[int]) -> str:
|
|
|
+ return ''.join([self.itos.get(i, self.unknown_token) for i in indices])
|
|
|
+
|
|
|
+ def save(self, filepath: str):
|
|
|
+ """保存分词器"""
|
|
|
+ os.makedirs(os.path.dirname(filepath), exist_ok=True)
|
|
|
+ with open(filepath, 'w', encoding='utf-8') as f:
|
|
|
+ json.dump(self.stoi, f, ensure_ascii=False, indent=2)
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def load(cls, filepath: str):
|
|
|
+ """加载分词器"""
|
|
|
+ with open(filepath, 'r', encoding='utf-8') as f:
|
|
|
+ stoi = json.load(f)
|
|
|
+ tokenizer = cls(stoi=stoi)
|
|
|
+ return tokenizer
|
|
|
+
|
|
|
+ def __str__(self):
|
|
|
+ return f"CharTokenizer(vocab_size={self.vocab_size})"
|
|
|
+
|
|
|
+
|
|
|
+# ==================== 自动数据收集 ====================
|
|
|
+
|
|
|
+def collect_training_data():
|
|
|
+ """自动收集训练数据"""
|
|
|
+ data_sources = []
|
|
|
+
|
|
|
+ # 1. 使用内置的编程数据
|
|
|
+ programming_data = """
|
|
|
+# Python完整知识库
|
|
|
+def calculate_factorial(n):
|
|
|
+ if n == 0 or n == 1:
|
|
|
+ return 1
|
|
|
+ else:
|
|
|
+ return n * calculate_factorial(n-1)
|
|
|
+
|
|
|
+class Student:
|
|
|
+ def __init__(self, name, age, grade):
|
|
|
+ self.name = name
|
|
|
+ self.age = age
|
|
|
+ self.grade = grade
|
|
|
+ self.subjects = []
|
|
|
+
|
|
|
+ def add_subject(self, subject):
|
|
|
+ self.subjects.append(subject)
|
|
|
+
|
|
|
+ def get_average(self, scores):
|
|
|
+ if not scores:
|
|
|
+ return 0
|
|
|
+ return sum(scores) / len(scores)
|
|
|
+
|
|
|
+def read_file_safely(filename):
|
|
|
+ try:
|
|
|
+ with open(filename, 'r', encoding='utf-8') as file:
|
|
|
+ return file.read()
|
|
|
+ except FileNotFoundError:
|
|
|
+ return "文件不存在"
|
|
|
+
|
|
|
+# 数据结构和算法
|
|
|
+def binary_search(arr, target):
|
|
|
+ left, right = 0, len(arr) - 1
|
|
|
+ while left <= right:
|
|
|
+ mid = (left + right) // 2
|
|
|
+ if arr[mid] == target:
|
|
|
+ return mid
|
|
|
+ elif arr[mid] < target:
|
|
|
+ left = mid + 1
|
|
|
+ else:
|
|
|
+ right = mid - 1
|
|
|
+ return -1
|
|
|
+
|
|
|
+# 面向对象编程示例
|
|
|
+class Animal:
|
|
|
+ def __init__(self, name, species):
|
|
|
+ self.name = name
|
|
|
+ self.species = species
|
|
|
+
|
|
|
+ def speak(self):
|
|
|
+ return "动物发出声音"
|
|
|
+
|
|
|
+class Dog(Animal):
|
|
|
+ def __init__(self, name, breed):
|
|
|
+ super().__init__(name, "犬科")
|
|
|
+ self.breed = breed
|
|
|
+
|
|
|
+ def speak(self):
|
|
|
+ return "汪汪!"
|
|
|
+
|
|
|
+# 文件操作类
|
|
|
+class FileProcessor:
|
|
|
+ def __init__(self, filename):
|
|
|
+ self.filename = filename
|
|
|
+
|
|
|
+ def read_content(self):
|
|
|
+ try:
|
|
|
+ with open(self.filename, 'r', encoding='utf-8') as f:
|
|
|
+ return f.read()
|
|
|
+ except Exception as e:
|
|
|
+ return f"错误: {e}"
|
|
|
+
|
|
|
+ def write_content(self, content):
|
|
|
+ try:
|
|
|
+ with open(self.filename, 'w', encoding='utf-8') as f:
|
|
|
+ f.write(content)
|
|
|
+ return True
|
|
|
+ except Exception as e:
|
|
|
+ print(f"写入错误: {e}")
|
|
|
+ return False
|
|
|
+
|
|
|
+# 数学计算函数
|
|
|
+import math
|
|
|
+def quadratic_equation(a, b, c):
|
|
|
+ discriminant = b**2 - 4*a*c
|
|
|
+ if discriminant < 0:
|
|
|
+ return "无实数解"
|
|
|
+ elif discriminant == 0:
|
|
|
+ x = -b / (2*a)
|
|
|
+ return f"唯一解: x = {x}"
|
|
|
+ else:
|
|
|
+ x1 = (-b + math.sqrt(discriminant)) / (2*a)
|
|
|
+ x2 = (-b - math.sqrt(discriminant)) / (2*a)
|
|
|
+ return f"两个解: x1 = {x1}, x2 = {x2}"
|
|
|
+
|
|
|
+# 字符串处理工具
|
|
|
+def process_text(text):
|
|
|
+ lines = text.split('\\n')
|
|
|
+ processed_lines = []
|
|
|
+ for line in lines:
|
|
|
+ line = line.strip()
|
|
|
+ if line and not line.startswith('#'):
|
|
|
+ processed_lines.append(line)
|
|
|
+ return '\\n'.join(processed_lines)
|
|
|
+
|
|
|
+# 列表操作示例
|
|
|
+def list_operations():
|
|
|
+ numbers = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
|
|
+ squares = [x**2 for x in numbers]
|
|
|
+ even_squares = [x**2 for x in numbers if x % 2 == 0]
|
|
|
+ return squares, even_squares
|
|
|
+
|
|
|
+# 字典操作示例
|
|
|
+def dict_operations():
|
|
|
+ student = {
|
|
|
+ "name": "张三",
|
|
|
+ "age": 20,
|
|
|
+ "major": "计算机科学",
|
|
|
+ "grades": {"数学": 90, "英语": 85, "编程": 95}
|
|
|
+ }
|
|
|
+ return student
|
|
|
+
|
|
|
+# 异常处理示例
|
|
|
+def safe_division(a, b):
|
|
|
+ try:
|
|
|
+ result = a / b
|
|
|
+ return result
|
|
|
+ except ZeroDivisionError:
|
|
|
+ return "除数不能为零"
|
|
|
+ except TypeError:
|
|
|
+ return "输入必须是数字"
|
|
|
+
|
|
|
+# 装饰器示例
|
|
|
+def timer(func):
|
|
|
+ def wrapper(*args, **kwargs):
|
|
|
+ import time
|
|
|
+ start = time.time()
|
|
|
+ result = func(*args, **kwargs)
|
|
|
+ end = time.time()
|
|
|
+ print(f"函数 {func.__name__} 执行时间: {end-start:.2f}秒")
|
|
|
+ return result
|
|
|
+ return wrapper
|
|
|
+
|
|
|
+@timer
|
|
|
+def expensive_operation(n):
|
|
|
+ import time
|
|
|
+ time.sleep(0.1)
|
|
|
+ return sum(range(n))
|
|
|
+
|
|
|
+# 生成器示例
|
|
|
+def fibonacci_generator(n):
|
|
|
+ a, b = 0, 1
|
|
|
+ for _ in range(n):
|
|
|
+ yield a
|
|
|
+ a, b = b, a + b
|
|
|
+
|
|
|
+# 上下文管理器
|
|
|
+class DatabaseConnection:
|
|
|
+ def __init__(self, db_name):
|
|
|
+ self.db_name = db_name
|
|
|
+
|
|
|
+ def __enter__(self):
|
|
|
+ print(f"连接数据库: {self.db_name}")
|
|
|
+ return self
|
|
|
+
|
|
|
+ def __exit__(self, exc_type, exc_val, exc_tb):
|
|
|
+ print("关闭数据库连接")
|
|
|
+
|
|
|
+ def query(self, sql):
|
|
|
+ print(f"执行查询: {sql}")
|
|
|
+ return [{"id": 1, "name": "示例数据"}]
|
|
|
+
|
|
|
+# 排序算法
|
|
|
+def bubble_sort(arr):
|
|
|
+ n = len(arr)
|
|
|
+ for i in range(n):
|
|
|
+ for j in range(0, n-i-1):
|
|
|
+ if arr[j] > arr[j+1]:
|
|
|
+ arr[j], arr[j+1] = arr[j+1], arr[j]
|
|
|
+ return arr
|
|
|
+
|
|
|
+def quick_sort(arr):
|
|
|
+ if len(arr) <= 1:
|
|
|
+ return arr
|
|
|
+ pivot = arr[len(arr)//2]
|
|
|
+ left = [x for x in arr if x < pivot]
|
|
|
+ middle = [x for x in arr if x == pivot]
|
|
|
+ right = [x for x in arr if x > pivot]
|
|
|
+ return quick_sort(left) + middle + quick_sort(right)
|
|
|
+
|
|
|
+# 数据结构
|
|
|
+class LinkedList:
|
|
|
+ class Node:
|
|
|
+ def __init__(self, data):
|
|
|
+ self.data = data
|
|
|
+ self.next = None
|
|
|
+
|
|
|
+ def __init__(self):
|
|
|
+ self.head = None
|
|
|
+
|
|
|
+ def append(self, data):
|
|
|
+ new_node = self.Node(data)
|
|
|
+ if not self.head:
|
|
|
+ self.head = new_node
|
|
|
+ return
|
|
|
+ current = self.head
|
|
|
+ while current.next:
|
|
|
+ current = current.next
|
|
|
+ current.next = new_node
|
|
|
+
|
|
|
+ def display(self):
|
|
|
+ elements = []
|
|
|
+ current = self.head
|
|
|
+ while current:
|
|
|
+ elements.append(current.data)
|
|
|
+ current = current.next
|
|
|
+ return elements
|
|
|
+
|
|
|
+class Stack:
|
|
|
+ def __init__(self):
|
|
|
+ self.items = []
|
|
|
+
|
|
|
+ def push(self, item):
|
|
|
+ self.items.append(item)
|
|
|
+
|
|
|
+ def pop(self):
|
|
|
+ if not self.is_empty():
|
|
|
+ return self.items.pop()
|
|
|
+ return None
|
|
|
+
|
|
|
+ def is_empty(self):
|
|
|
+ return len(self.items) == 0
|
|
|
+
|
|
|
+ def peek(self):
|
|
|
+ if not self.is_empty():
|
|
|
+ return self.items[-1]
|
|
|
+ return None
|
|
|
+
|
|
|
+# 主程序入口
|
|
|
+if __name__ == "__main__":
|
|
|
+ # 测试各种功能
|
|
|
+ print("测试开始...")
|
|
|
+
|
|
|
+ # 数学函数测试
|
|
|
+ result = calculate_factorial(5)
|
|
|
+ print(f"5的阶乘: {result}")
|
|
|
+
|
|
|
+ # 学生类测试
|
|
|
+ student = Student("李四", 20, "计算机科学")
|
|
|
+ student.add_subject("Python编程")
|
|
|
+ student.add_subject("数据结构")
|
|
|
+ print(f"学生: {student.name}, 科目: {student.subjects}")
|
|
|
+
|
|
|
+ # 排序测试
|
|
|
+ test_arr = [64, 34, 25, 12, 22, 11, 90]
|
|
|
+ sorted_arr = quick_sort(test_arr.copy())
|
|
|
+ print(f"排序前: {test_arr}")
|
|
|
+ print(f"排序后: {sorted_arr}")
|
|
|
+
|
|
|
+ print("所有测试完成!")
|
|
|
+"""
|
|
|
+ data_sources.append(programming_data)
|
|
|
+
|
|
|
+ # 2. 尝试读取外部数据文件
|
|
|
+ data_files = [
|
|
|
+ "training_data.txt",
|
|
|
+ "programming_data.txt",
|
|
|
+ "code_data.txt",
|
|
|
+ "python_code.txt"
|
|
|
+ ]
|
|
|
+
|
|
|
+ for data_file in data_files:
|
|
|
+ file_path = os.path.join(FILE_CONFIG['save_dir'], data_file)
|
|
|
+ if os.path.exists(file_path):
|
|
|
+ try:
|
|
|
+ with open(file_path, 'r', encoding='utf-8') as f:
|
|
|
+ content = f.read().strip()
|
|
|
+ if len(content) > TRAINING_DATA_CONFIG['min_text_length']:
|
|
|
+ data_sources.append(content)
|
|
|
+ print(f"✅ 加载数据文件: {data_file} ({len(content)} 字符)")
|
|
|
+ except Exception as e:
|
|
|
+ print(f"⚠ 读取数据文件 {data_file} 时出错: {e}")
|
|
|
+
|
|
|
+ # 3. 如果没有足够数据,使用扩展的默认数据
|
|
|
+ if len(''.join(data_sources)) < 10000: # 如果总数据小于10k字符
|
|
|
+ extended_data = """
|
|
|
+# 更多Python编程示例
|
|
|
+
|
|
|
+# 网络请求示例
|
|
|
+import requests
|
|
|
+def fetch_url(url):
|
|
|
+ try:
|
|
|
+ response = requests.get(url, timeout=10)
|
|
|
+ response.raise_for_status()
|
|
|
+ return response.text
|
|
|
+ except requests.RequestException as e:
|
|
|
+ return f"请求失败: {e}"
|
|
|
+
|
|
|
+# 数据处理示例
|
|
|
+import json
|
|
|
+def process_json_data(json_string):
|
|
|
+ try:
|
|
|
+ data = json.loads(json_string)
|
|
|
+ return data
|
|
|
+ except json.JSONDecodeError as e:
|
|
|
+ return f"JSON解析错误: {e}"
|
|
|
+
|
|
|
+def save_to_json(data, filename):
|
|
|
+ try:
|
|
|
+ with open(filename, 'w', encoding='utf-8') as f:
|
|
|
+ json.dump(data, f, ensure_ascii=False, indent=2)
|
|
|
+ return True
|
|
|
+ except Exception as e:
|
|
|
+ print(f"保存JSON失败: {e}")
|
|
|
+ return False
|
|
|
+
|
|
|
+# 日期时间处理
|
|
|
+from datetime import datetime, timedelta
|
|
|
+def date_operations():
|
|
|
+ now = datetime.now()
|
|
|
+ tomorrow = now + timedelta(days=1)
|
|
|
+ last_week = now - timedelta(weeks=1)
|
|
|
+
|
|
|
+ return {
|
|
|
+ "now": now.strftime("%Y-%m-%d %H:%M:%S"),
|
|
|
+ "tomorrow": tomorrow.strftime("%Y-%m-%d"),
|
|
|
+ "last_week": last_week.strftime("%Y-%m-%d")
|
|
|
+ }
|
|
|
+
|
|
|
+# 正则表达式示例
|
|
|
+import re
|
|
|
+def extract_emails(text):
|
|
|
+ pattern = r'\\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\\.[A-Z|a-z]{2,}\\b'
|
|
|
+ return re.findall(pattern, text)
|
|
|
+
|
|
|
+def validate_phone(phone):
|
|
|
+ pattern = r'^1[3-9]\\d{9}$'
|
|
|
+ return bool(re.match(pattern, phone))
|
|
|
+
|
|
|
+# 多线程示例
|
|
|
+import threading
|
|
|
+import time
|
|
|
+
|
|
|
+class Counter:
|
|
|
+ def __init__(self):
|
|
|
+ self.value = 0
|
|
|
+ self.lock = threading.Lock()
|
|
|
+
|
|
|
+ def increment(self):
|
|
|
+ with self.lock:
|
|
|
+ self.value += 1
|
|
|
+
|
|
|
+def worker(counter, iterations):
|
|
|
+ for _ in range(iterations):
|
|
|
+ counter.increment()
|
|
|
+
|
|
|
+# 单元测试示例
|
|
|
+import unittest
|
|
|
+class TestMathFunctions(unittest.TestCase):
|
|
|
+ def test_factorial(self):
|
|
|
+ self.assertEqual(calculate_factorial(5), 120)
|
|
|
+ self.assertEqual(calculate_factorial(0), 1)
|
|
|
+
|
|
|
+ def test_binary_search(self):
|
|
|
+ arr = [1, 3, 5, 7, 9]
|
|
|
+ self.assertEqual(binary_search(arr, 5), 2)
|
|
|
+ self.assertEqual(binary_search(arr, 2), -1)
|
|
|
+
|
|
|
+# 主程序入口
|
|
|
+if __name__ == "__main__":
|
|
|
+ # 测试各种功能
|
|
|
+ print("测试开始...")
|
|
|
+
|
|
|
+ # 数学函数测试
|
|
|
+ result = calculate_factorial(5)
|
|
|
+ print(f"5的阶乘: {result}")
|
|
|
+
|
|
|
+ # 学生类测试
|
|
|
+ student = Student("李四", 20, "计算机科学")
|
|
|
+ student.add_subject("Python编程")
|
|
|
+ student.add_subject("数据结构")
|
|
|
+ print(f"学生: {student.name}, 科目: {student.subjects}")
|
|
|
+
|
|
|
+ # 排序测试
|
|
|
+ test_arr = [64, 34, 25, 12, 22, 11, 90]
|
|
|
+ sorted_arr = quick_sort(test_arr.copy())
|
|
|
+ print(f"排序前: {test_arr}")
|
|
|
+ print(f"排序后: {sorted_arr}")
|
|
|
+
|
|
|
+ print("所有测试完成!")
|
|
|
+"""
|
|
|
+ data_sources.append(extended_data)
|
|
|
+
|
|
|
+ # 合并所有数据源
|
|
|
+ combined_data = '\n'.join(data_sources)
|
|
|
+
|
|
|
+ # 数据重复以增加训练样本
|
|
|
+ combined_data = combined_data * TRAINING_DATA_CONFIG['data_repetition']
|
|
|
+
|
|
|
+ print(f"📊 总训练数据: {len(combined_data):,} 字符")
|
|
|
+ return combined_data
|
|
|
+
|
|
|
+
|
|
|
+# ==================== 优化的训练函数 ====================
|
|
|
+
|
|
|
+class EnhancedTrainingMonitor:
|
|
|
+ """增强的训练监控器"""
|
|
|
+
|
|
|
+ def __init__(self, save_dir: str = FILE_CONFIG['save_dir']):
|
|
|
+ self.losses = []
|
|
|
+ self.start_time = time.time()
|
|
|
+ self.save_dir = save_dir
|
|
|
+ os.makedirs(save_dir, exist_ok=True)
|
|
|
+ self.best_loss = float('inf')
|
|
|
+ self.patience = 100
|
|
|
+ self.checkpoint_frequency = 50 # 每50轮保存一次
|
|
|
+
|
|
|
+ def update(self, loss, epoch, model=None, tokenizer=None):
|
|
|
+ self.losses.append(loss)
|
|
|
+ elapsed = time.time() - self.start_time
|
|
|
+
|
|
|
+ # 进度显示
|
|
|
+ if epoch % 20 == 0 or epoch < 10 or epoch == TRAINING_CONFIG['epochs'] - 1:
|
|
|
+ epochs_per_sec = (epoch + 1) / elapsed
|
|
|
+ eta = (TRAINING_CONFIG['epochs'] - epoch - 1) / epochs_per_sec if epochs_per_sec > 0 else 0
|
|
|
+
|
|
|
+ print(f"Epoch {epoch:4d} | Loss: {loss:.4f} | "
|
|
|
+ f"Speed: {epochs_per_sec:.2f} epoch/s | ETA: {eta:.0f}s")
|
|
|
+
|
|
|
+ # 定期保存检查点
|
|
|
+ if model and epoch % self.checkpoint_frequency == 0 and epoch > 0:
|
|
|
+ self.save_checkpoint(model, tokenizer, epoch, loss)
|
|
|
+
|
|
|
+ # 保存最佳模型
|
|
|
+ if loss < self.best_loss:
|
|
|
+ self.best_loss = loss
|
|
|
+ if model and epoch > 100: # 100轮后才开始保存最佳模型
|
|
|
+ self.save_best_model(model, tokenizer, epoch, loss)
|
|
|
+
|
|
|
+ def save_checkpoint(self, model, tokenizer, epoch, loss):
|
|
|
+ """保存检查点"""
|
|
|
+ checkpoint = {
|
|
|
+ 'epoch': epoch,
|
|
|
+ 'model_state_dict': model.state_dict(),
|
|
|
+ 'loss': loss,
|
|
|
+ 'config': model.config,
|
|
|
+ 'tokenizer': tokenizer.stoi,
|
|
|
+ 'timestamp': datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
|
+ }
|
|
|
+ path = os.path.join(self.save_dir, f"checkpoint_epoch_{epoch}.pth")
|
|
|
+ torch.save(checkpoint, path)
|
|
|
+ print(f"💾 检查点已保存: {path}")
|
|
|
+
|
|
|
+ def save_best_model(self, model, tokenizer, epoch, loss):
|
|
|
+ """保存最佳模型"""
|
|
|
+ model_data = {
|
|
|
+ 'model_state_dict': model.state_dict(),
|
|
|
+ 'config': model.config,
|
|
|
+ 'tokenizer': tokenizer.stoi,
|
|
|
+ 'training_losses': self.losses,
|
|
|
+ 'epoch': epoch,
|
|
|
+ 'loss': loss,
|
|
|
+ 'timestamp': datetime.now().strftime("%Y%m%d_%H%M%S"),
|
|
|
+ 'global_configs': {
|
|
|
+ 'MODEL_CONFIG': MODEL_CONFIG,
|
|
|
+ 'TRAINING_CONFIG': TRAINING_CONFIG,
|
|
|
+ 'GENERATION_CONFIG': GENERATION_CONFIG
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ model_path = os.path.join(self.save_dir, f"{FILE_CONFIG['model_prefix']}_best.pth")
|
|
|
+ torch.save(model_data, model_path)
|
|
|
+ print(f"🏆 最佳模型已保存: {model_path} (loss: {loss:.4f})")
|
|
|
+
|
|
|
+ def plot_loss(self):
|
|
|
+ """绘制损失曲线"""
|
|
|
+ try:
|
|
|
+ import matplotlib.pyplot as plt
|
|
|
+ plt.figure(figsize=(12, 6))
|
|
|
+ plt.plot(self.losses)
|
|
|
+ plt.title('Training Loss Progress')
|
|
|
+ plt.xlabel('Epoch')
|
|
|
+ plt.ylabel('Loss')
|
|
|
+ plt.grid(True, alpha=0.3)
|
|
|
+ loss_path = os.path.join(self.save_dir, 'training_loss.png')
|
|
|
+ plt.savefig(loss_path, dpi=150, bbox_inches='tight')
|
|
|
+ print(f"✓ 损失曲线已保存: {loss_path}")
|
|
|
+ except ImportError:
|
|
|
+ print("⚠ 未安装matplotlib,无法绘制损失曲线")
|
|
|
+
|
|
|
+
|
|
|
+def get_improved_learning_rate(epoch, warmup_epochs=20):
|
|
|
+ """改进的学习率调度"""
|
|
|
+ if epoch < warmup_epochs:
|
|
|
+ # 线性预热
|
|
|
+ return TRAINING_CONFIG['learning_rate'] * (epoch + 1) / warmup_epochs
|
|
|
+ else:
|
|
|
+ # 余弦退火
|
|
|
+ progress = (epoch - warmup_epochs) / (TRAINING_CONFIG['epochs'] - warmup_epochs)
|
|
|
+ return TRAINING_CONFIG['learning_rate'] * 0.5 * (1 + math.cos(math.pi * progress))
|
|
|
+
|
|
|
+
|
|
|
+def improved_train_gpt(model: OptimizedGPT, X: torch.Tensor, Y: torch.Tensor, tokenizer: CharTokenizer):
|
|
|
+ """改进的训练函数"""
|
|
|
+ model.train()
|
|
|
+
|
|
|
+ optimizer = torch.optim.AdamW(
|
|
|
+ model.parameters(),
|
|
|
+ lr=TRAINING_CONFIG['learning_rate'],
|
|
|
+ weight_decay=TRAINING_CONFIG['weight_decay'],
|
|
|
+ betas=(0.9, 0.95)
|
|
|
+ )
|
|
|
+
|
|
|
+ monitor = EnhancedTrainingMonitor(FILE_CONFIG['save_dir'])
|
|
|
+
|
|
|
+ print(f"🚀 开始训练GPT模型")
|
|
|
+ print(f"📊 总轮数: {TRAINING_CONFIG['epochs']}")
|
|
|
+ print(f"🔢 模型参数: {sum(p.numel() for p in model.parameters()):,}")
|
|
|
+ print(f"📚 训练样本: {len(X):,}")
|
|
|
+
|
|
|
+ # 设备设置
|
|
|
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
+ print(f"💻 使用设备: {device}")
|
|
|
+ model.to(device)
|
|
|
+ X, Y = X.to(device), Y.to(device)
|
|
|
+
|
|
|
+ best_loss = float('inf')
|
|
|
+ patience_counter = 0
|
|
|
+
|
|
|
+ for epoch in range(TRAINING_CONFIG['epochs']):
|
|
|
+ # 动态学习率
|
|
|
+ lr = get_improved_learning_rate(epoch, TRAINING_CONFIG['warmup_epochs'])
|
|
|
+ for param_group in optimizer.param_groups:
|
|
|
+ param_group['lr'] = lr
|
|
|
+
|
|
|
+ # 训练步骤
|
|
|
+ optimizer.zero_grad()
|
|
|
+
|
|
|
+ indices = torch.randint(0, len(X), (TRAINING_CONFIG['batch_size'],))
|
|
|
+ x_batch = X[indices]
|
|
|
+ y_batch = Y[indices]
|
|
|
+
|
|
|
+ logits, loss = model(x_batch, y_batch)
|
|
|
+ loss.backward()
|
|
|
+
|
|
|
+ # 梯度裁剪
|
|
|
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=TRAINING_CONFIG['grad_clip'])
|
|
|
+ optimizer.step()
|
|
|
+
|
|
|
+ monitor.update(loss.item(), epoch, model, tokenizer)
|
|
|
+
|
|
|
+ # 早停检查
|
|
|
+ if loss.item() < best_loss:
|
|
|
+ best_loss = loss.item()
|
|
|
+ patience_counter = 0
|
|
|
+ else:
|
|
|
+ patience_counter += 1
|
|
|
+
|
|
|
+ # 早停条件
|
|
|
+ if patience_counter >= monitor.patience and epoch > 300:
|
|
|
+ print(f"🛑 早停触发,第{epoch}轮")
|
|
|
+ break
|
|
|
+
|
|
|
+ # 损失足够小提前停止
|
|
|
+ if loss.item() < TRAINING_CONFIG['min_loss'] and epoch > 200:
|
|
|
+ print(f"✅ 训练完成,损失已达目标值 {loss.item():.4f}")
|
|
|
+ break
|
|
|
+
|
|
|
+ print("🎉 训练完成!")
|
|
|
+ monitor.plot_loss()
|
|
|
+ return monitor.losses
|
|
|
+
|
|
|
+
|
|
|
+# ==================== 工具函数 ====================
|
|
|
+
|
|
|
+def create_improved_sample_dataset(text: str, block_size: int = None) -> Tuple[
|
|
|
+ torch.Tensor, torch.Tensor, 'CharTokenizer']:
|
|
|
+ """创建改进的训练数据集"""
|
|
|
+ block_size = block_size or TRAINING_CONFIG['block_size']
|
|
|
+
|
|
|
+ # 文本预处理
|
|
|
+ lines = text.split('\n')
|
|
|
+ cleaned_lines = []
|
|
|
+ for line in lines:
|
|
|
+ line = line.strip()
|
|
|
+ if line and not line.startswith('#'): # 移除空行和注释
|
|
|
+ cleaned_lines.append(line)
|
|
|
+ text = '\n'.join(cleaned_lines)
|
|
|
+
|
|
|
+ # 数据增强
|
|
|
+ text = text * TRAINING_DATA_CONFIG['data_repetition']
|
|
|
+
|
|
|
+ tokenizer = CharTokenizer(text)
|
|
|
+
|
|
|
+ if len(text) < block_size + 1:
|
|
|
+ print("⚠ 文本较短,使用重叠采样")
|
|
|
+ data = torch.tensor(tokenizer.encode(text), dtype=torch.long)
|
|
|
+ while len(data) < block_size + 1000:
|
|
|
+ data = torch.cat([data, data])
|
|
|
+ data = data[:block_size + 2000]
|
|
|
+ else:
|
|
|
+ data = torch.tensor(tokenizer.encode(text), dtype=torch.long)
|
|
|
+
|
|
|
+ n = len(data) - block_size
|
|
|
+ if n <= 0:
|
|
|
+ raise ValueError("无法创建训练样本")
|
|
|
+
|
|
|
+ # 创建训练样本
|
|
|
+ X = torch.stack([data[i:i + block_size] for i in range(0, n, 1)]) # 步长为1获取更多样本
|
|
|
+ Y = torch.stack([data[i + 1:i + block_size + 1] for i in range(0, n, 1)])
|
|
|
+
|
|
|
+ print(f"✅ 创建了 {len(X):,} 个训练样本")
|
|
|
+ print(f"🔤 词汇表大小: {tokenizer.vocab_size}")
|
|
|
+ return X, Y, tokenizer
|
|
|
+
|
|
|
+
|
|
|
+def format_generated_text(text: str, start_text: str) -> str:
|
|
|
+ """格式化生成的文本"""
|
|
|
+ # 移除起始文本
|
|
|
+ if text.startswith(start_text):
|
|
|
+ generated_part = text[len(start_text):]
|
|
|
+ else:
|
|
|
+ generated_part = text
|
|
|
+
|
|
|
+ # 清理文本
|
|
|
+ lines = generated_part.split('\n')
|
|
|
+ cleaned_lines = []
|
|
|
+
|
|
|
+ for line in lines:
|
|
|
+ line = line.strip()
|
|
|
+ if line:
|
|
|
+ # 简单的代码格式检测
|
|
|
+ if any(keyword in line for keyword in ['def ', 'class ', 'import ', 'from ', 'if ', 'for ', 'while ']):
|
|
|
+ cleaned_lines.append(line)
|
|
|
+ elif line.startswith('#') or line.startswith('"""') or line.startswith("'''"):
|
|
|
+ cleaned_lines.append(line)
|
|
|
+ elif '=' in line or ':' in line or line.endswith(':'):
|
|
|
+ cleaned_lines.append(line)
|
|
|
+ elif len(line) > 10: # 保留较长的文本行
|
|
|
+ cleaned_lines.append(line)
|
|
|
+
|
|
|
+ return '\n'.join(cleaned_lines)
|
|
|
+
|
|
|
+
|
|
|
+def interactive_generation(model: OptimizedGPT, tokenizer: CharTokenizer):
|
|
|
+ """改进的交互式文本生成 - 使用全局参数"""
|
|
|
+ print("\n" + "=" * 60)
|
|
|
+ print("🤖 进入交互式生成模式")
|
|
|
+ print("💡 提示: 输入Python代码片段或自然语言描述")
|
|
|
+ print("⏹️ 退出: 输入 'quit', 'exit', 或 '退出'")
|
|
|
+ print("🔧 使用全局生成参数:")
|
|
|
+ print(f" 🌡️ 温度: {GENERATION_CONFIG['temperature']}")
|
|
|
+ print(f" 🔝 Top-K: {GENERATION_CONFIG['top_k']}")
|
|
|
+ print(f" 📏 生成长度: {GENERATION_CONFIG['max_tokens']}")
|
|
|
+ print("=" * 60)
|
|
|
+
|
|
|
+ model.eval()
|
|
|
+
|
|
|
+ while True:
|
|
|
+ try:
|
|
|
+ print("\n" + "-" * 40)
|
|
|
+ user_input = input("🎯 请输入起始文本: ").strip()
|
|
|
+
|
|
|
+ if user_input.lower() in ['quit', 'exit', '退出']:
|
|
|
+ break
|
|
|
+
|
|
|
+ if not user_input:
|
|
|
+ print("⚠ 输入不能为空,请重新输入。")
|
|
|
+ continue
|
|
|
+
|
|
|
+ print(f"⚡ 生成中...", end='', flush=True)
|
|
|
+ start_time = time.time()
|
|
|
+
|
|
|
+ # 使用全局配置参数
|
|
|
+ start_tokens = torch.tensor([tokenizer.encode(user_input)], dtype=torch.long)
|
|
|
+
|
|
|
+ with torch.no_grad():
|
|
|
+ generated_tokens = model.generate(
|
|
|
+ start_tokens,
|
|
|
+ max_new_tokens=GENERATION_CONFIG['max_tokens'],
|
|
|
+ temperature=GENERATION_CONFIG['temperature'],
|
|
|
+ top_k=GENERATION_CONFIG['top_k'],
|
|
|
+ top_p=GENERATION_CONFIG['top_p'],
|
|
|
+ repetition_penalty=GENERATION_CONFIG['repetition_penalty']
|
|
|
+ )
|
|
|
+
|
|
|
+ elapsed = time.time() - start_time
|
|
|
+ print(f"完成! (耗时: {elapsed:.2f}s)")
|
|
|
+
|
|
|
+ # 解码和格式化
|
|
|
+ full_text = tokenizer.decode(generated_tokens[0].tolist())
|
|
|
+ formatted_text = format_generated_text(full_text, user_input)
|
|
|
+
|
|
|
+ print(f"\n📊 生成结果:")
|
|
|
+ print("=" * 50)
|
|
|
+ print(f"🎯 起始: {user_input}")
|
|
|
+ print("-" * 50)
|
|
|
+ if formatted_text:
|
|
|
+ print(formatted_text)
|
|
|
+ else:
|
|
|
+ # 如果格式化后为空,显示原始生成文本(截断)
|
|
|
+ display_text = full_text[len(user_input):]
|
|
|
+ if len(display_text) > 300:
|
|
|
+ display_text = display_text[:300] + "..."
|
|
|
+ print(display_text)
|
|
|
+ print("=" * 50)
|
|
|
+ print(f"📏 总长度: {len(full_text)} 字符")
|
|
|
+
|
|
|
+ except KeyboardInterrupt:
|
|
|
+ print("\n\n🛑 用户中断,退出交互模式")
|
|
|
+ break
|
|
|
+ except Exception as e:
|
|
|
+ print(f"❌ 生成时出错: {e}")
|
|
|
+
|
|
|
+
|
|
|
+def get_available_models(save_dir: str = None) -> List[Tuple[str, str]]:
|
|
|
+ """获取所有可用的模型文件"""
|
|
|
+ save_dir = save_dir or FILE_CONFIG['save_dir']
|
|
|
+ if not os.path.exists(save_dir):
|
|
|
+ return []
|
|
|
+
|
|
|
+ model_files = [f for f in os.listdir(save_dir) if f.endswith('.pth') and 'checkpoint' not in f]
|
|
|
+ if not model_files:
|
|
|
+ return []
|
|
|
+
|
|
|
+ model_files.sort(key=lambda x: os.path.getmtime(os.path.join(save_dir, x)), reverse=True)
|
|
|
+ return [(f, os.path.join(save_dir, f)) for f in model_files]
|
|
|
+
|
|
|
+
|
|
|
+def select_model_interactively() -> str:
|
|
|
+ """交互式选择模型文件"""
|
|
|
+ available_models = get_available_models()
|
|
|
+
|
|
|
+ if not available_models:
|
|
|
+ print("❌ 在 my_gptmodel 目录中未找到任何模型文件")
|
|
|
+ return None
|
|
|
+
|
|
|
+ print("\n📂 可用的模型文件:")
|
|
|
+ print("-" * 60)
|
|
|
+ for i, (filename, full_path) in enumerate(available_models, 1):
|
|
|
+ mtime = os.path.getmtime(full_path)
|
|
|
+ mtime_str = datetime.fromtimestamp(mtime).strftime("%Y-%m-%d %H:%M:%S")
|
|
|
+ size = os.path.getsize(full_path) / 1024 / 1024 # MB
|
|
|
+ print(f"{i:2d}. {filename}")
|
|
|
+ print(f" 修改时间: {mtime_str} | 大小: {size:.1f}MB")
|
|
|
+
|
|
|
+ while True:
|
|
|
+ try:
|
|
|
+ choice = input(f"\n🎲 请选择模型文件 (1-{len(available_models)}): ").strip()
|
|
|
+ if not choice:
|
|
|
+ return available_models[0][1]
|
|
|
+
|
|
|
+ index = int(choice) - 1
|
|
|
+ if 0 <= index < len(available_models):
|
|
|
+ return available_models[index][1]
|
|
|
+ else:
|
|
|
+ print(f"⚠ 请输入 1-{len(available_models)} 之间的数字")
|
|
|
+ except ValueError:
|
|
|
+ print("⚠ 请输入有效的数字")
|
|
|
+
|
|
|
+
|
|
|
+class AdvancedGPT(OptimizedGPT):
|
|
|
+ """增强版GPT"""
|
|
|
+
|
|
|
+ def __init__(self, config: GPTConfig):
|
|
|
+ super().__init__(config)
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def from_pretrained(cls, model_path: str, weights_only=False):
|
|
|
+ """从预训练文件加载模型"""
|
|
|
+ try:
|
|
|
+ checkpoint = torch.load(model_path, map_location='cpu', weights_only=weights_only)
|
|
|
+ config = checkpoint['config']
|
|
|
+ model = cls(config)
|
|
|
+ model.load_state_dict(checkpoint['model_state_dict'])
|
|
|
+ global_configs = checkpoint.get('global_configs', {})
|
|
|
+ return model, checkpoint.get('tokenizer', None), checkpoint.get('training_losses', []), global_configs
|
|
|
+ except Exception as e:
|
|
|
+ print(f"❌ 加载模型时出错: {e}")
|
|
|
+ raise
|
|
|
+
|
|
|
+
|
|
|
+def load_and_test_model(model_path: str = None):
|
|
|
+ """加载并测试模型"""
|
|
|
+ try:
|
|
|
+ if model_path is None:
|
|
|
+ model_path = select_model_interactively()
|
|
|
+ if model_path is None:
|
|
|
+ return None, None
|
|
|
+
|
|
|
+ print(f"📥 加载模型: {model_path}")
|
|
|
+ model, tokenizer_dict, losses, global_configs = AdvancedGPT.from_pretrained(model_path, weights_only=False)
|
|
|
+
|
|
|
+ tokenizer = CharTokenizer(stoi=tokenizer_dict)
|
|
|
+ print(f"✅ 模型加载成功")
|
|
|
+ print(f"🔤 词汇表大小: {tokenizer.vocab_size}")
|
|
|
+
|
|
|
+ # 改进的测试生成
|
|
|
+ test_prompts = [
|
|
|
+ "def calculate",
|
|
|
+ "class Student",
|
|
|
+ "import pandas",
|
|
|
+ "for i in range",
|
|
|
+ "# 单元测试",
|
|
|
+ "def read_file"
|
|
|
+ ]
|
|
|
+
|
|
|
+ print(f"\n🧪 模型测试生成:")
|
|
|
+ print("-" * 40)
|
|
|
+
|
|
|
+ for i, prompt in enumerate(test_prompts[:3], 1): # 只测试前3个
|
|
|
+ print(f"\n测试 {i}: '{prompt}'")
|
|
|
+ start_tokens = torch.tensor([tokenizer.encode(prompt)], dtype=torch.long)
|
|
|
+
|
|
|
+ with torch.no_grad():
|
|
|
+ generated = model.generate(start_tokens, max_new_tokens=100)
|
|
|
+
|
|
|
+ result = tokenizer.decode(generated[0].tolist())
|
|
|
+ formatted = format_generated_text(result, prompt)
|
|
|
+ if formatted:
|
|
|
+ print(formatted[:200] + "..." if len(formatted) > 200 else formatted)
|
|
|
+ else:
|
|
|
+ print("⚠ 生成结果为空")
|
|
|
+
|
|
|
+ return model, tokenizer
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"❌ 加载模型失败: {e}")
|
|
|
+ import traceback
|
|
|
+ traceback.print_exc()
|
|
|
+ return None, None
|
|
|
+
|
|
|
+
|
|
|
+def create_programming_data_file():
|
|
|
+ """创建编程数据文件"""
|
|
|
+ programming_data_path = os.path.join(FILE_CONFIG['save_dir'], FILE_CONFIG['programming_data_file'])
|
|
|
+
|
|
|
+ if not os.path.exists(programming_data_path):
|
|
|
+ print(f"📝 创建编程数据文件: {programming_data_path}")
|
|
|
+ programming_data = collect_training_data()
|
|
|
+
|
|
|
+ try:
|
|
|
+ with open(programming_data_path, 'w', encoding='utf-8') as f:
|
|
|
+ f.write(programming_data.strip())
|
|
|
+ print(f"✅ 编程数据文件已创建: {programming_data_path}")
|
|
|
+ print("💡 您可以将自己的Python代码数据添加到这个文件中")
|
|
|
+ except Exception as e:
|
|
|
+ print(f"❌ 创建编程数据文件时出错: {e}")
|
|
|
+
|
|
|
+
|
|
|
+def main():
|
|
|
+ """主函数"""
|
|
|
+ try:
|
|
|
+ # 创建保存目录
|
|
|
+ save_dir = FILE_CONFIG['save_dir']
|
|
|
+ os.makedirs(save_dir, exist_ok=True)
|
|
|
+
|
|
|
+ print("🤖 GPT语言模型训练与生成系统")
|
|
|
+ print("=" * 60)
|
|
|
+ print(f"📁 文件保存目录: {save_dir}")
|
|
|
+ print(f"⚙️ 模型配置: {MODEL_CONFIG}")
|
|
|
+ print(f"⚙️ 训练配置: {TRAINING_CONFIG}")
|
|
|
+ print(f"⚙️ 生成配置: {GENERATION_CONFIG}")
|
|
|
+
|
|
|
+ # 1. 自动收集训练数据
|
|
|
+ print("\n1. 📚 收集训练数据...")
|
|
|
+ training_data = collect_training_data()
|
|
|
+
|
|
|
+ # 2. 创建数据集
|
|
|
+ print("\n2. 🗂️ 创建数据集...")
|
|
|
+ X, Y, tokenizer = create_improved_sample_dataset(training_data)
|
|
|
+ print(f" ✅ 数据集: {len(X):,} 样本")
|
|
|
+ print(f" 🔤 词汇表: {tokenizer.vocab_size} 字符")
|
|
|
+
|
|
|
+ # 3. 创建模型
|
|
|
+ print("\n3. 🧠 创建模型...")
|
|
|
+ config = GPTConfig(vocab_size=tokenizer.vocab_size)
|
|
|
+ print(f" {config}")
|
|
|
+
|
|
|
+ model = OptimizedGPT(config)
|
|
|
+ print(f" ✅ 参数数量: {sum(p.numel() for p in model.parameters()):,}")
|
|
|
+
|
|
|
+ # 4. 训练模型
|
|
|
+ print("\n4. 🏋️ 训练模型...")
|
|
|
+ losses = improved_train_gpt(model, X, Y, tokenizer)
|
|
|
+
|
|
|
+ # 5. 保存最终模型
|
|
|
+ print("\n5. 💾 保存最终模型...")
|
|
|
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
|
+ model_path = os.path.join(save_dir, f"{FILE_CONFIG['model_prefix']}_final_{timestamp}.pth")
|
|
|
+ tokenizer_path = os.path.join(save_dir, f"{FILE_CONFIG['tokenizer_prefix']}_{timestamp}.json")
|
|
|
+
|
|
|
+ torch.save({
|
|
|
+ 'model_state_dict': model.state_dict(),
|
|
|
+ 'config': config,
|
|
|
+ 'tokenizer': tokenizer.stoi,
|
|
|
+ 'training_losses': losses,
|
|
|
+ 'timestamp': timestamp,
|
|
|
+ 'global_configs': {
|
|
|
+ 'MODEL_CONFIG': MODEL_CONFIG,
|
|
|
+ 'TRAINING_CONFIG': TRAINING_CONFIG,
|
|
|
+ 'GENERATION_CONFIG': GENERATION_CONFIG
|
|
|
+ }
|
|
|
+ }, model_path)
|
|
|
+
|
|
|
+ tokenizer.save(tokenizer_path)
|
|
|
+ print(f" ✅ 模型已保存: {model_path}")
|
|
|
+ print(f" ✅ 分词器已保存: {tokenizer_path}")
|
|
|
+
|
|
|
+ # 6. 交互式生成
|
|
|
+ print("\n6. 🎮 进入交互模式...")
|
|
|
+ interactive_generation(model, tokenizer)
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"❌ 程序执行出错: {e}")
|
|
|
+ import traceback
|
|
|
+ traceback.print_exc()
|
|
|
+
|
|
|
+
|
|
|
+def auto_detect_and_run():
|
|
|
+ """自动检测并运行"""
|
|
|
+ print("🔍 GPT语言模型自动检测系统")
|
|
|
+ print("=" * 50)
|
|
|
+
|
|
|
+ available_models = get_available_models()
|
|
|
+
|
|
|
+ if available_models:
|
|
|
+ print(f"📂 检测到 {len(available_models)} 个现有模型")
|
|
|
+ print("🔄 自动加载最新模型...")
|
|
|
+
|
|
|
+ latest_model_path = available_models[0][1]
|
|
|
+ model, tokenizer = load_and_test_model(latest_model_path)
|
|
|
+
|
|
|
+ if model and tokenizer:
|
|
|
+ print("\n✅ 模型加载成功,进入交互式生成模式")
|
|
|
+ interactive_generation(model, tokenizer)
|
|
|
+ else:
|
|
|
+ print("❌ 模型加载失败,开始训练新模型...")
|
|
|
+ main()
|
|
|
+ else:
|
|
|
+ print("❌ 未检测到现有模型,开始训练新模型...")
|
|
|
+ main()
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ # 自动检测并运行:有模型就加载,没有就训练
|
|
|
+ auto_detect_and_run()
|