Files
petshy/utils/optimization_manager.py
2025-10-08 20:39:09 +08:00

461 lines
15 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
优化管理器 - 统一管理所有优化模块的配置和状态
该模块提供了一个统一的接口来管理和配置所有的优化功能,
包括DAG-HMM优化、特征融合优化和HMM参数优化。
"""
import os
import json
import logging
from typing import Dict, Any, Optional, List
from dataclasses import dataclass
@dataclass
class OptimizationConfig:
"""优化配置数据类"""
enable_optimizations: bool = True
optimization_level: str = "full"
# DAG-HMM优化配置
dag_hmm_enabled: bool = True
max_states: int = 10
max_gaussians: int = 5
cv_folds: int = 3
# 特征融合优化配置
feature_fusion_enabled: bool = True
adaptive_learning: bool = True
feature_selection: bool = True
pca_components: int = 50
# HMM参数优化配置
hmm_optimization_enabled: bool = True
optimization_method: str = "grid_search"
early_stopping: bool = True
# 检测器优化配置
detector_optimization_enabled: bool = True
use_optimized_fusion: bool = True
default_model: str = "svm"
class OptimizationManager:
"""
优化管理器
统一管理所有优化模块的配置、状态和性能监控。
"""
def __init__(self, config_path: Optional[str] = None):
"""
初始化优化管理器
参数:
config_path: 配置文件路径
"""
self.config_path = config_path or self._get_default_config_path()
self.config = self._load_config()
self.optimization_status = {}
self.performance_metrics = {}
# 设置日志
self._setup_logging()
self.logger.info("优化管理器已初始化")
def _get_default_config_path(self) -> str:
"""获取默认配置文件路径"""
current_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.dirname(current_dir)
return os.path.join(project_root, "config", "optimization_config.json")
def _setup_logging(self):
"""设置日志"""
log_level = self.config.get("logging", {}).get("log_level", "INFO")
logging.basicConfig(
level=getattr(logging, log_level),
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
self.logger = logging.getLogger("OptimizationManager")
def _load_config(self) -> Dict[str, Any]:
"""
加载配置文件
返回:
config: 配置字典
"""
if os.path.exists(self.config_path):
try:
with open(self.config_path, 'r', encoding='utf-8') as f:
config = json.load(f)
return config
except Exception as e:
print(f"加载配置文件失败: {e}")
return self._get_default_config()
else:
print(f"配置文件不存在: {self.config_path}")
return self._get_default_config()
def _get_default_config(self) -> Dict[str, Any]:
"""获取默认配置"""
return {
"optimization_settings": {
"enable_optimizations": True,
"optimization_level": "full"
},
"dag_hmm_optimization": {
"enabled": True,
"max_states": 10,
"max_gaussians": 5,
"cv_folds": 3
},
"feature_fusion_optimization": {
"enabled": True,
"adaptive_learning": True,
"feature_selection": True,
"pca_components": 50,
"initial_weights": {
"temporal_modulation": 0.2,
"mfcc": 0.3,
"yamnet": 0.5
}
},
"hmm_parameter_optimization": {
"enabled": True,
"optimization_methods": ["grid_search"],
"early_stopping": True
},
"detector_optimization": {
"enabled": True,
"use_optimized_fusion": True,
"default_model": "svm"
}
}
def save_config(self) -> None:
"""保存配置文件"""
try:
os.makedirs(os.path.dirname(self.config_path), exist_ok=True)
with open(self.config_path, 'w', encoding='utf-8') as f:
json.dump(self.config, f, indent=2, ensure_ascii=False)
self.logger.info(f"配置已保存到: {self.config_path}")
except Exception as e:
self.logger.error(f"保存配置失败: {e}")
def get_optimization_config(self) -> OptimizationConfig:
"""
获取优化配置对象
返回:
config: 优化配置对象
"""
opt_settings = self.config.get("optimization_settings", {})
dag_hmm_config = self.config.get("dag_hmm_optimization", {})
fusion_config = self.config.get("feature_fusion_optimization", {})
hmm_config = self.config.get("hmm_parameter_optimization", {})
detector_config = self.config.get("detector_optimization", {})
return OptimizationConfig(
enable_optimizations=opt_settings.get("enable_optimizations", True),
optimization_level=opt_settings.get("optimization_level", "full"),
dag_hmm_enabled=dag_hmm_config.get("enabled", True),
max_states=dag_hmm_config.get("max_states", 10),
max_gaussians=dag_hmm_config.get("max_gaussians", 5),
cv_folds=dag_hmm_config.get("cv_folds", 3),
feature_fusion_enabled=fusion_config.get("enabled", True),
adaptive_learning=fusion_config.get("adaptive_learning", True),
feature_selection=fusion_config.get("feature_selection", True),
pca_components=fusion_config.get("pca_components", 50),
hmm_optimization_enabled=hmm_config.get("enabled", True),
optimization_method=hmm_config.get("optimization_methods", ["grid_search"])[0],
early_stopping=hmm_config.get("early_stopping", True),
detector_optimization_enabled=detector_config.get("enabled", True),
use_optimized_fusion=detector_config.get("use_optimized_fusion", True),
default_model=detector_config.get("default_model", "svm")
)
def is_optimization_enabled(self, optimization_type: str) -> bool:
"""
检查特定优化是否启用
参数:
optimization_type: 优化类型
返回:
enabled: 是否启用
"""
if not self.config.get("optimization_settings", {}).get("enable_optimizations", True):
return False
type_mapping = {
"dag_hmm": "dag_hmm_optimization",
"feature_fusion": "feature_fusion_optimization",
"hmm_parameter": "hmm_parameter_optimization",
"detector": "detector_optimization"
}
config_key = type_mapping.get(optimization_type)
if config_key:
return self.config.get(config_key, {}).get("enabled", True)
return False
def enable_optimization(self, optimization_type: str) -> None:
"""
启用特定优化
参数:
optimization_type: 优化类型
"""
type_mapping = {
"dag_hmm": "dag_hmm_optimization",
"feature_fusion": "feature_fusion_optimization",
"hmm_parameter": "hmm_parameter_optimization",
"detector": "detector_optimization"
}
config_key = type_mapping.get(optimization_type)
if config_key:
if config_key not in self.config:
self.config[config_key] = {}
self.config[config_key]["enabled"] = True
self.logger.info(f"已启用 {optimization_type} 优化")
def disable_optimization(self, optimization_type: str) -> None:
"""
禁用特定优化
参数:
optimization_type: 优化类型
"""
type_mapping = {
"dag_hmm": "dag_hmm_optimization",
"feature_fusion": "feature_fusion_optimization",
"hmm_parameter": "hmm_parameter_optimization",
"detector": "detector_optimization"
}
config_key = type_mapping.get(optimization_type)
if config_key:
if config_key not in self.config:
self.config[config_key] = {}
self.config[config_key]["enabled"] = False
self.logger.info(f"已禁用 {optimization_type} 优化")
def update_optimization_status(self, optimization_type: str, status: Dict[str, Any]) -> None:
"""
更新优化状态
参数:
optimization_type: 优化类型
status: 状态信息
"""
self.optimization_status[optimization_type] = {
**status,
"timestamp": self._get_timestamp()
}
if self.config.get("logging", {}).get("log_optimization_process", True):
self.logger.info(f"{optimization_type} 优化状态更新: {status}")
def record_performance_metrics(self, component: str, metrics: Dict[str, Any]) -> None:
"""
记录性能指标
参数:
component: 组件名称
metrics: 性能指标
"""
if component not in self.performance_metrics:
self.performance_metrics[component] = []
self.performance_metrics[component].append({
**metrics,
"timestamp": self._get_timestamp()
})
if self.config.get("logging", {}).get("log_performance_metrics", True):
self.logger.info(f"{component} 性能指标: {metrics}")
def get_performance_summary(self) -> Dict[str, Any]:
"""
获取性能摘要
返回:
summary: 性能摘要
"""
summary = {}
for component, metrics_list in self.performance_metrics.items():
if metrics_list:
latest_metrics = metrics_list[-1]
summary[component] = {
"latest_metrics": latest_metrics,
"total_records": len(metrics_list)
}
return summary
def check_performance_targets(self) -> Dict[str, bool]:
"""
检查是否达到性能目标
返回:
results: 目标达成情况
"""
targets = self.config.get("performance_targets", {})
results = {}
# 检查猫叫声检测准确率
if "cat_detection_accuracy" in targets:
target = targets["cat_detection_accuracy"]
current = self._get_latest_metric("detector", "accuracy")
results["cat_detection_accuracy"] = current >= target if current is not None else False
# 检查意图分类准确率
if "intent_classification_accuracy" in targets:
target = targets["intent_classification_accuracy"]
current = self._get_latest_metric("classifier", "accuracy")
results["intent_classification_accuracy"] = current >= target if current is not None else False
return results
def _get_latest_metric(self, component: str, metric_name: str) -> Optional[float]:
"""获取最新的指标值"""
if component in self.performance_metrics and self.performance_metrics[component]:
latest = self.performance_metrics[component][-1]
return latest.get(metric_name)
return None
def _get_timestamp(self) -> str:
"""获取当前时间戳"""
from datetime import datetime
return datetime.now().isoformat()
def get_system_status(self) -> Dict[str, Any]:
"""
获取系统状态
返回:
status: 系统状态
"""
config = self.get_optimization_config()
return {
"optimization_enabled": config.enable_optimizations,
"optimization_level": config.optimization_level,
"optimizations": {
"dag_hmm": config.dag_hmm_enabled,
"feature_fusion": config.feature_fusion_enabled,
"hmm_parameter": config.hmm_optimization_enabled,
"detector": config.detector_optimization_enabled
},
"optimization_status": self.optimization_status,
"performance_summary": self.get_performance_summary(),
"performance_targets": self.check_performance_targets()
}
def generate_optimization_report(self) -> Dict[str, Any]:
"""
生成优化报告
返回:
report: 优化报告
"""
return {
"config": self.config,
"system_status": self.get_system_status(),
"performance_metrics": self.performance_metrics,
"optimization_status": self.optimization_status,
"timestamp": self._get_timestamp()
}
def export_report(self, output_path: str) -> None:
"""
导出优化报告
参数:
output_path: 输出路径
"""
report = self.generate_optimization_report()
try:
os.makedirs(os.path.dirname(output_path), exist_ok=True)
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(report, f, indent=2, ensure_ascii=False)
self.logger.info(f"优化报告已导出到: {output_path}")
except Exception as e:
self.logger.error(f"导出报告失败: {e}")
# 全局优化管理器实例
_optimization_manager = None
def get_optimization_manager(config_path: Optional[str] = None) -> OptimizationManager:
"""
获取全局优化管理器实例
参数:
config_path: 配置文件路径
返回:
manager: 优化管理器实例
"""
global _optimization_manager
if _optimization_manager is None:
_optimization_manager = OptimizationManager(config_path)
return _optimization_manager
def reset_optimization_manager():
"""重置全局优化管理器实例"""
global _optimization_manager
_optimization_manager = None
# 测试代码
if __name__ == "__main__":
# 创建优化管理器
manager = OptimizationManager()
# 获取配置
config = manager.get_optimization_config()
print("优化配置:", config)
# 检查优化状态
print("DAG-HMM优化启用:", manager.is_optimization_enabled("dag_hmm"))
print("特征融合优化启用:", manager.is_optimization_enabled("feature_fusion"))
# 记录性能指标
manager.record_performance_metrics("detector", {
"accuracy": 0.95,
"precision": 0.93,
"recall": 0.97
})
manager.record_performance_metrics("classifier", {
"accuracy": 0.92,
"f1": 0.91
})
# 获取系统状态
status = manager.get_system_status()
print("\\n系统状态:", status)
# 检查性能目标
targets = manager.check_performance_targets()
print("\\n性能目标达成情况:", targets)
# 生成报告
report = manager.generate_optimization_report()
print("\\n优化报告生成完成包含", len(report), "个部分")