""" 猫叫声检测器集成模块 - 将专用猫叫声检测器集成到主系统 """ import os import numpy as np from typing import Dict, Any, Optional from src.hybrid_feature_extractor import HybridFeatureExtractor from src.cat_sound_detector import CatSoundDetector class IntegratedCatDetector: """集成猫叫声检测器类,结合YAMNet和专用检测器""" def __init__(self, detector_model_path: Optional[str] = None, threshold: float = 0.5, fallback_threshold: float = 0.1): """ 初始化集成猫叫声检测器 参数: detector_model_path: 专用检测器模型路径,如果为None则仅使用YAMNet threshold: 专用检测器阈值 fallback_threshold: YAMNet回退阈值 """ self.feature_extractor = HybridFeatureExtractor() self.detector = None self.threshold = threshold self.fallback_threshold = fallback_threshold # 如果提供了模型路径,加载专用检测器 if detector_model_path and os.path.exists(detector_model_path): try: self.detector = CatSoundDetector(model_path=detector_model_path) print(f"已加载专用猫叫声检测器: {detector_model_path}") except Exception as e: print(f"加载专用猫叫声检测器失败: {e}") print("将使用YAMNet作为回退方案") def detect(self, audio_data: np.ndarray) -> Dict[str, Any]: """ 检测音频是否包含猫叫声 参数: audio_data: 音频数据 返回: result: 检测结果 """ # 提取YAMNet特征 features = self.feature_extractor.process_audio(audio_data) # 获取YAMNet的猫叫声检测结果 yamnet_detection = features["cat_detection"] # 如果有专用检测器,使用它进行检测 if self.detector is not None: # 使用平均嵌入向量 embedding_mean = np.mean(features["embeddings"], axis=0) # 使用专用检测器预测 detector_result = self.detector.predict(embedding_mean) # 合并结果 result = { 'detected': detector_result['detected'] or (detector_result['confidence'] > self.threshold), 'confidence': detector_result['confidence'], 'yamnet_confidence': yamnet_detection['confidence'], 'yamnet_detected': yamnet_detection['detected'], 'using_specialized_detector': True } else: # 仅使用YAMNet结果 result = { 'detected': yamnet_detection['detected'] or (yamnet_detection['confidence'] > self.fallback_threshold), 'confidence': yamnet_detection['confidence'], 'yamnet_confidence': yamnet_detection['confidence'], 'yamnet_detected': yamnet_detection['detected'], 'using_specialized_detector': False } # 添加YAMNet检测到的类别 result['top_categories'] = features["top_categories"] return result