460 lines
17 KiB
Python
460 lines
17 KiB
Python
"""
|
||
主程序 - 优化后的猫咪翻译器V2系统入口
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import argparse
|
||
import numpy as np
|
||
import librosa
|
||
import sounddevice as sd
|
||
import time
|
||
import json
|
||
from typing import Dict, Any, List, Optional, Tuple
|
||
|
||
from src.audio_input import AudioInput
|
||
from src.hybrid_feature_extractor import HybridFeatureExtractor
|
||
from src.dag_hmm_classifier_v2 import DAGHMMClassifierV2
|
||
from src.cat_sound_detector import CatSoundDetector
|
||
from src.sample_collector import SampleCollector
|
||
from src.statistical_silence_detector import StatisticalSilenceDetector
|
||
|
||
class OptimizedCatTranslator:
|
||
"""
|
||
优化后的猫咪翻译器
|
||
|
||
集成了时序调制特征、统计静音检测、混合特征提取、
|
||
调整梅尔滤波器数量以及DAG-HMM与优化特征结合的系统。
|
||
"""
|
||
|
||
def __init__(self,
|
||
detector_model_path: Optional[str] = "./models/cat_detector_svm.pkl",
|
||
intent_model_path: Optional[str] = "./models",
|
||
feature_type: str = "hybrid",
|
||
detector_threshold: float = 0.5):
|
||
"""
|
||
初始化优化后的猫咪翻译器
|
||
|
||
参数:
|
||
detector_model_path: 猫叫声检测器模型路径
|
||
intent_model_path: 意图分类器模型路径
|
||
feature_type: 特征类型,可选"temporal_modulation", "mfcc", "yamnet", "hybrid"
|
||
detector_threshold: 叫声检测阈值
|
||
"""
|
||
self.audio_input = AudioInput()
|
||
self.feature_extractor = HybridFeatureExtractor()
|
||
self.detector_threshold = detector_threshold
|
||
self.feature_type = feature_type
|
||
self.species_labels = {
|
||
0: "none",
|
||
1: "cat",
|
||
2: "dog",
|
||
3: "pig",
|
||
}
|
||
|
||
# 加载猫叫声检测器
|
||
if detector_model_path and os.path.exists(detector_model_path):
|
||
self.cat_detector = CatSoundDetector()
|
||
self.cat_detector.load_model(detector_model_path)
|
||
print(f"猫叫声检测器已从 {detector_model_path} 加载")
|
||
else:
|
||
self.cat_detector = None
|
||
print("未加载猫叫声检测器,将使用YAMNet进行检测")
|
||
|
||
# 加载意图分类器
|
||
if intent_model_path and os.path.exists(intent_model_path):
|
||
self.intent_classifier = DAGHMMClassifierV2(feature_type=feature_type)
|
||
self.intent_classifier.load_model(intent_model_path)
|
||
print(f"意图分类器已从 {intent_model_path} 加载")
|
||
else:
|
||
self.intent_classifier = None
|
||
print("未加载意图分类器,将只进行猫叫声检测")
|
||
|
||
def analyze_file(self, file_path: str) -> Dict[str, Any]:
|
||
"""
|
||
分析音频文件
|
||
|
||
参数:
|
||
file_path: 音频文件路径
|
||
|
||
返回:
|
||
result: 分析结果
|
||
"""
|
||
print(f"分析音频文件: {file_path}")
|
||
|
||
# 加载音频
|
||
audio, sr = self.audio_input.load_from_file(file_path)
|
||
|
||
# 分析音频
|
||
return self.analyze_audio(audio, sr)
|
||
|
||
def analyze_audio(self, audio: np.ndarray, sr: int = 16000) -> Dict[str, Any]:
|
||
"""
|
||
分析音频数据
|
||
|
||
参数:
|
||
audio: 音频数据
|
||
sr: 采样率
|
||
|
||
返回:
|
||
result: 分析结果
|
||
"""
|
||
# 1. 提取混合特征
|
||
# hybrid_features = self.feature_extractor.extract_hybrid_features(audio)
|
||
|
||
# 2. 检测物种叫声
|
||
if self.cat_detector:
|
||
# 使用优化后的物种叫声检测器
|
||
detector_result = self.cat_detector.predict(audio)
|
||
confidence = detector_result["prob"]
|
||
is_species_sound = detector_result["pred"] != 0 and confidence > self.detector_threshold
|
||
else:
|
||
# 使用YAMNet检测
|
||
raise ValueError("未初始化物种叫声检测器")
|
||
species_labels = self.species_labels[detector_result["pred"]]
|
||
|
||
# 3. 如果是猫叫声,进行意图分类
|
||
intent_result = None
|
||
if is_species_sound and self.intent_classifier:
|
||
intent_result = self.intent_classifier.predict(audio, species_labels)
|
||
|
||
# 4. 构建结果
|
||
result = {
|
||
"species_labels": species_labels,
|
||
"is_species_sound": bool(is_species_sound),
|
||
"confidence": float(confidence),
|
||
"intent_result": intent_result
|
||
}
|
||
|
||
return result
|
||
|
||
def start_live_analysis(self,
|
||
duration: float = 3.0,
|
||
interval: float = 1.0,
|
||
device: Optional[int] = None):
|
||
"""
|
||
开始实时分析
|
||
|
||
参数:
|
||
duration: 每次录音持续时间(秒)
|
||
interval: 分析间隔时间(秒)
|
||
device: 录音设备ID
|
||
"""
|
||
print(f"开始实时分析,按Ctrl+C停止...")
|
||
print(f"录音持续时间: {duration}秒,分析间隔: {interval}秒")
|
||
|
||
try:
|
||
while True:
|
||
# 录音
|
||
print("\n录音中...")
|
||
audio = self.audio_input.record_audio(duration=duration, device=device)
|
||
|
||
# 分析
|
||
result = self.analyze_audio(audio)
|
||
|
||
# 输出结果
|
||
if result["is_cat_sound"]:
|
||
print(f"检测到猫叫声! 置信度: {result['confidence']:.4f}")
|
||
if result["intent_result"]:
|
||
intent = result["intent_result"]
|
||
print(f"意图: {intent['class_name']} (置信度: {intent['confidence']:.4f})")
|
||
print("所有类别概率:")
|
||
for cls, prob in intent["probabilities"].items():
|
||
print(f" {cls}: {prob:.4f}")
|
||
else:
|
||
print(f"未检测到猫叫声。置信度: {result['confidence']:.4f}")
|
||
|
||
# 等待
|
||
time.sleep(interval)
|
||
|
||
except KeyboardInterrupt:
|
||
print("\n实时分析已停止")
|
||
|
||
def add_sample(self,
|
||
file_path: str,
|
||
label: str,
|
||
is_cat_sound: bool = True,
|
||
cat_name: Optional[str] = None) -> Dict[str, Any]:
|
||
"""
|
||
添加训练样本
|
||
|
||
参数:
|
||
file_path: 音频文件路径
|
||
label: 标签
|
||
is_cat_sound: 是否为猫叫声
|
||
cat_name: 猫咪名称
|
||
|
||
返回:
|
||
result: 添加结果
|
||
"""
|
||
print(f"添加样本: {file_path}, 标签: {label}, 是否猫叫声: {is_cat_sound}")
|
||
|
||
# 加载音频
|
||
audio, sr = self.audio_input.load_from_file(file_path)
|
||
|
||
# 提取特征
|
||
hybrid_features = self.feature_extractor.extract_hybrid_features(audio)
|
||
|
||
# 保存样本
|
||
samples_dir = os.path.join("samples", cat_name if cat_name else "default")
|
||
os.makedirs(samples_dir, exist_ok=True)
|
||
|
||
# 生成样本ID
|
||
sample_id = int(time.time())
|
||
|
||
# 保存特征和元数据
|
||
sample_data = {
|
||
"features": hybrid_features.tolist(),
|
||
"label": label,
|
||
"is_cat_sound": is_cat_sound,
|
||
"cat_name": cat_name,
|
||
"file_path": file_path,
|
||
"timestamp": sample_id
|
||
}
|
||
|
||
sample_path = os.path.join(samples_dir, f"sample_{sample_id}.json")
|
||
with open(sample_path, "w") as f:
|
||
json.dump(sample_data, f)
|
||
|
||
print(f"样本已保存到 {sample_path}")
|
||
|
||
return {
|
||
"sample_id": sample_id,
|
||
"sample_path": sample_path
|
||
}
|
||
|
||
def train_detector(self,
|
||
model_type: str = "svm",
|
||
output_path: Optional[str] = None) -> Dict[str, Any]:
|
||
"""
|
||
训练猫叫声检测器
|
||
|
||
参数:
|
||
model_type: 模型类型,可选"svm", "rf", "nn"
|
||
output_path: 输出路径
|
||
|
||
返回:
|
||
metrics: 训练指标
|
||
"""
|
||
print(f"训练物种叫声检测器,模型类型: {model_type}")
|
||
|
||
species_sounds_audio = {
|
||
"cat_sounds": [],
|
||
"dog_sounds": [],
|
||
"non_sounds": [],
|
||
}
|
||
collector = SampleCollector()
|
||
|
||
for species_sounds in species_sounds_audio:
|
||
[
|
||
species_sounds_audio[species_sounds].append(librosa.load(file_path, sr=16000)[0]) for file_path in
|
||
[meta["target_path"] for _, meta in collector.metadata[species_sounds].items()]
|
||
]
|
||
|
||
|
||
# 获取样本数量
|
||
sample_counts = collector.get_sample_counts()
|
||
print(f"猫叫声样本数量: {sample_counts['cat_sounds']}")
|
||
print(f"狗叫声样本数量: {sample_counts['dog_sounds']}")
|
||
print(f"非物种叫声样本数量: {sample_counts['non_sounds']}")
|
||
|
||
# 初始化检测器
|
||
detector = CatSoundDetector(model_type=model_type)
|
||
|
||
# 准备训练数据
|
||
|
||
# 训练模型
|
||
metrics = detector.train(species_sounds_audio, validation_split=0.2)
|
||
|
||
# 输出评估指标
|
||
print("\n评估指标:")
|
||
print(f"训练集准确率: {metrics['train_accuracy']:.4f}")
|
||
# print(f"训练集精确率: {metrics['train_precision']:.4f}")
|
||
# print(f"训练集召回率: {metrics['train_recall']:.4f}")
|
||
# print(f"训练集F1得分: {metrics['train_f1']:.4f}")
|
||
print(f"测试集准确率: {metrics['val_accuracy']:.4f}")
|
||
print(f"测试集精确率: {metrics['val_precision']:.4f}")
|
||
print(f"测试集召回率: {metrics['val_recall']:.4f}")
|
||
print(f"测试集F1得分: {metrics['val_f1']:.4f}")
|
||
|
||
# 保存模型
|
||
model_path = os.path.join(output_path, f"cat_detector_{model_type}.pkl")
|
||
detector.save_model(model_path)
|
||
print(f"模型已保存到: {model_path}")
|
||
|
||
return metrics
|
||
|
||
def train_intent_classifier(self,
|
||
samples_dir: str,
|
||
feature_type: str = "hybrid",
|
||
output_path: Optional[str] = None) -> Dict[str, Any]:
|
||
"""
|
||
训练意图分类器
|
||
|
||
参数:
|
||
samples_dir: 样本目录
|
||
feature_type: 特征类型,可选"temporal_modulation", "mfcc", "yamnet", "hybrid"
|
||
output_path: 输出路径
|
||
|
||
返回:
|
||
metrics: 训练指标
|
||
"""
|
||
print(f"训练意图分类器,特征类型: {feature_type}")
|
||
|
||
# 加载样本
|
||
audio_files = []
|
||
labels = []
|
||
# 遍历样本目录下的所有子目录(每个子目录对应一个意图类别)
|
||
for intent_dir in os.listdir(samples_dir):
|
||
intent_path = os.path.join(samples_dir, intent_dir)
|
||
if os.path.isdir(intent_path):
|
||
for file in os.listdir(intent_path):
|
||
if file.endswith(".wav") or file.endswith(".WAV") or file.endswith(".mp3"):
|
||
audio_path = os.path.join(intent_path, file)
|
||
audio, sr = librosa.load(audio_path, sr=16000)
|
||
if audio.size > 0: # 确保音频数据不为空
|
||
audio_files.append(audio)
|
||
labels.append(intent_dir)
|
||
else:
|
||
print(f"警告: 音频文件 {audio_path} 为空,跳过。")
|
||
|
||
print(f"加载了 {len(audio_files)} 个样本,共 {len(set(labels))} 个意图类别")
|
||
|
||
if not audio_files or len(set(labels)) < 2: # 至少需要两个类别才能训练分类器
|
||
print("错误: 训练意图分类器所需样本或类别不足,跳过训练。")
|
||
return {"train_accuracy": float("nan"), "message": "样本或类别不足"}
|
||
|
||
# 初始化分类器
|
||
classifier = DAGHMMClassifierV2(feature_type=feature_type)
|
||
|
||
# 训练模型
|
||
metrics = classifier.fit(audio_files, labels)
|
||
|
||
# 保存模型
|
||
if output_path:
|
||
classifier.save_model(output_path)
|
||
print(f"模型已保存到 {output_path}")
|
||
|
||
return metrics
|
||
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(description="优化后的猫咪翻译器V2")
|
||
|
||
# 子命令
|
||
subparsers = parser.add_subparsers(dest="command", help="命令")
|
||
|
||
# 分析命令
|
||
analyze_parser = subparsers.add_parser("analyze", help="分析音频文件")
|
||
analyze_parser.add_argument("file", help="音频文件路径")
|
||
analyze_parser.add_argument("--detector", help="猫叫声检测器模型路径", default="./models/cat_detector_svm.pkl")
|
||
analyze_parser.add_argument("--intent-model", help="意图分类器模型路径", default="./models")
|
||
analyze_parser.add_argument("--feature-type", default="hybrid",
|
||
choices=["temporal_modulation", "mfcc", "yamnet", "hybrid"],
|
||
help="特征类型")
|
||
analyze_parser.add_argument("--threshold", type=float, default=0.5, help="猫叫声检测阈值")
|
||
|
||
# 实时分析命令
|
||
live_parser = subparsers.add_parser("live", help="实时分析麦克风输入")
|
||
live_parser.add_argument("--detector", help="猫叫声检测器模型路径", default="./models/cat_detector_svm.pkl")
|
||
live_parser.add_argument("--intent-model", help="意图分类器模型路径", default="./models")
|
||
live_parser.add_argument("--feature-type", default="temporal_modulation",
|
||
choices=["temporal_modulation", "mfcc", "yamnet", "hybrid"],
|
||
help="特征类型")
|
||
live_parser.add_argument("--threshold", type=float, default=0.5, help="猫叫声检测阈值")
|
||
live_parser.add_argument("--duration", type=float, default=3.0, help="每次录音持续时间(秒)")
|
||
live_parser.add_argument("--interval", type=float, default=1.0, help="分析间隔时间(秒)")
|
||
live_parser.add_argument("--device", type=int, help="录音设备ID")
|
||
|
||
# 添加样本命令
|
||
add_sample_parser = subparsers.add_parser("add-sample", help="添加训练样本")
|
||
add_sample_parser.add_argument("file", help="音频文件路径")
|
||
add_sample_parser.add_argument("label", help="标签")
|
||
add_sample_parser.add_argument("--is-cat-sound", action="store_true", help="是否为猫叫声")
|
||
add_sample_parser.add_argument("--cat", help="猫咪名称")
|
||
|
||
# 训练检测器命令
|
||
train_detector_parser = subparsers.add_parser("train-detector", help="训练猫叫声检测器")
|
||
train_detector_parser.add_argument("--model-type", default="svm", choices=["svm", "rf", "nn"], help="模型类型")
|
||
train_detector_parser.add_argument("--output", default="./models", help="输出路径")
|
||
|
||
# 训练意图分类器命令
|
||
train_intent_parser = subparsers.add_parser("train-intent", help="训练意图分类器")
|
||
train_intent_parser.add_argument("--samples", required=True, help="样本目录")
|
||
train_intent_parser.add_argument("--feature-type", default="hybrid",
|
||
choices=["temporal_modulation", "mfcc", "yamnet", "hybrid"],
|
||
help="特征类型")
|
||
train_intent_parser.add_argument("--output", help="输出路径")
|
||
|
||
args = parser.parse_args()
|
||
|
||
if args.command == "analyze":
|
||
translator = OptimizedCatTranslator(
|
||
detector_model_path=args.detector,
|
||
intent_model_path=args.intent_model,
|
||
feature_type=args.feature_type,
|
||
detector_threshold=args.threshold
|
||
)
|
||
result = translator.analyze_file(args.file)
|
||
|
||
# 输出结果
|
||
if result["is_species_sound"]:
|
||
print(f"检测到 {result['species_labels']} 叫声! 置信度: {result['confidence']:.4f}")
|
||
if result["intent_result"]:
|
||
intent = result["intent_result"]
|
||
if intent['winner']:
|
||
print(f"意图: {intent['winner']} (置信度: {intent['confidence']:.4f})")
|
||
else:
|
||
print("⚠️特征学习中。。。")
|
||
print(intent)
|
||
|
||
else:
|
||
print(f"未检测到物种叫声。置信度: {result['confidence']:.4f}")
|
||
|
||
elif args.command == "live":
|
||
translator = OptimizedCatTranslator(
|
||
detector_model_path=args.detector,
|
||
intent_model_path=args.intent_model,
|
||
feature_type=args.feature_type,
|
||
detector_threshold=args.threshold
|
||
)
|
||
translator.start_live_analysis(
|
||
duration=args.duration,
|
||
interval=args.interval,
|
||
device=args.device
|
||
)
|
||
|
||
elif args.command == "add-sample":
|
||
translator = OptimizedCatTranslator()
|
||
result = translator.add_sample(
|
||
file_path=args.file,
|
||
label=args.label,
|
||
is_cat_sound=args.is_cat_sound,
|
||
cat_name=args.cat
|
||
)
|
||
print(f"样本已添加,ID: {result['sample_id']}")
|
||
|
||
elif args.command == "train-detector":
|
||
translator = OptimizedCatTranslator()
|
||
metrics = translator.train_detector(
|
||
model_type=args.model_type,
|
||
output_path=args.output
|
||
)
|
||
print(f"训练完成")
|
||
|
||
elif args.command == "train-intent":
|
||
translator = OptimizedCatTranslator()
|
||
metrics = translator.train_intent_classifier(
|
||
samples_dir=args.samples,
|
||
feature_type=args.feature_type,
|
||
output_path=args.output
|
||
)
|
||
print(f"训练完成")
|
||
|
||
else:
|
||
parser.print_help()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|