feat: first commit
This commit is contained in:
84
src/integrated_detector.py
Normal file
84
src/integrated_detector.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""
|
||||
猫叫声检测器集成模块 - 将专用猫叫声检测器集成到主系统
|
||||
"""
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
from src.hybrid_feature_extractor import HybridFeatureExtractor
|
||||
from src.cat_sound_detector import CatSoundDetector
|
||||
|
||||
class IntegratedCatDetector:
|
||||
"""集成猫叫声检测器类,结合YAMNet和专用检测器"""
|
||||
|
||||
def __init__(self, detector_model_path: Optional[str] = None,
|
||||
threshold: float = 0.5, fallback_threshold: float = 0.1):
|
||||
"""
|
||||
初始化集成猫叫声检测器
|
||||
|
||||
参数:
|
||||
detector_model_path: 专用检测器模型路径,如果为None则仅使用YAMNet
|
||||
threshold: 专用检测器阈值
|
||||
fallback_threshold: YAMNet回退阈值
|
||||
"""
|
||||
self.feature_extractor = HybridFeatureExtractor()
|
||||
self.detector = None
|
||||
self.threshold = threshold
|
||||
self.fallback_threshold = fallback_threshold
|
||||
|
||||
# 如果提供了模型路径,加载专用检测器
|
||||
if detector_model_path and os.path.exists(detector_model_path):
|
||||
try:
|
||||
self.detector = CatSoundDetector(model_path=detector_model_path)
|
||||
print(f"已加载专用猫叫声检测器: {detector_model_path}")
|
||||
except Exception as e:
|
||||
print(f"加载专用猫叫声检测器失败: {e}")
|
||||
print("将使用YAMNet作为回退方案")
|
||||
|
||||
def detect(self, audio_data: np.ndarray) -> Dict[str, Any]:
|
||||
"""
|
||||
检测音频是否包含猫叫声
|
||||
|
||||
参数:
|
||||
audio_data: 音频数据
|
||||
|
||||
返回:
|
||||
result: 检测结果
|
||||
"""
|
||||
# 提取YAMNet特征
|
||||
features = self.feature_extractor.process_audio(audio_data)
|
||||
|
||||
# 获取YAMNet的猫叫声检测结果
|
||||
yamnet_detection = features["cat_detection"]
|
||||
|
||||
# 如果有专用检测器,使用它进行检测
|
||||
if self.detector is not None:
|
||||
# 使用平均嵌入向量
|
||||
embedding_mean = np.mean(features["embeddings"], axis=0)
|
||||
|
||||
# 使用专用检测器预测
|
||||
detector_result = self.detector.predict(embedding_mean)
|
||||
|
||||
# 合并结果
|
||||
result = {
|
||||
'detected': detector_result['detected'] or (detector_result['confidence'] > self.threshold),
|
||||
'confidence': detector_result['confidence'],
|
||||
'yamnet_confidence': yamnet_detection['confidence'],
|
||||
'yamnet_detected': yamnet_detection['detected'],
|
||||
'using_specialized_detector': True
|
||||
}
|
||||
else:
|
||||
# 仅使用YAMNet结果
|
||||
result = {
|
||||
'detected': yamnet_detection['detected'] or (yamnet_detection['confidence'] > self.fallback_threshold),
|
||||
'confidence': yamnet_detection['confidence'],
|
||||
'yamnet_confidence': yamnet_detection['confidence'],
|
||||
'yamnet_detected': yamnet_detection['detected'],
|
||||
'using_specialized_detector': False
|
||||
}
|
||||
|
||||
# 添加YAMNet检测到的类别
|
||||
result['top_categories'] = features["top_categories"]
|
||||
|
||||
return result
|
||||
Reference in New Issue
Block a user