461 lines
15 KiB
Python
461 lines
15 KiB
Python
"""
|
||
优化管理器 - 统一管理所有优化模块的配置和状态
|
||
|
||
该模块提供了一个统一的接口来管理和配置所有的优化功能,
|
||
包括DAG-HMM优化、特征融合优化和HMM参数优化。
|
||
"""
|
||
|
||
import os
|
||
import json
|
||
import logging
|
||
from typing import Dict, Any, Optional, List
|
||
from dataclasses import dataclass
|
||
|
||
@dataclass
|
||
class OptimizationConfig:
|
||
"""优化配置数据类"""
|
||
enable_optimizations: bool = True
|
||
optimization_level: str = "full"
|
||
|
||
# DAG-HMM优化配置
|
||
dag_hmm_enabled: bool = True
|
||
max_states: int = 10
|
||
max_gaussians: int = 5
|
||
cv_folds: int = 3
|
||
|
||
# 特征融合优化配置
|
||
feature_fusion_enabled: bool = True
|
||
adaptive_learning: bool = True
|
||
feature_selection: bool = True
|
||
pca_components: int = 50
|
||
|
||
# HMM参数优化配置
|
||
hmm_optimization_enabled: bool = True
|
||
optimization_method: str = "grid_search"
|
||
early_stopping: bool = True
|
||
|
||
# 检测器优化配置
|
||
detector_optimization_enabled: bool = True
|
||
use_optimized_fusion: bool = True
|
||
default_model: str = "svm"
|
||
|
||
class OptimizationManager:
|
||
"""
|
||
优化管理器
|
||
|
||
统一管理所有优化模块的配置、状态和性能监控。
|
||
"""
|
||
|
||
def __init__(self, config_path: Optional[str] = None):
|
||
"""
|
||
初始化优化管理器
|
||
|
||
参数:
|
||
config_path: 配置文件路径
|
||
"""
|
||
self.config_path = config_path or self._get_default_config_path()
|
||
self.config = self._load_config()
|
||
self.optimization_status = {}
|
||
self.performance_metrics = {}
|
||
|
||
# 设置日志
|
||
self._setup_logging()
|
||
|
||
self.logger.info("优化管理器已初始化")
|
||
|
||
def _get_default_config_path(self) -> str:
|
||
"""获取默认配置文件路径"""
|
||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||
project_root = os.path.dirname(current_dir)
|
||
return os.path.join(project_root, "config", "optimization_config.json")
|
||
|
||
def _setup_logging(self):
|
||
"""设置日志"""
|
||
log_level = self.config.get("logging", {}).get("log_level", "INFO")
|
||
|
||
logging.basicConfig(
|
||
level=getattr(logging, log_level),
|
||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||
)
|
||
|
||
self.logger = logging.getLogger("OptimizationManager")
|
||
|
||
def _load_config(self) -> Dict[str, Any]:
|
||
"""
|
||
加载配置文件
|
||
|
||
返回:
|
||
config: 配置字典
|
||
"""
|
||
if os.path.exists(self.config_path):
|
||
try:
|
||
with open(self.config_path, 'r', encoding='utf-8') as f:
|
||
config = json.load(f)
|
||
return config
|
||
except Exception as e:
|
||
print(f"加载配置文件失败: {e}")
|
||
return self._get_default_config()
|
||
else:
|
||
print(f"配置文件不存在: {self.config_path}")
|
||
return self._get_default_config()
|
||
|
||
def _get_default_config(self) -> Dict[str, Any]:
|
||
"""获取默认配置"""
|
||
return {
|
||
"optimization_settings": {
|
||
"enable_optimizations": True,
|
||
"optimization_level": "full"
|
||
},
|
||
"dag_hmm_optimization": {
|
||
"enabled": True,
|
||
"max_states": 10,
|
||
"max_gaussians": 5,
|
||
"cv_folds": 3
|
||
},
|
||
"feature_fusion_optimization": {
|
||
"enabled": True,
|
||
"adaptive_learning": True,
|
||
"feature_selection": True,
|
||
"pca_components": 50,
|
||
"initial_weights": {
|
||
"temporal_modulation": 0.2,
|
||
"mfcc": 0.3,
|
||
"yamnet": 0.5
|
||
}
|
||
},
|
||
"hmm_parameter_optimization": {
|
||
"enabled": True,
|
||
"optimization_methods": ["grid_search"],
|
||
"early_stopping": True
|
||
},
|
||
"detector_optimization": {
|
||
"enabled": True,
|
||
"use_optimized_fusion": True,
|
||
"default_model": "svm"
|
||
}
|
||
}
|
||
|
||
def save_config(self) -> None:
|
||
"""保存配置文件"""
|
||
try:
|
||
os.makedirs(os.path.dirname(self.config_path), exist_ok=True)
|
||
with open(self.config_path, 'w', encoding='utf-8') as f:
|
||
json.dump(self.config, f, indent=2, ensure_ascii=False)
|
||
self.logger.info(f"配置已保存到: {self.config_path}")
|
||
except Exception as e:
|
||
self.logger.error(f"保存配置失败: {e}")
|
||
|
||
def get_optimization_config(self) -> OptimizationConfig:
|
||
"""
|
||
获取优化配置对象
|
||
|
||
返回:
|
||
config: 优化配置对象
|
||
"""
|
||
opt_settings = self.config.get("optimization_settings", {})
|
||
dag_hmm_config = self.config.get("dag_hmm_optimization", {})
|
||
fusion_config = self.config.get("feature_fusion_optimization", {})
|
||
hmm_config = self.config.get("hmm_parameter_optimization", {})
|
||
detector_config = self.config.get("detector_optimization", {})
|
||
|
||
return OptimizationConfig(
|
||
enable_optimizations=opt_settings.get("enable_optimizations", True),
|
||
optimization_level=opt_settings.get("optimization_level", "full"),
|
||
|
||
dag_hmm_enabled=dag_hmm_config.get("enabled", True),
|
||
max_states=dag_hmm_config.get("max_states", 10),
|
||
max_gaussians=dag_hmm_config.get("max_gaussians", 5),
|
||
cv_folds=dag_hmm_config.get("cv_folds", 3),
|
||
|
||
feature_fusion_enabled=fusion_config.get("enabled", True),
|
||
adaptive_learning=fusion_config.get("adaptive_learning", True),
|
||
feature_selection=fusion_config.get("feature_selection", True),
|
||
pca_components=fusion_config.get("pca_components", 50),
|
||
|
||
hmm_optimization_enabled=hmm_config.get("enabled", True),
|
||
optimization_method=hmm_config.get("optimization_methods", ["grid_search"])[0],
|
||
early_stopping=hmm_config.get("early_stopping", True),
|
||
|
||
detector_optimization_enabled=detector_config.get("enabled", True),
|
||
use_optimized_fusion=detector_config.get("use_optimized_fusion", True),
|
||
default_model=detector_config.get("default_model", "svm")
|
||
)
|
||
|
||
def is_optimization_enabled(self, optimization_type: str) -> bool:
|
||
"""
|
||
检查特定优化是否启用
|
||
|
||
参数:
|
||
optimization_type: 优化类型
|
||
|
||
返回:
|
||
enabled: 是否启用
|
||
"""
|
||
if not self.config.get("optimization_settings", {}).get("enable_optimizations", True):
|
||
return False
|
||
|
||
type_mapping = {
|
||
"dag_hmm": "dag_hmm_optimization",
|
||
"feature_fusion": "feature_fusion_optimization",
|
||
"hmm_parameter": "hmm_parameter_optimization",
|
||
"detector": "detector_optimization"
|
||
}
|
||
|
||
config_key = type_mapping.get(optimization_type)
|
||
if config_key:
|
||
return self.config.get(config_key, {}).get("enabled", True)
|
||
|
||
return False
|
||
|
||
def enable_optimization(self, optimization_type: str) -> None:
|
||
"""
|
||
启用特定优化
|
||
|
||
参数:
|
||
optimization_type: 优化类型
|
||
"""
|
||
type_mapping = {
|
||
"dag_hmm": "dag_hmm_optimization",
|
||
"feature_fusion": "feature_fusion_optimization",
|
||
"hmm_parameter": "hmm_parameter_optimization",
|
||
"detector": "detector_optimization"
|
||
}
|
||
|
||
config_key = type_mapping.get(optimization_type)
|
||
if config_key:
|
||
if config_key not in self.config:
|
||
self.config[config_key] = {}
|
||
self.config[config_key]["enabled"] = True
|
||
self.logger.info(f"已启用 {optimization_type} 优化")
|
||
|
||
def disable_optimization(self, optimization_type: str) -> None:
|
||
"""
|
||
禁用特定优化
|
||
|
||
参数:
|
||
optimization_type: 优化类型
|
||
"""
|
||
type_mapping = {
|
||
"dag_hmm": "dag_hmm_optimization",
|
||
"feature_fusion": "feature_fusion_optimization",
|
||
"hmm_parameter": "hmm_parameter_optimization",
|
||
"detector": "detector_optimization"
|
||
}
|
||
|
||
config_key = type_mapping.get(optimization_type)
|
||
if config_key:
|
||
if config_key not in self.config:
|
||
self.config[config_key] = {}
|
||
self.config[config_key]["enabled"] = False
|
||
self.logger.info(f"已禁用 {optimization_type} 优化")
|
||
|
||
def update_optimization_status(self, optimization_type: str, status: Dict[str, Any]) -> None:
|
||
"""
|
||
更新优化状态
|
||
|
||
参数:
|
||
optimization_type: 优化类型
|
||
status: 状态信息
|
||
"""
|
||
self.optimization_status[optimization_type] = {
|
||
**status,
|
||
"timestamp": self._get_timestamp()
|
||
}
|
||
|
||
if self.config.get("logging", {}).get("log_optimization_process", True):
|
||
self.logger.info(f"{optimization_type} 优化状态更新: {status}")
|
||
|
||
def record_performance_metrics(self, component: str, metrics: Dict[str, Any]) -> None:
|
||
"""
|
||
记录性能指标
|
||
|
||
参数:
|
||
component: 组件名称
|
||
metrics: 性能指标
|
||
"""
|
||
if component not in self.performance_metrics:
|
||
self.performance_metrics[component] = []
|
||
|
||
self.performance_metrics[component].append({
|
||
**metrics,
|
||
"timestamp": self._get_timestamp()
|
||
})
|
||
|
||
if self.config.get("logging", {}).get("log_performance_metrics", True):
|
||
self.logger.info(f"{component} 性能指标: {metrics}")
|
||
|
||
def get_performance_summary(self) -> Dict[str, Any]:
|
||
"""
|
||
获取性能摘要
|
||
|
||
返回:
|
||
summary: 性能摘要
|
||
"""
|
||
summary = {}
|
||
|
||
for component, metrics_list in self.performance_metrics.items():
|
||
if metrics_list:
|
||
latest_metrics = metrics_list[-1]
|
||
summary[component] = {
|
||
"latest_metrics": latest_metrics,
|
||
"total_records": len(metrics_list)
|
||
}
|
||
|
||
return summary
|
||
|
||
def check_performance_targets(self) -> Dict[str, bool]:
|
||
"""
|
||
检查是否达到性能目标
|
||
|
||
返回:
|
||
results: 目标达成情况
|
||
"""
|
||
targets = self.config.get("performance_targets", {})
|
||
results = {}
|
||
|
||
# 检查猫叫声检测准确率
|
||
if "cat_detection_accuracy" in targets:
|
||
target = targets["cat_detection_accuracy"]
|
||
current = self._get_latest_metric("detector", "accuracy")
|
||
results["cat_detection_accuracy"] = current >= target if current is not None else False
|
||
|
||
# 检查意图分类准确率
|
||
if "intent_classification_accuracy" in targets:
|
||
target = targets["intent_classification_accuracy"]
|
||
current = self._get_latest_metric("classifier", "accuracy")
|
||
results["intent_classification_accuracy"] = current >= target if current is not None else False
|
||
|
||
return results
|
||
|
||
def _get_latest_metric(self, component: str, metric_name: str) -> Optional[float]:
|
||
"""获取最新的指标值"""
|
||
if component in self.performance_metrics and self.performance_metrics[component]:
|
||
latest = self.performance_metrics[component][-1]
|
||
return latest.get(metric_name)
|
||
return None
|
||
|
||
def _get_timestamp(self) -> str:
|
||
"""获取当前时间戳"""
|
||
from datetime import datetime
|
||
return datetime.now().isoformat()
|
||
|
||
def get_system_status(self) -> Dict[str, Any]:
|
||
"""
|
||
获取系统状态
|
||
|
||
返回:
|
||
status: 系统状态
|
||
"""
|
||
config = self.get_optimization_config()
|
||
|
||
return {
|
||
"optimization_enabled": config.enable_optimizations,
|
||
"optimization_level": config.optimization_level,
|
||
"optimizations": {
|
||
"dag_hmm": config.dag_hmm_enabled,
|
||
"feature_fusion": config.feature_fusion_enabled,
|
||
"hmm_parameter": config.hmm_optimization_enabled,
|
||
"detector": config.detector_optimization_enabled
|
||
},
|
||
"optimization_status": self.optimization_status,
|
||
"performance_summary": self.get_performance_summary(),
|
||
"performance_targets": self.check_performance_targets()
|
||
}
|
||
|
||
def generate_optimization_report(self) -> Dict[str, Any]:
|
||
"""
|
||
生成优化报告
|
||
|
||
返回:
|
||
report: 优化报告
|
||
"""
|
||
return {
|
||
"config": self.config,
|
||
"system_status": self.get_system_status(),
|
||
"performance_metrics": self.performance_metrics,
|
||
"optimization_status": self.optimization_status,
|
||
"timestamp": self._get_timestamp()
|
||
}
|
||
|
||
def export_report(self, output_path: str) -> None:
|
||
"""
|
||
导出优化报告
|
||
|
||
参数:
|
||
output_path: 输出路径
|
||
"""
|
||
report = self.generate_optimization_report()
|
||
|
||
try:
|
||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||
with open(output_path, 'w', encoding='utf-8') as f:
|
||
json.dump(report, f, indent=2, ensure_ascii=False)
|
||
self.logger.info(f"优化报告已导出到: {output_path}")
|
||
except Exception as e:
|
||
self.logger.error(f"导出报告失败: {e}")
|
||
|
||
|
||
# 全局优化管理器实例
|
||
_optimization_manager = None
|
||
|
||
def get_optimization_manager(config_path: Optional[str] = None) -> OptimizationManager:
|
||
"""
|
||
获取全局优化管理器实例
|
||
|
||
参数:
|
||
config_path: 配置文件路径
|
||
|
||
返回:
|
||
manager: 优化管理器实例
|
||
"""
|
||
global _optimization_manager
|
||
|
||
if _optimization_manager is None:
|
||
_optimization_manager = OptimizationManager(config_path)
|
||
|
||
return _optimization_manager
|
||
|
||
def reset_optimization_manager():
|
||
"""重置全局优化管理器实例"""
|
||
global _optimization_manager
|
||
_optimization_manager = None
|
||
|
||
|
||
# 测试代码
|
||
if __name__ == "__main__":
|
||
# 创建优化管理器
|
||
manager = OptimizationManager()
|
||
|
||
# 获取配置
|
||
config = manager.get_optimization_config()
|
||
print("优化配置:", config)
|
||
|
||
# 检查优化状态
|
||
print("DAG-HMM优化启用:", manager.is_optimization_enabled("dag_hmm"))
|
||
print("特征融合优化启用:", manager.is_optimization_enabled("feature_fusion"))
|
||
|
||
# 记录性能指标
|
||
manager.record_performance_metrics("detector", {
|
||
"accuracy": 0.95,
|
||
"precision": 0.93,
|
||
"recall": 0.97
|
||
})
|
||
|
||
manager.record_performance_metrics("classifier", {
|
||
"accuracy": 0.92,
|
||
"f1": 0.91
|
||
})
|
||
|
||
# 获取系统状态
|
||
status = manager.get_system_status()
|
||
print("\\n系统状态:", status)
|
||
|
||
# 检查性能目标
|
||
targets = manager.check_performance_targets()
|
||
print("\\n性能目标达成情况:", targets)
|
||
|
||
# 生成报告
|
||
report = manager.generate_optimization_report()
|
||
print("\\n优化报告生成完成,包含", len(report), "个部分")
|
||
|