""" 模型比较器模块 - 用于比较不同猫叫声意图分类模型的性能 该模块提供了比较DAG-HMM、深度学习、SVM和随机森林等不同分类方法的功能, 帮助用户选择最适合其数据集的模型。 """ import os import numpy as np import json import matplotlib.pyplot as plt from typing import Dict, Any, List, Optional, Tuple import time from datetime import datetime from src.cat_intent_classifier_v2 import CatIntentClassifier from src.dag_hmm_classifier import DAGHMMClassifier class ModelComparator: """模型比较器类,用于比较不同猫叫声意图分类模型的性能""" def __init__(self, results_dir: str = "./comparison_results"): """ 初始化模型比较器 参数: results_dir: 结果保存目录 """ self.results_dir = results_dir os.makedirs(results_dir, exist_ok=True) # 支持的模型类型 self.model_types = { "dag_hmm": { "name": "DAG-HMM", "class": DAGHMMClassifier, "params": {"n_states": 5, "n_mix": 3} }, "dl": { "name": "深度学习", "class": CatIntentClassifier, "params": {} } } def compare_models(self, features: List[np.ndarray], labels: List[str], model_types: List[str] = None, test_size: float = 0.2, cat_name: Optional[str] = None) -> Dict[str, Any]: """ 比较不同模型的性能 参数: features: 特征序列列表 labels: 标签列表 model_types: 要比较的模型类型列表,默认为所有支持的模型 test_size: 测试集比例 cat_name: 猫咪名称,默认为None(通用模型) 返回: results: 比较结果 """ if model_types is None: model_types = list(self.model_types.keys()) # 验证模型类型 for model_type in model_types: if model_type not in self.model_types: raise ValueError(f"不支持的模型类型: {model_type}") # 划分训练集和测试集 from sklearn.model_selection import train_test_split _, test_features, _, test_labels = train_test_split( features, labels, test_size=test_size, random_state=42, stratify=labels ) train_features, train_labels = features, labels print(f"训练集大小: {len(train_features)}, 测试集大小: {len(test_features)}") # 比较结果 results = { "models": {}, "best_model": None, "comparison_time": datetime.now().isoformat(), "dataset_info": { "total_samples": len(features), "train_samples": len(train_features), "test_samples": len(test_features), "classes": sorted(list(set(labels))), "class_distribution": {label: labels.count(label) for label in set(labels)} } } # 训练和评估每个模型 for model_type in model_types: model_info = self.model_types[model_type] model_name = model_info["name"] model_class = model_info["class"] model_params = model_info["params"] print(f"\n开始训练和评估 {model_name} 模型...") try: # 创建模型 model = model_class(**model_params) # 记录训练开始时间 train_start_time = time.time() # 训练模型 train_metrics = model.train(train_features, train_labels) # 记录训练结束时间 train_end_time = time.time() train_time = train_end_time - train_start_time # 记录评估开始时间 eval_start_time = time.time() # 评估模型 eval_metrics = model.evaluate(test_features, test_labels) # 记录评估结束时间 eval_end_time = time.time() eval_time = eval_end_time - eval_start_time # 保存模型 model_dir = os.path.join(self.results_dir, "models") os.makedirs(model_dir, exist_ok=True) model_paths = model.save_model(model_dir, cat_name) # 记录结果 results["models"][model_type] = { "name": model_name, "train_metrics": train_metrics, "eval_metrics": eval_metrics, "train_time": train_time, "eval_time": eval_time, "model_paths": model_paths } print(f"{model_name} 模型训练完成,评估指标: {eval_metrics}") except Exception as e: print(f"{model_name} 模型训练或评估失败: {e}") results["models"][model_type] = { "name": model_name, "error": str(e) } # 确定最佳模型 best_model = None best_accuracy = -1 for model_type, model_result in results["models"].items(): if "eval_metrics" in model_result and "accuracy" in model_result["eval_metrics"]: accuracy = model_result["eval_metrics"]["accuracy"] if accuracy > best_accuracy: best_accuracy = accuracy best_model = model_type results["best_model"] = best_model # 保存比较结果 result_path = os.path.join( self.results_dir, f"comparison_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json" ) with open(result_path, 'w') as f: # 将numpy值转换为Python原生类型 def convert_numpy(obj): if isinstance(obj, np.integer): return int(obj) elif isinstance(obj, np.floating): return float(obj) elif isinstance(obj, np.ndarray): return obj.tolist() return obj json_results = {k: convert_numpy(v) for k, v in results.items()} json.dump(json_results, f, indent=2) print(f"\n比较结果已保存到: {result_path}") # 可视化比较结果 self.visualize_comparison(results) return results def visualize_comparison(self, results: Dict[str, Any]) -> str: """ 可视化比较结果 参数: results: 比较结果 返回: plot_path: 图表保存路径 """ # 准备数据 model_names = [] accuracies = [] precisions = [] recalls = [] f1_scores = [] train_times = [] for model_type, model_result in results["models"].items(): if "eval_metrics" in model_result: model_names.append(model_result["name"]) metrics = model_result["eval_metrics"] accuracies.append(metrics.get("accuracy", 0)) precisions.append(metrics.get("precision", 0)) recalls.append(metrics.get("recall", 0)) f1_scores.append(metrics.get("f1", 0)) train_times.append(model_result.get("train_time", 0)) # 创建图表 fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 7)) # 性能指标图 x = np.arange(len(model_names)) width = 0.2 ax1.bar(x - width*1.5, accuracies, width, label='准确率') ax1.bar(x - width/2, precisions, width, label='精确率') ax1.bar(x + width/2, recalls, width, label='召回率') ax1.bar(x + width*1.5, f1_scores, width, label='F1分数') ax1.set_ylabel('得分') ax1.set_title('模型性能比较') ax1.set_xticks(x) ax1.set_xticklabels(model_names) ax1.legend() ax1.set_ylim(0, 1.1) # 为每个柱子添加数值标签 for i, v in enumerate(accuracies): ax1.text(i - width*1.5, v + 0.02, f'{v:.2f}', ha='center', va='bottom', fontsize=8) for i, v in enumerate(precisions): ax1.text(i - width/2, v + 0.02, f'{v:.2f}', ha='center', va='bottom', fontsize=8) for i, v in enumerate(recalls): ax1.text(i + width/2, v + 0.02, f'{v:.2f}', ha='center', va='bottom', fontsize=8) for i, v in enumerate(f1_scores): ax1.text(i + width*1.5, v + 0.02, f'{v:.2f}', ha='center', va='bottom', fontsize=8) # 训练时间图 ax2.bar(model_names, train_times, color='skyblue') ax2.set_ylabel('时间 (秒)') ax2.set_title('模型训练时间比较') # 为每个柱子添加数值标签 for i, v in enumerate(train_times): ax2.text(i, v + 0.1, f'{v:.1f}s', ha='center', va='bottom') # 标记最佳模型 best_model = results.get("best_model") if best_model and best_model in results["models"]: best_model_name = results["models"][best_model]["name"] best_index = model_names.index(best_model_name) ax1.get_xticklabels()[best_index].set_color('red') ax1.get_xticklabels()[best_index].set_weight('bold') ax2.get_xticklabels()[best_index].set_color('red') ax2.get_xticklabels()[best_index].set_weight('bold') # 添加总标题 plt.suptitle('猫叫声意图分类模型比较', fontsize=16) # 保存图表 plot_path = os.path.join( self.results_dir, f"comparison_plot_{datetime.now().strftime('%Y%m%d_%H%M%S')}.png" ) plt.tight_layout() plt.subplots_adjust(top=0.9) plt.savefig(plot_path, dpi=300) plt.close() print(f"比较图表已保存到: {plot_path}") return plot_path def load_best_model(self, comparison_result_path: str, cat_name: Optional[str] = None) -> Any: """ 加载比较结果中的最佳模型 参数: comparison_result_path: 比较结果文件路径 cat_name: 猫咪名称,默认为None(通用模型) 返回: model: 加载的模型 """ # 加载比较结果 with open(comparison_result_path, 'r') as f: results = json.load(f) # 获取最佳模型类型 best_model_type = results.get("best_model") if not best_model_type: raise ValueError("比较结果中没有最佳模型") # 获取最佳模型信息 best_model_info = results["models"].get(best_model_type) if not best_model_info or "model_paths" not in best_model_info: raise ValueError(f"无法获取最佳模型 {best_model_type} 的路径信息") # 获取模型类 model_class = self.model_types[best_model_type]["class"] model_params = self.model_types[best_model_type]["params"] # 创建模型 model = model_class(**model_params) # 确定模型目录 model_dir = os.path.dirname(best_model_info["model_paths"]["model"]) # 加载模型 model.load_model(model_dir, cat_name) return model # 示例用法 if __name__ == "__main__": # 创建一些模拟数据 np.random.seed(42) n_samples = 50 n_features = 1024 n_timesteps = 10 # 生成特征序列 features = [] labels = [] for i in range(n_samples): # 生成一个随机特征序列 feature = np.random.randn(n_timesteps, n_features) features.append(feature) # 生成标签 if i < n_samples / 3: labels.append("快乐") elif i < 2 * n_samples / 3: labels.append("愤怒") else: labels.append("饥饿") # 创建比较器 comparator = ModelComparator() # 比较模型 results = comparator.compare_models(features, labels) # 加载最佳模型 best_model = comparator.load_best_model( os.path.join(comparator.results_dir, f"comparison_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json") ) # 使用最佳模型进行预测 prediction = best_model.predict(features[0]) print(f"最佳模型预测结果: {prediction}")