Files
petshy/src/integrated_detector.py
2025-10-08 20:39:09 +08:00

85 lines
3.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
猫叫声检测器集成模块 - 将专用猫叫声检测器集成到主系统
"""
import os
import numpy as np
from typing import Dict, Any, Optional
from src.hybrid_feature_extractor import HybridFeatureExtractor
from src.cat_sound_detector import CatSoundDetector
class IntegratedCatDetector:
"""集成猫叫声检测器类结合YAMNet和专用检测器"""
def __init__(self, detector_model_path: Optional[str] = None,
threshold: float = 0.5, fallback_threshold: float = 0.1):
"""
初始化集成猫叫声检测器
参数:
detector_model_path: 专用检测器模型路径如果为None则仅使用YAMNet
threshold: 专用检测器阈值
fallback_threshold: YAMNet回退阈值
"""
self.feature_extractor = HybridFeatureExtractor()
self.detector = None
self.threshold = threshold
self.fallback_threshold = fallback_threshold
# 如果提供了模型路径,加载专用检测器
if detector_model_path and os.path.exists(detector_model_path):
try:
self.detector = CatSoundDetector(model_path=detector_model_path)
print(f"已加载专用猫叫声检测器: {detector_model_path}")
except Exception as e:
print(f"加载专用猫叫声检测器失败: {e}")
print("将使用YAMNet作为回退方案")
def detect(self, audio_data: np.ndarray) -> Dict[str, Any]:
"""
检测音频是否包含猫叫声
参数:
audio_data: 音频数据
返回:
result: 检测结果
"""
# 提取YAMNet特征
features = self.feature_extractor.process_audio(audio_data)
# 获取YAMNet的猫叫声检测结果
yamnet_detection = features["cat_detection"]
# 如果有专用检测器,使用它进行检测
if self.detector is not None:
# 使用平均嵌入向量
embedding_mean = np.mean(features["embeddings"], axis=0)
# 使用专用检测器预测
detector_result = self.detector.predict(embedding_mean)
# 合并结果
result = {
'detected': detector_result['detected'] or (detector_result['confidence'] > self.threshold),
'confidence': detector_result['confidence'],
'yamnet_confidence': yamnet_detection['confidence'],
'yamnet_detected': yamnet_detection['detected'],
'using_specialized_detector': True
}
else:
# 仅使用YAMNet结果
result = {
'detected': yamnet_detection['detected'] or (yamnet_detection['confidence'] > self.fallback_threshold),
'confidence': yamnet_detection['confidence'],
'yamnet_confidence': yamnet_detection['confidence'],
'yamnet_detected': yamnet_detection['detected'],
'using_specialized_detector': False
}
# 添加YAMNet检测到的类别
result['top_categories'] = features["top_categories"]
return result