85 lines
3.2 KiB
Python
85 lines
3.2 KiB
Python
"""
|
||
猫叫声检测器集成模块 - 将专用猫叫声检测器集成到主系统
|
||
"""
|
||
|
||
import os
|
||
import numpy as np
|
||
from typing import Dict, Any, Optional
|
||
|
||
from src.hybrid_feature_extractor import HybridFeatureExtractor
|
||
from src.cat_sound_detector import CatSoundDetector
|
||
|
||
class IntegratedCatDetector:
|
||
"""集成猫叫声检测器类,结合YAMNet和专用检测器"""
|
||
|
||
def __init__(self, detector_model_path: Optional[str] = None,
|
||
threshold: float = 0.5, fallback_threshold: float = 0.1):
|
||
"""
|
||
初始化集成猫叫声检测器
|
||
|
||
参数:
|
||
detector_model_path: 专用检测器模型路径,如果为None则仅使用YAMNet
|
||
threshold: 专用检测器阈值
|
||
fallback_threshold: YAMNet回退阈值
|
||
"""
|
||
self.feature_extractor = HybridFeatureExtractor()
|
||
self.detector = None
|
||
self.threshold = threshold
|
||
self.fallback_threshold = fallback_threshold
|
||
|
||
# 如果提供了模型路径,加载专用检测器
|
||
if detector_model_path and os.path.exists(detector_model_path):
|
||
try:
|
||
self.detector = CatSoundDetector(model_path=detector_model_path)
|
||
print(f"已加载专用猫叫声检测器: {detector_model_path}")
|
||
except Exception as e:
|
||
print(f"加载专用猫叫声检测器失败: {e}")
|
||
print("将使用YAMNet作为回退方案")
|
||
|
||
def detect(self, audio_data: np.ndarray) -> Dict[str, Any]:
|
||
"""
|
||
检测音频是否包含猫叫声
|
||
|
||
参数:
|
||
audio_data: 音频数据
|
||
|
||
返回:
|
||
result: 检测结果
|
||
"""
|
||
# 提取YAMNet特征
|
||
features = self.feature_extractor.process_audio(audio_data)
|
||
|
||
# 获取YAMNet的猫叫声检测结果
|
||
yamnet_detection = features["cat_detection"]
|
||
|
||
# 如果有专用检测器,使用它进行检测
|
||
if self.detector is not None:
|
||
# 使用平均嵌入向量
|
||
embedding_mean = np.mean(features["embeddings"], axis=0)
|
||
|
||
# 使用专用检测器预测
|
||
detector_result = self.detector.predict(embedding_mean)
|
||
|
||
# 合并结果
|
||
result = {
|
||
'detected': detector_result['detected'] or (detector_result['confidence'] > self.threshold),
|
||
'confidence': detector_result['confidence'],
|
||
'yamnet_confidence': yamnet_detection['confidence'],
|
||
'yamnet_detected': yamnet_detection['detected'],
|
||
'using_specialized_detector': True
|
||
}
|
||
else:
|
||
# 仅使用YAMNet结果
|
||
result = {
|
||
'detected': yamnet_detection['detected'] or (yamnet_detection['confidence'] > self.fallback_threshold),
|
||
'confidence': yamnet_detection['confidence'],
|
||
'yamnet_confidence': yamnet_detection['confidence'],
|
||
'yamnet_detected': yamnet_detection['detected'],
|
||
'using_specialized_detector': False
|
||
}
|
||
|
||
# 添加YAMNet检测到的类别
|
||
result['top_categories'] = features["top_categories"]
|
||
|
||
return result
|