feat: first commit
This commit is contained in:
460
utils/optimization_manager.py
Normal file
460
utils/optimization_manager.py
Normal file
@@ -0,0 +1,460 @@
|
||||
"""
|
||||
优化管理器 - 统一管理所有优化模块的配置和状态
|
||||
|
||||
该模块提供了一个统一的接口来管理和配置所有的优化功能,
|
||||
包括DAG-HMM优化、特征融合优化和HMM参数优化。
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, Any, Optional, List
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass
|
||||
class OptimizationConfig:
|
||||
"""优化配置数据类"""
|
||||
enable_optimizations: bool = True
|
||||
optimization_level: str = "full"
|
||||
|
||||
# DAG-HMM优化配置
|
||||
dag_hmm_enabled: bool = True
|
||||
max_states: int = 10
|
||||
max_gaussians: int = 5
|
||||
cv_folds: int = 3
|
||||
|
||||
# 特征融合优化配置
|
||||
feature_fusion_enabled: bool = True
|
||||
adaptive_learning: bool = True
|
||||
feature_selection: bool = True
|
||||
pca_components: int = 50
|
||||
|
||||
# HMM参数优化配置
|
||||
hmm_optimization_enabled: bool = True
|
||||
optimization_method: str = "grid_search"
|
||||
early_stopping: bool = True
|
||||
|
||||
# 检测器优化配置
|
||||
detector_optimization_enabled: bool = True
|
||||
use_optimized_fusion: bool = True
|
||||
default_model: str = "svm"
|
||||
|
||||
class OptimizationManager:
|
||||
"""
|
||||
优化管理器
|
||||
|
||||
统一管理所有优化模块的配置、状态和性能监控。
|
||||
"""
|
||||
|
||||
def __init__(self, config_path: Optional[str] = None):
|
||||
"""
|
||||
初始化优化管理器
|
||||
|
||||
参数:
|
||||
config_path: 配置文件路径
|
||||
"""
|
||||
self.config_path = config_path or self._get_default_config_path()
|
||||
self.config = self._load_config()
|
||||
self.optimization_status = {}
|
||||
self.performance_metrics = {}
|
||||
|
||||
# 设置日志
|
||||
self._setup_logging()
|
||||
|
||||
self.logger.info("优化管理器已初始化")
|
||||
|
||||
def _get_default_config_path(self) -> str:
|
||||
"""获取默认配置文件路径"""
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
project_root = os.path.dirname(current_dir)
|
||||
return os.path.join(project_root, "config", "optimization_config.json")
|
||||
|
||||
def _setup_logging(self):
|
||||
"""设置日志"""
|
||||
log_level = self.config.get("logging", {}).get("log_level", "INFO")
|
||||
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, log_level),
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
|
||||
self.logger = logging.getLogger("OptimizationManager")
|
||||
|
||||
def _load_config(self) -> Dict[str, Any]:
|
||||
"""
|
||||
加载配置文件
|
||||
|
||||
返回:
|
||||
config: 配置字典
|
||||
"""
|
||||
if os.path.exists(self.config_path):
|
||||
try:
|
||||
with open(self.config_path, 'r', encoding='utf-8') as f:
|
||||
config = json.load(f)
|
||||
return config
|
||||
except Exception as e:
|
||||
print(f"加载配置文件失败: {e}")
|
||||
return self._get_default_config()
|
||||
else:
|
||||
print(f"配置文件不存在: {self.config_path}")
|
||||
return self._get_default_config()
|
||||
|
||||
def _get_default_config(self) -> Dict[str, Any]:
|
||||
"""获取默认配置"""
|
||||
return {
|
||||
"optimization_settings": {
|
||||
"enable_optimizations": True,
|
||||
"optimization_level": "full"
|
||||
},
|
||||
"dag_hmm_optimization": {
|
||||
"enabled": True,
|
||||
"max_states": 10,
|
||||
"max_gaussians": 5,
|
||||
"cv_folds": 3
|
||||
},
|
||||
"feature_fusion_optimization": {
|
||||
"enabled": True,
|
||||
"adaptive_learning": True,
|
||||
"feature_selection": True,
|
||||
"pca_components": 50,
|
||||
"initial_weights": {
|
||||
"temporal_modulation": 0.2,
|
||||
"mfcc": 0.3,
|
||||
"yamnet": 0.5
|
||||
}
|
||||
},
|
||||
"hmm_parameter_optimization": {
|
||||
"enabled": True,
|
||||
"optimization_methods": ["grid_search"],
|
||||
"early_stopping": True
|
||||
},
|
||||
"detector_optimization": {
|
||||
"enabled": True,
|
||||
"use_optimized_fusion": True,
|
||||
"default_model": "svm"
|
||||
}
|
||||
}
|
||||
|
||||
def save_config(self) -> None:
|
||||
"""保存配置文件"""
|
||||
try:
|
||||
os.makedirs(os.path.dirname(self.config_path), exist_ok=True)
|
||||
with open(self.config_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(self.config, f, indent=2, ensure_ascii=False)
|
||||
self.logger.info(f"配置已保存到: {self.config_path}")
|
||||
except Exception as e:
|
||||
self.logger.error(f"保存配置失败: {e}")
|
||||
|
||||
def get_optimization_config(self) -> OptimizationConfig:
|
||||
"""
|
||||
获取优化配置对象
|
||||
|
||||
返回:
|
||||
config: 优化配置对象
|
||||
"""
|
||||
opt_settings = self.config.get("optimization_settings", {})
|
||||
dag_hmm_config = self.config.get("dag_hmm_optimization", {})
|
||||
fusion_config = self.config.get("feature_fusion_optimization", {})
|
||||
hmm_config = self.config.get("hmm_parameter_optimization", {})
|
||||
detector_config = self.config.get("detector_optimization", {})
|
||||
|
||||
return OptimizationConfig(
|
||||
enable_optimizations=opt_settings.get("enable_optimizations", True),
|
||||
optimization_level=opt_settings.get("optimization_level", "full"),
|
||||
|
||||
dag_hmm_enabled=dag_hmm_config.get("enabled", True),
|
||||
max_states=dag_hmm_config.get("max_states", 10),
|
||||
max_gaussians=dag_hmm_config.get("max_gaussians", 5),
|
||||
cv_folds=dag_hmm_config.get("cv_folds", 3),
|
||||
|
||||
feature_fusion_enabled=fusion_config.get("enabled", True),
|
||||
adaptive_learning=fusion_config.get("adaptive_learning", True),
|
||||
feature_selection=fusion_config.get("feature_selection", True),
|
||||
pca_components=fusion_config.get("pca_components", 50),
|
||||
|
||||
hmm_optimization_enabled=hmm_config.get("enabled", True),
|
||||
optimization_method=hmm_config.get("optimization_methods", ["grid_search"])[0],
|
||||
early_stopping=hmm_config.get("early_stopping", True),
|
||||
|
||||
detector_optimization_enabled=detector_config.get("enabled", True),
|
||||
use_optimized_fusion=detector_config.get("use_optimized_fusion", True),
|
||||
default_model=detector_config.get("default_model", "svm")
|
||||
)
|
||||
|
||||
def is_optimization_enabled(self, optimization_type: str) -> bool:
|
||||
"""
|
||||
检查特定优化是否启用
|
||||
|
||||
参数:
|
||||
optimization_type: 优化类型
|
||||
|
||||
返回:
|
||||
enabled: 是否启用
|
||||
"""
|
||||
if not self.config.get("optimization_settings", {}).get("enable_optimizations", True):
|
||||
return False
|
||||
|
||||
type_mapping = {
|
||||
"dag_hmm": "dag_hmm_optimization",
|
||||
"feature_fusion": "feature_fusion_optimization",
|
||||
"hmm_parameter": "hmm_parameter_optimization",
|
||||
"detector": "detector_optimization"
|
||||
}
|
||||
|
||||
config_key = type_mapping.get(optimization_type)
|
||||
if config_key:
|
||||
return self.config.get(config_key, {}).get("enabled", True)
|
||||
|
||||
return False
|
||||
|
||||
def enable_optimization(self, optimization_type: str) -> None:
|
||||
"""
|
||||
启用特定优化
|
||||
|
||||
参数:
|
||||
optimization_type: 优化类型
|
||||
"""
|
||||
type_mapping = {
|
||||
"dag_hmm": "dag_hmm_optimization",
|
||||
"feature_fusion": "feature_fusion_optimization",
|
||||
"hmm_parameter": "hmm_parameter_optimization",
|
||||
"detector": "detector_optimization"
|
||||
}
|
||||
|
||||
config_key = type_mapping.get(optimization_type)
|
||||
if config_key:
|
||||
if config_key not in self.config:
|
||||
self.config[config_key] = {}
|
||||
self.config[config_key]["enabled"] = True
|
||||
self.logger.info(f"已启用 {optimization_type} 优化")
|
||||
|
||||
def disable_optimization(self, optimization_type: str) -> None:
|
||||
"""
|
||||
禁用特定优化
|
||||
|
||||
参数:
|
||||
optimization_type: 优化类型
|
||||
"""
|
||||
type_mapping = {
|
||||
"dag_hmm": "dag_hmm_optimization",
|
||||
"feature_fusion": "feature_fusion_optimization",
|
||||
"hmm_parameter": "hmm_parameter_optimization",
|
||||
"detector": "detector_optimization"
|
||||
}
|
||||
|
||||
config_key = type_mapping.get(optimization_type)
|
||||
if config_key:
|
||||
if config_key not in self.config:
|
||||
self.config[config_key] = {}
|
||||
self.config[config_key]["enabled"] = False
|
||||
self.logger.info(f"已禁用 {optimization_type} 优化")
|
||||
|
||||
def update_optimization_status(self, optimization_type: str, status: Dict[str, Any]) -> None:
|
||||
"""
|
||||
更新优化状态
|
||||
|
||||
参数:
|
||||
optimization_type: 优化类型
|
||||
status: 状态信息
|
||||
"""
|
||||
self.optimization_status[optimization_type] = {
|
||||
**status,
|
||||
"timestamp": self._get_timestamp()
|
||||
}
|
||||
|
||||
if self.config.get("logging", {}).get("log_optimization_process", True):
|
||||
self.logger.info(f"{optimization_type} 优化状态更新: {status}")
|
||||
|
||||
def record_performance_metrics(self, component: str, metrics: Dict[str, Any]) -> None:
|
||||
"""
|
||||
记录性能指标
|
||||
|
||||
参数:
|
||||
component: 组件名称
|
||||
metrics: 性能指标
|
||||
"""
|
||||
if component not in self.performance_metrics:
|
||||
self.performance_metrics[component] = []
|
||||
|
||||
self.performance_metrics[component].append({
|
||||
**metrics,
|
||||
"timestamp": self._get_timestamp()
|
||||
})
|
||||
|
||||
if self.config.get("logging", {}).get("log_performance_metrics", True):
|
||||
self.logger.info(f"{component} 性能指标: {metrics}")
|
||||
|
||||
def get_performance_summary(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取性能摘要
|
||||
|
||||
返回:
|
||||
summary: 性能摘要
|
||||
"""
|
||||
summary = {}
|
||||
|
||||
for component, metrics_list in self.performance_metrics.items():
|
||||
if metrics_list:
|
||||
latest_metrics = metrics_list[-1]
|
||||
summary[component] = {
|
||||
"latest_metrics": latest_metrics,
|
||||
"total_records": len(metrics_list)
|
||||
}
|
||||
|
||||
return summary
|
||||
|
||||
def check_performance_targets(self) -> Dict[str, bool]:
|
||||
"""
|
||||
检查是否达到性能目标
|
||||
|
||||
返回:
|
||||
results: 目标达成情况
|
||||
"""
|
||||
targets = self.config.get("performance_targets", {})
|
||||
results = {}
|
||||
|
||||
# 检查猫叫声检测准确率
|
||||
if "cat_detection_accuracy" in targets:
|
||||
target = targets["cat_detection_accuracy"]
|
||||
current = self._get_latest_metric("detector", "accuracy")
|
||||
results["cat_detection_accuracy"] = current >= target if current is not None else False
|
||||
|
||||
# 检查意图分类准确率
|
||||
if "intent_classification_accuracy" in targets:
|
||||
target = targets["intent_classification_accuracy"]
|
||||
current = self._get_latest_metric("classifier", "accuracy")
|
||||
results["intent_classification_accuracy"] = current >= target if current is not None else False
|
||||
|
||||
return results
|
||||
|
||||
def _get_latest_metric(self, component: str, metric_name: str) -> Optional[float]:
|
||||
"""获取最新的指标值"""
|
||||
if component in self.performance_metrics and self.performance_metrics[component]:
|
||||
latest = self.performance_metrics[component][-1]
|
||||
return latest.get(metric_name)
|
||||
return None
|
||||
|
||||
def _get_timestamp(self) -> str:
|
||||
"""获取当前时间戳"""
|
||||
from datetime import datetime
|
||||
return datetime.now().isoformat()
|
||||
|
||||
def get_system_status(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取系统状态
|
||||
|
||||
返回:
|
||||
status: 系统状态
|
||||
"""
|
||||
config = self.get_optimization_config()
|
||||
|
||||
return {
|
||||
"optimization_enabled": config.enable_optimizations,
|
||||
"optimization_level": config.optimization_level,
|
||||
"optimizations": {
|
||||
"dag_hmm": config.dag_hmm_enabled,
|
||||
"feature_fusion": config.feature_fusion_enabled,
|
||||
"hmm_parameter": config.hmm_optimization_enabled,
|
||||
"detector": config.detector_optimization_enabled
|
||||
},
|
||||
"optimization_status": self.optimization_status,
|
||||
"performance_summary": self.get_performance_summary(),
|
||||
"performance_targets": self.check_performance_targets()
|
||||
}
|
||||
|
||||
def generate_optimization_report(self) -> Dict[str, Any]:
|
||||
"""
|
||||
生成优化报告
|
||||
|
||||
返回:
|
||||
report: 优化报告
|
||||
"""
|
||||
return {
|
||||
"config": self.config,
|
||||
"system_status": self.get_system_status(),
|
||||
"performance_metrics": self.performance_metrics,
|
||||
"optimization_status": self.optimization_status,
|
||||
"timestamp": self._get_timestamp()
|
||||
}
|
||||
|
||||
def export_report(self, output_path: str) -> None:
|
||||
"""
|
||||
导出优化报告
|
||||
|
||||
参数:
|
||||
output_path: 输出路径
|
||||
"""
|
||||
report = self.generate_optimization_report()
|
||||
|
||||
try:
|
||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(report, f, indent=2, ensure_ascii=False)
|
||||
self.logger.info(f"优化报告已导出到: {output_path}")
|
||||
except Exception as e:
|
||||
self.logger.error(f"导出报告失败: {e}")
|
||||
|
||||
|
||||
# 全局优化管理器实例
|
||||
_optimization_manager = None
|
||||
|
||||
def get_optimization_manager(config_path: Optional[str] = None) -> OptimizationManager:
|
||||
"""
|
||||
获取全局优化管理器实例
|
||||
|
||||
参数:
|
||||
config_path: 配置文件路径
|
||||
|
||||
返回:
|
||||
manager: 优化管理器实例
|
||||
"""
|
||||
global _optimization_manager
|
||||
|
||||
if _optimization_manager is None:
|
||||
_optimization_manager = OptimizationManager(config_path)
|
||||
|
||||
return _optimization_manager
|
||||
|
||||
def reset_optimization_manager():
|
||||
"""重置全局优化管理器实例"""
|
||||
global _optimization_manager
|
||||
_optimization_manager = None
|
||||
|
||||
|
||||
# 测试代码
|
||||
if __name__ == "__main__":
|
||||
# 创建优化管理器
|
||||
manager = OptimizationManager()
|
||||
|
||||
# 获取配置
|
||||
config = manager.get_optimization_config()
|
||||
print("优化配置:", config)
|
||||
|
||||
# 检查优化状态
|
||||
print("DAG-HMM优化启用:", manager.is_optimization_enabled("dag_hmm"))
|
||||
print("特征融合优化启用:", manager.is_optimization_enabled("feature_fusion"))
|
||||
|
||||
# 记录性能指标
|
||||
manager.record_performance_metrics("detector", {
|
||||
"accuracy": 0.95,
|
||||
"precision": 0.93,
|
||||
"recall": 0.97
|
||||
})
|
||||
|
||||
manager.record_performance_metrics("classifier", {
|
||||
"accuracy": 0.92,
|
||||
"f1": 0.91
|
||||
})
|
||||
|
||||
# 获取系统状态
|
||||
status = manager.get_system_status()
|
||||
print("\\n系统状态:", status)
|
||||
|
||||
# 检查性能目标
|
||||
targets = manager.check_performance_targets()
|
||||
print("\\n性能目标达成情况:", targets)
|
||||
|
||||
# 生成报告
|
||||
report = manager.generate_optimization_report()
|
||||
print("\\n优化报告生成完成,包含", len(report), "个部分")
|
||||
|
||||
Reference in New Issue
Block a user