Files
petshy/src/model_comparator.py
2025-10-08 20:39:09 +08:00

367 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
模型比较器模块 - 用于比较不同猫叫声意图分类模型的性能
该模块提供了比较DAG-HMM、深度学习、SVM和随机森林等不同分类方法的功能
帮助用户选择最适合其数据集的模型。
"""
import os
import numpy as np
import json
import matplotlib.pyplot as plt
from typing import Dict, Any, List, Optional, Tuple
import time
from datetime import datetime
from src.cat_intent_classifier_v2 import CatIntentClassifier
from src.dag_hmm_classifier import DAGHMMClassifier
class ModelComparator:
"""模型比较器类,用于比较不同猫叫声意图分类模型的性能"""
def __init__(self, results_dir: str = "./comparison_results"):
"""
初始化模型比较器
参数:
results_dir: 结果保存目录
"""
self.results_dir = results_dir
os.makedirs(results_dir, exist_ok=True)
# 支持的模型类型
self.model_types = {
"dag_hmm": {
"name": "DAG-HMM",
"class": DAGHMMClassifier,
"params": {"n_states": 5, "n_mix": 3}
},
"dl": {
"name": "深度学习",
"class": CatIntentClassifier,
"params": {}
}
}
def compare_models(self, features: List[np.ndarray], labels: List[str],
model_types: List[str] = None, test_size: float = 0.2,
cat_name: Optional[str] = None) -> Dict[str, Any]:
"""
比较不同模型的性能
参数:
features: 特征序列列表
labels: 标签列表
model_types: 要比较的模型类型列表,默认为所有支持的模型
test_size: 测试集比例
cat_name: 猫咪名称默认为None通用模型
返回:
results: 比较结果
"""
if model_types is None:
model_types = list(self.model_types.keys())
# 验证模型类型
for model_type in model_types:
if model_type not in self.model_types:
raise ValueError(f"不支持的模型类型: {model_type}")
# 划分训练集和测试集
from sklearn.model_selection import train_test_split
_, test_features, _, test_labels = train_test_split(
features, labels, test_size=test_size, random_state=42, stratify=labels
)
train_features, train_labels = features, labels
print(f"训练集大小: {len(train_features)}, 测试集大小: {len(test_features)}")
# 比较结果
results = {
"models": {},
"best_model": None,
"comparison_time": datetime.now().isoformat(),
"dataset_info": {
"total_samples": len(features),
"train_samples": len(train_features),
"test_samples": len(test_features),
"classes": sorted(list(set(labels))),
"class_distribution": {label: labels.count(label) for label in set(labels)}
}
}
# 训练和评估每个模型
for model_type in model_types:
model_info = self.model_types[model_type]
model_name = model_info["name"]
model_class = model_info["class"]
model_params = model_info["params"]
print(f"\n开始训练和评估 {model_name} 模型...")
try:
# 创建模型
model = model_class(**model_params)
# 记录训练开始时间
train_start_time = time.time()
# 训练模型
train_metrics = model.train(train_features, train_labels)
# 记录训练结束时间
train_end_time = time.time()
train_time = train_end_time - train_start_time
# 记录评估开始时间
eval_start_time = time.time()
# 评估模型
eval_metrics = model.evaluate(test_features, test_labels)
# 记录评估结束时间
eval_end_time = time.time()
eval_time = eval_end_time - eval_start_time
# 保存模型
model_dir = os.path.join(self.results_dir, "models")
os.makedirs(model_dir, exist_ok=True)
model_paths = model.save_model(model_dir, cat_name)
# 记录结果
results["models"][model_type] = {
"name": model_name,
"train_metrics": train_metrics,
"eval_metrics": eval_metrics,
"train_time": train_time,
"eval_time": eval_time,
"model_paths": model_paths
}
print(f"{model_name} 模型训练完成,评估指标: {eval_metrics}")
except Exception as e:
print(f"{model_name} 模型训练或评估失败: {e}")
results["models"][model_type] = {
"name": model_name,
"error": str(e)
}
# 确定最佳模型
best_model = None
best_accuracy = -1
for model_type, model_result in results["models"].items():
if "eval_metrics" in model_result and "accuracy" in model_result["eval_metrics"]:
accuracy = model_result["eval_metrics"]["accuracy"]
if accuracy > best_accuracy:
best_accuracy = accuracy
best_model = model_type
results["best_model"] = best_model
# 保存比较结果
result_path = os.path.join(
self.results_dir,
f"comparison_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
)
with open(result_path, 'w') as f:
# 将numpy值转换为Python原生类型
def convert_numpy(obj):
if isinstance(obj, np.integer):
return int(obj)
elif isinstance(obj, np.floating):
return float(obj)
elif isinstance(obj, np.ndarray):
return obj.tolist()
return obj
json_results = {k: convert_numpy(v) for k, v in results.items()}
json.dump(json_results, f, indent=2)
print(f"\n比较结果已保存到: {result_path}")
# 可视化比较结果
self.visualize_comparison(results)
return results
def visualize_comparison(self, results: Dict[str, Any]) -> str:
"""
可视化比较结果
参数:
results: 比较结果
返回:
plot_path: 图表保存路径
"""
# 准备数据
model_names = []
accuracies = []
precisions = []
recalls = []
f1_scores = []
train_times = []
for model_type, model_result in results["models"].items():
if "eval_metrics" in model_result:
model_names.append(model_result["name"])
metrics = model_result["eval_metrics"]
accuracies.append(metrics.get("accuracy", 0))
precisions.append(metrics.get("precision", 0))
recalls.append(metrics.get("recall", 0))
f1_scores.append(metrics.get("f1", 0))
train_times.append(model_result.get("train_time", 0))
# 创建图表
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 7))
# 性能指标图
x = np.arange(len(model_names))
width = 0.2
ax1.bar(x - width*1.5, accuracies, width, label='准确率')
ax1.bar(x - width/2, precisions, width, label='精确率')
ax1.bar(x + width/2, recalls, width, label='召回率')
ax1.bar(x + width*1.5, f1_scores, width, label='F1分数')
ax1.set_ylabel('得分')
ax1.set_title('模型性能比较')
ax1.set_xticks(x)
ax1.set_xticklabels(model_names)
ax1.legend()
ax1.set_ylim(0, 1.1)
# 为每个柱子添加数值标签
for i, v in enumerate(accuracies):
ax1.text(i - width*1.5, v + 0.02, f'{v:.2f}', ha='center', va='bottom', fontsize=8)
for i, v in enumerate(precisions):
ax1.text(i - width/2, v + 0.02, f'{v:.2f}', ha='center', va='bottom', fontsize=8)
for i, v in enumerate(recalls):
ax1.text(i + width/2, v + 0.02, f'{v:.2f}', ha='center', va='bottom', fontsize=8)
for i, v in enumerate(f1_scores):
ax1.text(i + width*1.5, v + 0.02, f'{v:.2f}', ha='center', va='bottom', fontsize=8)
# 训练时间图
ax2.bar(model_names, train_times, color='skyblue')
ax2.set_ylabel('时间 (秒)')
ax2.set_title('模型训练时间比较')
# 为每个柱子添加数值标签
for i, v in enumerate(train_times):
ax2.text(i, v + 0.1, f'{v:.1f}s', ha='center', va='bottom')
# 标记最佳模型
best_model = results.get("best_model")
if best_model and best_model in results["models"]:
best_model_name = results["models"][best_model]["name"]
best_index = model_names.index(best_model_name)
ax1.get_xticklabels()[best_index].set_color('red')
ax1.get_xticklabels()[best_index].set_weight('bold')
ax2.get_xticklabels()[best_index].set_color('red')
ax2.get_xticklabels()[best_index].set_weight('bold')
# 添加总标题
plt.suptitle('猫叫声意图分类模型比较', fontsize=16)
# 保存图表
plot_path = os.path.join(
self.results_dir,
f"comparison_plot_{datetime.now().strftime('%Y%m%d_%H%M%S')}.png"
)
plt.tight_layout()
plt.subplots_adjust(top=0.9)
plt.savefig(plot_path, dpi=300)
plt.close()
print(f"比较图表已保存到: {plot_path}")
return plot_path
def load_best_model(self, comparison_result_path: str, cat_name: Optional[str] = None) -> Any:
"""
加载比较结果中的最佳模型
参数:
comparison_result_path: 比较结果文件路径
cat_name: 猫咪名称默认为None通用模型
返回:
model: 加载的模型
"""
# 加载比较结果
with open(comparison_result_path, 'r') as f:
results = json.load(f)
# 获取最佳模型类型
best_model_type = results.get("best_model")
if not best_model_type:
raise ValueError("比较结果中没有最佳模型")
# 获取最佳模型信息
best_model_info = results["models"].get(best_model_type)
if not best_model_info or "model_paths" not in best_model_info:
raise ValueError(f"无法获取最佳模型 {best_model_type} 的路径信息")
# 获取模型类
model_class = self.model_types[best_model_type]["class"]
model_params = self.model_types[best_model_type]["params"]
# 创建模型
model = model_class(**model_params)
# 确定模型目录
model_dir = os.path.dirname(best_model_info["model_paths"]["model"])
# 加载模型
model.load_model(model_dir, cat_name)
return model
# 示例用法
if __name__ == "__main__":
# 创建一些模拟数据
np.random.seed(42)
n_samples = 50
n_features = 1024
n_timesteps = 10
# 生成特征序列
features = []
labels = []
for i in range(n_samples):
# 生成一个随机特征序列
feature = np.random.randn(n_timesteps, n_features)
features.append(feature)
# 生成标签
if i < n_samples / 3:
labels.append("快乐")
elif i < 2 * n_samples / 3:
labels.append("愤怒")
else:
labels.append("饥饿")
# 创建比较器
comparator = ModelComparator()
# 比较模型
results = comparator.compare_models(features, labels)
# 加载最佳模型
best_model = comparator.load_best_model(
os.path.join(comparator.results_dir,
f"comparison_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json")
)
# 使用最佳模型进行预测
prediction = best_model.predict(features[0])
print(f"最佳模型预测结果: {prediction}")