367 lines
13 KiB
Python
367 lines
13 KiB
Python
"""
|
||
模型比较器模块 - 用于比较不同猫叫声意图分类模型的性能
|
||
|
||
该模块提供了比较DAG-HMM、深度学习、SVM和随机森林等不同分类方法的功能,
|
||
帮助用户选择最适合其数据集的模型。
|
||
"""
|
||
|
||
import os
|
||
import numpy as np
|
||
import json
|
||
import matplotlib.pyplot as plt
|
||
from typing import Dict, Any, List, Optional, Tuple
|
||
import time
|
||
from datetime import datetime
|
||
|
||
from src.cat_intent_classifier_v2 import CatIntentClassifier
|
||
from src.dag_hmm_classifier import DAGHMMClassifier
|
||
|
||
class ModelComparator:
|
||
"""模型比较器类,用于比较不同猫叫声意图分类模型的性能"""
|
||
|
||
def __init__(self, results_dir: str = "./comparison_results"):
|
||
"""
|
||
初始化模型比较器
|
||
|
||
参数:
|
||
results_dir: 结果保存目录
|
||
"""
|
||
self.results_dir = results_dir
|
||
os.makedirs(results_dir, exist_ok=True)
|
||
|
||
# 支持的模型类型
|
||
self.model_types = {
|
||
"dag_hmm": {
|
||
"name": "DAG-HMM",
|
||
"class": DAGHMMClassifier,
|
||
"params": {"n_states": 5, "n_mix": 3}
|
||
},
|
||
"dl": {
|
||
"name": "深度学习",
|
||
"class": CatIntentClassifier,
|
||
"params": {}
|
||
}
|
||
}
|
||
|
||
def compare_models(self, features: List[np.ndarray], labels: List[str],
|
||
model_types: List[str] = None, test_size: float = 0.2,
|
||
cat_name: Optional[str] = None) -> Dict[str, Any]:
|
||
"""
|
||
比较不同模型的性能
|
||
|
||
参数:
|
||
features: 特征序列列表
|
||
labels: 标签列表
|
||
model_types: 要比较的模型类型列表,默认为所有支持的模型
|
||
test_size: 测试集比例
|
||
cat_name: 猫咪名称,默认为None(通用模型)
|
||
|
||
返回:
|
||
results: 比较结果
|
||
"""
|
||
if model_types is None:
|
||
model_types = list(self.model_types.keys())
|
||
|
||
# 验证模型类型
|
||
for model_type in model_types:
|
||
if model_type not in self.model_types:
|
||
raise ValueError(f"不支持的模型类型: {model_type}")
|
||
|
||
# 划分训练集和测试集
|
||
from sklearn.model_selection import train_test_split
|
||
_, test_features, _, test_labels = train_test_split(
|
||
features, labels, test_size=test_size, random_state=42, stratify=labels
|
||
)
|
||
train_features, train_labels = features, labels
|
||
print(f"训练集大小: {len(train_features)}, 测试集大小: {len(test_features)}")
|
||
|
||
# 比较结果
|
||
results = {
|
||
"models": {},
|
||
"best_model": None,
|
||
"comparison_time": datetime.now().isoformat(),
|
||
"dataset_info": {
|
||
"total_samples": len(features),
|
||
"train_samples": len(train_features),
|
||
"test_samples": len(test_features),
|
||
"classes": sorted(list(set(labels))),
|
||
"class_distribution": {label: labels.count(label) for label in set(labels)}
|
||
}
|
||
}
|
||
|
||
# 训练和评估每个模型
|
||
for model_type in model_types:
|
||
model_info = self.model_types[model_type]
|
||
model_name = model_info["name"]
|
||
model_class = model_info["class"]
|
||
model_params = model_info["params"]
|
||
|
||
print(f"\n开始训练和评估 {model_name} 模型...")
|
||
|
||
try:
|
||
# 创建模型
|
||
model = model_class(**model_params)
|
||
|
||
# 记录训练开始时间
|
||
train_start_time = time.time()
|
||
|
||
# 训练模型
|
||
train_metrics = model.train(train_features, train_labels)
|
||
|
||
# 记录训练结束时间
|
||
train_end_time = time.time()
|
||
train_time = train_end_time - train_start_time
|
||
|
||
# 记录评估开始时间
|
||
eval_start_time = time.time()
|
||
|
||
# 评估模型
|
||
eval_metrics = model.evaluate(test_features, test_labels)
|
||
|
||
# 记录评估结束时间
|
||
eval_end_time = time.time()
|
||
eval_time = eval_end_time - eval_start_time
|
||
|
||
# 保存模型
|
||
model_dir = os.path.join(self.results_dir, "models")
|
||
os.makedirs(model_dir, exist_ok=True)
|
||
model_paths = model.save_model(model_dir, cat_name)
|
||
|
||
# 记录结果
|
||
results["models"][model_type] = {
|
||
"name": model_name,
|
||
"train_metrics": train_metrics,
|
||
"eval_metrics": eval_metrics,
|
||
"train_time": train_time,
|
||
"eval_time": eval_time,
|
||
"model_paths": model_paths
|
||
}
|
||
|
||
print(f"{model_name} 模型训练完成,评估指标: {eval_metrics}")
|
||
|
||
except Exception as e:
|
||
print(f"{model_name} 模型训练或评估失败: {e}")
|
||
results["models"][model_type] = {
|
||
"name": model_name,
|
||
"error": str(e)
|
||
}
|
||
|
||
# 确定最佳模型
|
||
best_model = None
|
||
best_accuracy = -1
|
||
|
||
for model_type, model_result in results["models"].items():
|
||
if "eval_metrics" in model_result and "accuracy" in model_result["eval_metrics"]:
|
||
accuracy = model_result["eval_metrics"]["accuracy"]
|
||
if accuracy > best_accuracy:
|
||
best_accuracy = accuracy
|
||
best_model = model_type
|
||
|
||
results["best_model"] = best_model
|
||
|
||
# 保存比较结果
|
||
result_path = os.path.join(
|
||
self.results_dir,
|
||
f"comparison_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
|
||
)
|
||
|
||
with open(result_path, 'w') as f:
|
||
# 将numpy值转换为Python原生类型
|
||
def convert_numpy(obj):
|
||
if isinstance(obj, np.integer):
|
||
return int(obj)
|
||
elif isinstance(obj, np.floating):
|
||
return float(obj)
|
||
elif isinstance(obj, np.ndarray):
|
||
return obj.tolist()
|
||
return obj
|
||
|
||
json_results = {k: convert_numpy(v) for k, v in results.items()}
|
||
json.dump(json_results, f, indent=2)
|
||
|
||
print(f"\n比较结果已保存到: {result_path}")
|
||
|
||
# 可视化比较结果
|
||
self.visualize_comparison(results)
|
||
|
||
return results
|
||
|
||
def visualize_comparison(self, results: Dict[str, Any]) -> str:
|
||
"""
|
||
可视化比较结果
|
||
|
||
参数:
|
||
results: 比较结果
|
||
|
||
返回:
|
||
plot_path: 图表保存路径
|
||
"""
|
||
# 准备数据
|
||
model_names = []
|
||
accuracies = []
|
||
precisions = []
|
||
recalls = []
|
||
f1_scores = []
|
||
train_times = []
|
||
|
||
for model_type, model_result in results["models"].items():
|
||
if "eval_metrics" in model_result:
|
||
model_names.append(model_result["name"])
|
||
|
||
metrics = model_result["eval_metrics"]
|
||
accuracies.append(metrics.get("accuracy", 0))
|
||
precisions.append(metrics.get("precision", 0))
|
||
recalls.append(metrics.get("recall", 0))
|
||
f1_scores.append(metrics.get("f1", 0))
|
||
|
||
train_times.append(model_result.get("train_time", 0))
|
||
|
||
# 创建图表
|
||
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 7))
|
||
|
||
# 性能指标图
|
||
x = np.arange(len(model_names))
|
||
width = 0.2
|
||
|
||
ax1.bar(x - width*1.5, accuracies, width, label='准确率')
|
||
ax1.bar(x - width/2, precisions, width, label='精确率')
|
||
ax1.bar(x + width/2, recalls, width, label='召回率')
|
||
ax1.bar(x + width*1.5, f1_scores, width, label='F1分数')
|
||
|
||
ax1.set_ylabel('得分')
|
||
ax1.set_title('模型性能比较')
|
||
ax1.set_xticks(x)
|
||
ax1.set_xticklabels(model_names)
|
||
ax1.legend()
|
||
ax1.set_ylim(0, 1.1)
|
||
|
||
# 为每个柱子添加数值标签
|
||
for i, v in enumerate(accuracies):
|
||
ax1.text(i - width*1.5, v + 0.02, f'{v:.2f}', ha='center', va='bottom', fontsize=8)
|
||
for i, v in enumerate(precisions):
|
||
ax1.text(i - width/2, v + 0.02, f'{v:.2f}', ha='center', va='bottom', fontsize=8)
|
||
for i, v in enumerate(recalls):
|
||
ax1.text(i + width/2, v + 0.02, f'{v:.2f}', ha='center', va='bottom', fontsize=8)
|
||
for i, v in enumerate(f1_scores):
|
||
ax1.text(i + width*1.5, v + 0.02, f'{v:.2f}', ha='center', va='bottom', fontsize=8)
|
||
|
||
# 训练时间图
|
||
ax2.bar(model_names, train_times, color='skyblue')
|
||
ax2.set_ylabel('时间 (秒)')
|
||
ax2.set_title('模型训练时间比较')
|
||
|
||
# 为每个柱子添加数值标签
|
||
for i, v in enumerate(train_times):
|
||
ax2.text(i, v + 0.1, f'{v:.1f}s', ha='center', va='bottom')
|
||
|
||
# 标记最佳模型
|
||
best_model = results.get("best_model")
|
||
if best_model and best_model in results["models"]:
|
||
best_model_name = results["models"][best_model]["name"]
|
||
best_index = model_names.index(best_model_name)
|
||
|
||
ax1.get_xticklabels()[best_index].set_color('red')
|
||
ax1.get_xticklabels()[best_index].set_weight('bold')
|
||
|
||
ax2.get_xticklabels()[best_index].set_color('red')
|
||
ax2.get_xticklabels()[best_index].set_weight('bold')
|
||
|
||
# 添加总标题
|
||
plt.suptitle('猫叫声意图分类模型比较', fontsize=16)
|
||
|
||
# 保存图表
|
||
plot_path = os.path.join(
|
||
self.results_dir,
|
||
f"comparison_plot_{datetime.now().strftime('%Y%m%d_%H%M%S')}.png"
|
||
)
|
||
plt.tight_layout()
|
||
plt.subplots_adjust(top=0.9)
|
||
plt.savefig(plot_path, dpi=300)
|
||
plt.close()
|
||
|
||
print(f"比较图表已保存到: {plot_path}")
|
||
|
||
return plot_path
|
||
|
||
def load_best_model(self, comparison_result_path: str, cat_name: Optional[str] = None) -> Any:
|
||
"""
|
||
加载比较结果中的最佳模型
|
||
|
||
参数:
|
||
comparison_result_path: 比较结果文件路径
|
||
cat_name: 猫咪名称,默认为None(通用模型)
|
||
|
||
返回:
|
||
model: 加载的模型
|
||
"""
|
||
# 加载比较结果
|
||
with open(comparison_result_path, 'r') as f:
|
||
results = json.load(f)
|
||
|
||
# 获取最佳模型类型
|
||
best_model_type = results.get("best_model")
|
||
if not best_model_type:
|
||
raise ValueError("比较结果中没有最佳模型")
|
||
|
||
# 获取最佳模型信息
|
||
best_model_info = results["models"].get(best_model_type)
|
||
if not best_model_info or "model_paths" not in best_model_info:
|
||
raise ValueError(f"无法获取最佳模型 {best_model_type} 的路径信息")
|
||
|
||
# 获取模型类
|
||
model_class = self.model_types[best_model_type]["class"]
|
||
model_params = self.model_types[best_model_type]["params"]
|
||
|
||
# 创建模型
|
||
model = model_class(**model_params)
|
||
|
||
# 确定模型目录
|
||
model_dir = os.path.dirname(best_model_info["model_paths"]["model"])
|
||
|
||
# 加载模型
|
||
model.load_model(model_dir, cat_name)
|
||
|
||
return model
|
||
|
||
|
||
# 示例用法
|
||
if __name__ == "__main__":
|
||
# 创建一些模拟数据
|
||
np.random.seed(42)
|
||
n_samples = 50
|
||
n_features = 1024
|
||
n_timesteps = 10
|
||
|
||
# 生成特征序列
|
||
features = []
|
||
labels = []
|
||
|
||
for i in range(n_samples):
|
||
# 生成一个随机特征序列
|
||
feature = np.random.randn(n_timesteps, n_features)
|
||
features.append(feature)
|
||
|
||
# 生成标签
|
||
if i < n_samples / 3:
|
||
labels.append("快乐")
|
||
elif i < 2 * n_samples / 3:
|
||
labels.append("愤怒")
|
||
else:
|
||
labels.append("饥饿")
|
||
|
||
# 创建比较器
|
||
comparator = ModelComparator()
|
||
|
||
# 比较模型
|
||
results = comparator.compare_models(features, labels)
|
||
|
||
# 加载最佳模型
|
||
best_model = comparator.load_best_model(
|
||
os.path.join(comparator.results_dir,
|
||
f"comparison_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json")
|
||
)
|
||
|
||
# 使用最佳模型进行预测
|
||
prediction = best_model.predict(features[0])
|
||
print(f"最佳模型预测结果: {prediction}")
|