feat: first commit
This commit is contained in:
366
src/model_comparator.py
Normal file
366
src/model_comparator.py
Normal file
@@ -0,0 +1,366 @@
|
||||
"""
|
||||
模型比较器模块 - 用于比较不同猫叫声意图分类模型的性能
|
||||
|
||||
该模块提供了比较DAG-HMM、深度学习、SVM和随机森林等不同分类方法的功能,
|
||||
帮助用户选择最适合其数据集的模型。
|
||||
"""
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import json
|
||||
import matplotlib.pyplot as plt
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
from src.cat_intent_classifier_v2 import CatIntentClassifier
|
||||
from src.dag_hmm_classifier import DAGHMMClassifier
|
||||
|
||||
class ModelComparator:
|
||||
"""模型比较器类,用于比较不同猫叫声意图分类模型的性能"""
|
||||
|
||||
def __init__(self, results_dir: str = "./comparison_results"):
|
||||
"""
|
||||
初始化模型比较器
|
||||
|
||||
参数:
|
||||
results_dir: 结果保存目录
|
||||
"""
|
||||
self.results_dir = results_dir
|
||||
os.makedirs(results_dir, exist_ok=True)
|
||||
|
||||
# 支持的模型类型
|
||||
self.model_types = {
|
||||
"dag_hmm": {
|
||||
"name": "DAG-HMM",
|
||||
"class": DAGHMMClassifier,
|
||||
"params": {"n_states": 5, "n_mix": 3}
|
||||
},
|
||||
"dl": {
|
||||
"name": "深度学习",
|
||||
"class": CatIntentClassifier,
|
||||
"params": {}
|
||||
}
|
||||
}
|
||||
|
||||
def compare_models(self, features: List[np.ndarray], labels: List[str],
|
||||
model_types: List[str] = None, test_size: float = 0.2,
|
||||
cat_name: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
比较不同模型的性能
|
||||
|
||||
参数:
|
||||
features: 特征序列列表
|
||||
labels: 标签列表
|
||||
model_types: 要比较的模型类型列表,默认为所有支持的模型
|
||||
test_size: 测试集比例
|
||||
cat_name: 猫咪名称,默认为None(通用模型)
|
||||
|
||||
返回:
|
||||
results: 比较结果
|
||||
"""
|
||||
if model_types is None:
|
||||
model_types = list(self.model_types.keys())
|
||||
|
||||
# 验证模型类型
|
||||
for model_type in model_types:
|
||||
if model_type not in self.model_types:
|
||||
raise ValueError(f"不支持的模型类型: {model_type}")
|
||||
|
||||
# 划分训练集和测试集
|
||||
from sklearn.model_selection import train_test_split
|
||||
_, test_features, _, test_labels = train_test_split(
|
||||
features, labels, test_size=test_size, random_state=42, stratify=labels
|
||||
)
|
||||
train_features, train_labels = features, labels
|
||||
print(f"训练集大小: {len(train_features)}, 测试集大小: {len(test_features)}")
|
||||
|
||||
# 比较结果
|
||||
results = {
|
||||
"models": {},
|
||||
"best_model": None,
|
||||
"comparison_time": datetime.now().isoformat(),
|
||||
"dataset_info": {
|
||||
"total_samples": len(features),
|
||||
"train_samples": len(train_features),
|
||||
"test_samples": len(test_features),
|
||||
"classes": sorted(list(set(labels))),
|
||||
"class_distribution": {label: labels.count(label) for label in set(labels)}
|
||||
}
|
||||
}
|
||||
|
||||
# 训练和评估每个模型
|
||||
for model_type in model_types:
|
||||
model_info = self.model_types[model_type]
|
||||
model_name = model_info["name"]
|
||||
model_class = model_info["class"]
|
||||
model_params = model_info["params"]
|
||||
|
||||
print(f"\n开始训练和评估 {model_name} 模型...")
|
||||
|
||||
try:
|
||||
# 创建模型
|
||||
model = model_class(**model_params)
|
||||
|
||||
# 记录训练开始时间
|
||||
train_start_time = time.time()
|
||||
|
||||
# 训练模型
|
||||
train_metrics = model.train(train_features, train_labels)
|
||||
|
||||
# 记录训练结束时间
|
||||
train_end_time = time.time()
|
||||
train_time = train_end_time - train_start_time
|
||||
|
||||
# 记录评估开始时间
|
||||
eval_start_time = time.time()
|
||||
|
||||
# 评估模型
|
||||
eval_metrics = model.evaluate(test_features, test_labels)
|
||||
|
||||
# 记录评估结束时间
|
||||
eval_end_time = time.time()
|
||||
eval_time = eval_end_time - eval_start_time
|
||||
|
||||
# 保存模型
|
||||
model_dir = os.path.join(self.results_dir, "models")
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
model_paths = model.save_model(model_dir, cat_name)
|
||||
|
||||
# 记录结果
|
||||
results["models"][model_type] = {
|
||||
"name": model_name,
|
||||
"train_metrics": train_metrics,
|
||||
"eval_metrics": eval_metrics,
|
||||
"train_time": train_time,
|
||||
"eval_time": eval_time,
|
||||
"model_paths": model_paths
|
||||
}
|
||||
|
||||
print(f"{model_name} 模型训练完成,评估指标: {eval_metrics}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"{model_name} 模型训练或评估失败: {e}")
|
||||
results["models"][model_type] = {
|
||||
"name": model_name,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
# 确定最佳模型
|
||||
best_model = None
|
||||
best_accuracy = -1
|
||||
|
||||
for model_type, model_result in results["models"].items():
|
||||
if "eval_metrics" in model_result and "accuracy" in model_result["eval_metrics"]:
|
||||
accuracy = model_result["eval_metrics"]["accuracy"]
|
||||
if accuracy > best_accuracy:
|
||||
best_accuracy = accuracy
|
||||
best_model = model_type
|
||||
|
||||
results["best_model"] = best_model
|
||||
|
||||
# 保存比较结果
|
||||
result_path = os.path.join(
|
||||
self.results_dir,
|
||||
f"comparison_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
|
||||
)
|
||||
|
||||
with open(result_path, 'w') as f:
|
||||
# 将numpy值转换为Python原生类型
|
||||
def convert_numpy(obj):
|
||||
if isinstance(obj, np.integer):
|
||||
return int(obj)
|
||||
elif isinstance(obj, np.floating):
|
||||
return float(obj)
|
||||
elif isinstance(obj, np.ndarray):
|
||||
return obj.tolist()
|
||||
return obj
|
||||
|
||||
json_results = {k: convert_numpy(v) for k, v in results.items()}
|
||||
json.dump(json_results, f, indent=2)
|
||||
|
||||
print(f"\n比较结果已保存到: {result_path}")
|
||||
|
||||
# 可视化比较结果
|
||||
self.visualize_comparison(results)
|
||||
|
||||
return results
|
||||
|
||||
def visualize_comparison(self, results: Dict[str, Any]) -> str:
|
||||
"""
|
||||
可视化比较结果
|
||||
|
||||
参数:
|
||||
results: 比较结果
|
||||
|
||||
返回:
|
||||
plot_path: 图表保存路径
|
||||
"""
|
||||
# 准备数据
|
||||
model_names = []
|
||||
accuracies = []
|
||||
precisions = []
|
||||
recalls = []
|
||||
f1_scores = []
|
||||
train_times = []
|
||||
|
||||
for model_type, model_result in results["models"].items():
|
||||
if "eval_metrics" in model_result:
|
||||
model_names.append(model_result["name"])
|
||||
|
||||
metrics = model_result["eval_metrics"]
|
||||
accuracies.append(metrics.get("accuracy", 0))
|
||||
precisions.append(metrics.get("precision", 0))
|
||||
recalls.append(metrics.get("recall", 0))
|
||||
f1_scores.append(metrics.get("f1", 0))
|
||||
|
||||
train_times.append(model_result.get("train_time", 0))
|
||||
|
||||
# 创建图表
|
||||
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 7))
|
||||
|
||||
# 性能指标图
|
||||
x = np.arange(len(model_names))
|
||||
width = 0.2
|
||||
|
||||
ax1.bar(x - width*1.5, accuracies, width, label='准确率')
|
||||
ax1.bar(x - width/2, precisions, width, label='精确率')
|
||||
ax1.bar(x + width/2, recalls, width, label='召回率')
|
||||
ax1.bar(x + width*1.5, f1_scores, width, label='F1分数')
|
||||
|
||||
ax1.set_ylabel('得分')
|
||||
ax1.set_title('模型性能比较')
|
||||
ax1.set_xticks(x)
|
||||
ax1.set_xticklabels(model_names)
|
||||
ax1.legend()
|
||||
ax1.set_ylim(0, 1.1)
|
||||
|
||||
# 为每个柱子添加数值标签
|
||||
for i, v in enumerate(accuracies):
|
||||
ax1.text(i - width*1.5, v + 0.02, f'{v:.2f}', ha='center', va='bottom', fontsize=8)
|
||||
for i, v in enumerate(precisions):
|
||||
ax1.text(i - width/2, v + 0.02, f'{v:.2f}', ha='center', va='bottom', fontsize=8)
|
||||
for i, v in enumerate(recalls):
|
||||
ax1.text(i + width/2, v + 0.02, f'{v:.2f}', ha='center', va='bottom', fontsize=8)
|
||||
for i, v in enumerate(f1_scores):
|
||||
ax1.text(i + width*1.5, v + 0.02, f'{v:.2f}', ha='center', va='bottom', fontsize=8)
|
||||
|
||||
# 训练时间图
|
||||
ax2.bar(model_names, train_times, color='skyblue')
|
||||
ax2.set_ylabel('时间 (秒)')
|
||||
ax2.set_title('模型训练时间比较')
|
||||
|
||||
# 为每个柱子添加数值标签
|
||||
for i, v in enumerate(train_times):
|
||||
ax2.text(i, v + 0.1, f'{v:.1f}s', ha='center', va='bottom')
|
||||
|
||||
# 标记最佳模型
|
||||
best_model = results.get("best_model")
|
||||
if best_model and best_model in results["models"]:
|
||||
best_model_name = results["models"][best_model]["name"]
|
||||
best_index = model_names.index(best_model_name)
|
||||
|
||||
ax1.get_xticklabels()[best_index].set_color('red')
|
||||
ax1.get_xticklabels()[best_index].set_weight('bold')
|
||||
|
||||
ax2.get_xticklabels()[best_index].set_color('red')
|
||||
ax2.get_xticklabels()[best_index].set_weight('bold')
|
||||
|
||||
# 添加总标题
|
||||
plt.suptitle('猫叫声意图分类模型比较', fontsize=16)
|
||||
|
||||
# 保存图表
|
||||
plot_path = os.path.join(
|
||||
self.results_dir,
|
||||
f"comparison_plot_{datetime.now().strftime('%Y%m%d_%H%M%S')}.png"
|
||||
)
|
||||
plt.tight_layout()
|
||||
plt.subplots_adjust(top=0.9)
|
||||
plt.savefig(plot_path, dpi=300)
|
||||
plt.close()
|
||||
|
||||
print(f"比较图表已保存到: {plot_path}")
|
||||
|
||||
return plot_path
|
||||
|
||||
def load_best_model(self, comparison_result_path: str, cat_name: Optional[str] = None) -> Any:
|
||||
"""
|
||||
加载比较结果中的最佳模型
|
||||
|
||||
参数:
|
||||
comparison_result_path: 比较结果文件路径
|
||||
cat_name: 猫咪名称,默认为None(通用模型)
|
||||
|
||||
返回:
|
||||
model: 加载的模型
|
||||
"""
|
||||
# 加载比较结果
|
||||
with open(comparison_result_path, 'r') as f:
|
||||
results = json.load(f)
|
||||
|
||||
# 获取最佳模型类型
|
||||
best_model_type = results.get("best_model")
|
||||
if not best_model_type:
|
||||
raise ValueError("比较结果中没有最佳模型")
|
||||
|
||||
# 获取最佳模型信息
|
||||
best_model_info = results["models"].get(best_model_type)
|
||||
if not best_model_info or "model_paths" not in best_model_info:
|
||||
raise ValueError(f"无法获取最佳模型 {best_model_type} 的路径信息")
|
||||
|
||||
# 获取模型类
|
||||
model_class = self.model_types[best_model_type]["class"]
|
||||
model_params = self.model_types[best_model_type]["params"]
|
||||
|
||||
# 创建模型
|
||||
model = model_class(**model_params)
|
||||
|
||||
# 确定模型目录
|
||||
model_dir = os.path.dirname(best_model_info["model_paths"]["model"])
|
||||
|
||||
# 加载模型
|
||||
model.load_model(model_dir, cat_name)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
# 示例用法
|
||||
if __name__ == "__main__":
|
||||
# 创建一些模拟数据
|
||||
np.random.seed(42)
|
||||
n_samples = 50
|
||||
n_features = 1024
|
||||
n_timesteps = 10
|
||||
|
||||
# 生成特征序列
|
||||
features = []
|
||||
labels = []
|
||||
|
||||
for i in range(n_samples):
|
||||
# 生成一个随机特征序列
|
||||
feature = np.random.randn(n_timesteps, n_features)
|
||||
features.append(feature)
|
||||
|
||||
# 生成标签
|
||||
if i < n_samples / 3:
|
||||
labels.append("快乐")
|
||||
elif i < 2 * n_samples / 3:
|
||||
labels.append("愤怒")
|
||||
else:
|
||||
labels.append("饥饿")
|
||||
|
||||
# 创建比较器
|
||||
comparator = ModelComparator()
|
||||
|
||||
# 比较模型
|
||||
results = comparator.compare_models(features, labels)
|
||||
|
||||
# 加载最佳模型
|
||||
best_model = comparator.load_best_model(
|
||||
os.path.join(comparator.results_dir,
|
||||
f"comparison_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json")
|
||||
)
|
||||
|
||||
# 使用最佳模型进行预测
|
||||
prediction = best_model.predict(features[0])
|
||||
print(f"最佳模型预测结果: {prediction}")
|
||||
Reference in New Issue
Block a user