Files
petshy/optimized_main.py
2025-10-08 20:39:09 +08:00

460 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
主程序 - 优化后的猫咪翻译器V2系统入口
"""
import os
import sys
import argparse
import numpy as np
import librosa
import sounddevice as sd
import time
import json
from typing import Dict, Any, List, Optional, Tuple
from src.audio_input import AudioInput
from src.hybrid_feature_extractor import HybridFeatureExtractor
from src.dag_hmm_classifier_v2 import DAGHMMClassifierV2
from src.cat_sound_detector import CatSoundDetector
from src.sample_collector import SampleCollector
from src.statistical_silence_detector import StatisticalSilenceDetector
class OptimizedCatTranslator:
"""
优化后的猫咪翻译器
集成了时序调制特征、统计静音检测、混合特征提取、
调整梅尔滤波器数量以及DAG-HMM与优化特征结合的系统。
"""
def __init__(self,
detector_model_path: Optional[str] = "./models/cat_detector_svm.pkl",
intent_model_path: Optional[str] = "./models",
feature_type: str = "hybrid",
detector_threshold: float = 0.5):
"""
初始化优化后的猫咪翻译器
参数:
detector_model_path: 猫叫声检测器模型路径
intent_model_path: 意图分类器模型路径
feature_type: 特征类型,可选"temporal_modulation", "mfcc", "yamnet", "hybrid"
detector_threshold: 叫声检测阈值
"""
self.audio_input = AudioInput()
self.feature_extractor = HybridFeatureExtractor()
self.detector_threshold = detector_threshold
self.feature_type = feature_type
self.species_labels = {
0: "none",
1: "cat",
2: "dog",
3: "pig",
}
# 加载猫叫声检测器
if detector_model_path and os.path.exists(detector_model_path):
self.cat_detector = CatSoundDetector()
self.cat_detector.load_model(detector_model_path)
print(f"猫叫声检测器已从 {detector_model_path} 加载")
else:
self.cat_detector = None
print("未加载猫叫声检测器将使用YAMNet进行检测")
# 加载意图分类器
if intent_model_path and os.path.exists(intent_model_path):
self.intent_classifier = DAGHMMClassifierV2(feature_type=feature_type)
self.intent_classifier.load_model(intent_model_path)
print(f"意图分类器已从 {intent_model_path} 加载")
else:
self.intent_classifier = None
print("未加载意图分类器,将只进行猫叫声检测")
def analyze_file(self, file_path: str) -> Dict[str, Any]:
"""
分析音频文件
参数:
file_path: 音频文件路径
返回:
result: 分析结果
"""
print(f"分析音频文件: {file_path}")
# 加载音频
audio, sr = self.audio_input.load_from_file(file_path)
# 分析音频
return self.analyze_audio(audio, sr)
def analyze_audio(self, audio: np.ndarray, sr: int = 16000) -> Dict[str, Any]:
"""
分析音频数据
参数:
audio: 音频数据
sr: 采样率
返回:
result: 分析结果
"""
# 1. 提取混合特征
# hybrid_features = self.feature_extractor.extract_hybrid_features(audio)
# 2. 检测物种叫声
if self.cat_detector:
# 使用优化后的物种叫声检测器
detector_result = self.cat_detector.predict(audio)
confidence = detector_result["prob"]
is_species_sound = detector_result["pred"] != 0 and confidence > self.detector_threshold
else:
# 使用YAMNet检测
raise ValueError("未初始化物种叫声检测器")
species_labels = self.species_labels[detector_result["pred"]]
# 3. 如果是猫叫声,进行意图分类
intent_result = None
if is_species_sound and self.intent_classifier:
intent_result = self.intent_classifier.predict(audio, species_labels)
# 4. 构建结果
result = {
"species_labels": species_labels,
"is_species_sound": bool(is_species_sound),
"confidence": float(confidence),
"intent_result": intent_result
}
return result
def start_live_analysis(self,
duration: float = 3.0,
interval: float = 1.0,
device: Optional[int] = None):
"""
开始实时分析
参数:
duration: 每次录音持续时间(秒)
interval: 分析间隔时间(秒)
device: 录音设备ID
"""
print(f"开始实时分析按Ctrl+C停止...")
print(f"录音持续时间: {duration}秒,分析间隔: {interval}")
try:
while True:
# 录音
print("\n录音中...")
audio = self.audio_input.record_audio(duration=duration, device=device)
# 分析
result = self.analyze_audio(audio)
# 输出结果
if result["is_cat_sound"]:
print(f"检测到猫叫声! 置信度: {result['confidence']:.4f}")
if result["intent_result"]:
intent = result["intent_result"]
print(f"意图: {intent['class_name']} (置信度: {intent['confidence']:.4f})")
print("所有类别概率:")
for cls, prob in intent["probabilities"].items():
print(f" {cls}: {prob:.4f}")
else:
print(f"未检测到猫叫声。置信度: {result['confidence']:.4f}")
# 等待
time.sleep(interval)
except KeyboardInterrupt:
print("\n实时分析已停止")
def add_sample(self,
file_path: str,
label: str,
is_cat_sound: bool = True,
cat_name: Optional[str] = None) -> Dict[str, Any]:
"""
添加训练样本
参数:
file_path: 音频文件路径
label: 标签
is_cat_sound: 是否为猫叫声
cat_name: 猫咪名称
返回:
result: 添加结果
"""
print(f"添加样本: {file_path}, 标签: {label}, 是否猫叫声: {is_cat_sound}")
# 加载音频
audio, sr = self.audio_input.load_from_file(file_path)
# 提取特征
hybrid_features = self.feature_extractor.extract_hybrid_features(audio)
# 保存样本
samples_dir = os.path.join("samples", cat_name if cat_name else "default")
os.makedirs(samples_dir, exist_ok=True)
# 生成样本ID
sample_id = int(time.time())
# 保存特征和元数据
sample_data = {
"features": hybrid_features.tolist(),
"label": label,
"is_cat_sound": is_cat_sound,
"cat_name": cat_name,
"file_path": file_path,
"timestamp": sample_id
}
sample_path = os.path.join(samples_dir, f"sample_{sample_id}.json")
with open(sample_path, "w") as f:
json.dump(sample_data, f)
print(f"样本已保存到 {sample_path}")
return {
"sample_id": sample_id,
"sample_path": sample_path
}
def train_detector(self,
model_type: str = "svm",
output_path: Optional[str] = None) -> Dict[str, Any]:
"""
训练猫叫声检测器
参数:
model_type: 模型类型,可选"svm", "rf", "nn"
output_path: 输出路径
返回:
metrics: 训练指标
"""
print(f"训练物种叫声检测器,模型类型: {model_type}")
species_sounds_audio = {
"cat_sounds": [],
"dog_sounds": [],
"non_sounds": [],
}
collector = SampleCollector()
for species_sounds in species_sounds_audio:
[
species_sounds_audio[species_sounds].append(librosa.load(file_path, sr=16000)[0]) for file_path in
[meta["target_path"] for _, meta in collector.metadata[species_sounds].items()]
]
# 获取样本数量
sample_counts = collector.get_sample_counts()
print(f"猫叫声样本数量: {sample_counts['cat_sounds']}")
print(f"狗叫声样本数量: {sample_counts['dog_sounds']}")
print(f"非物种叫声样本数量: {sample_counts['non_sounds']}")
# 初始化检测器
detector = CatSoundDetector(model_type=model_type)
# 准备训练数据
# 训练模型
metrics = detector.train(species_sounds_audio, validation_split=0.2)
# 输出评估指标
print("\n评估指标:")
print(f"训练集准确率: {metrics['train_accuracy']:.4f}")
# print(f"训练集精确率: {metrics['train_precision']:.4f}")
# print(f"训练集召回率: {metrics['train_recall']:.4f}")
# print(f"训练集F1得分: {metrics['train_f1']:.4f}")
print(f"测试集准确率: {metrics['val_accuracy']:.4f}")
print(f"测试集精确率: {metrics['val_precision']:.4f}")
print(f"测试集召回率: {metrics['val_recall']:.4f}")
print(f"测试集F1得分: {metrics['val_f1']:.4f}")
# 保存模型
model_path = os.path.join(output_path, f"cat_detector_{model_type}.pkl")
detector.save_model(model_path)
print(f"模型已保存到: {model_path}")
return metrics
def train_intent_classifier(self,
samples_dir: str,
feature_type: str = "hybrid",
output_path: Optional[str] = None) -> Dict[str, Any]:
"""
训练意图分类器
参数:
samples_dir: 样本目录
feature_type: 特征类型,可选"temporal_modulation", "mfcc", "yamnet", "hybrid"
output_path: 输出路径
返回:
metrics: 训练指标
"""
print(f"训练意图分类器,特征类型: {feature_type}")
# 加载样本
audio_files = []
labels = []
# 遍历样本目录下的所有子目录(每个子目录对应一个意图类别)
for intent_dir in os.listdir(samples_dir):
intent_path = os.path.join(samples_dir, intent_dir)
if os.path.isdir(intent_path):
for file in os.listdir(intent_path):
if file.endswith(".wav") or file.endswith(".WAV") or file.endswith(".mp3"):
audio_path = os.path.join(intent_path, file)
audio, sr = librosa.load(audio_path, sr=16000)
if audio.size > 0: # 确保音频数据不为空
audio_files.append(audio)
labels.append(intent_dir)
else:
print(f"警告: 音频文件 {audio_path} 为空,跳过。")
print(f"加载了 {len(audio_files)} 个样本,共 {len(set(labels))} 个意图类别")
if not audio_files or len(set(labels)) < 2: # 至少需要两个类别才能训练分类器
print("错误: 训练意图分类器所需样本或类别不足,跳过训练。")
return {"train_accuracy": float("nan"), "message": "样本或类别不足"}
# 初始化分类器
classifier = DAGHMMClassifierV2(feature_type=feature_type)
# 训练模型
metrics = classifier.fit(audio_files, labels)
# 保存模型
if output_path:
classifier.save_model(output_path)
print(f"模型已保存到 {output_path}")
return metrics
def main():
parser = argparse.ArgumentParser(description="优化后的猫咪翻译器V2")
# 子命令
subparsers = parser.add_subparsers(dest="command", help="命令")
# 分析命令
analyze_parser = subparsers.add_parser("analyze", help="分析音频文件")
analyze_parser.add_argument("file", help="音频文件路径")
analyze_parser.add_argument("--detector", help="猫叫声检测器模型路径", default="./models/cat_detector_svm.pkl")
analyze_parser.add_argument("--intent-model", help="意图分类器模型路径", default="./models")
analyze_parser.add_argument("--feature-type", default="hybrid",
choices=["temporal_modulation", "mfcc", "yamnet", "hybrid"],
help="特征类型")
analyze_parser.add_argument("--threshold", type=float, default=0.5, help="猫叫声检测阈值")
# 实时分析命令
live_parser = subparsers.add_parser("live", help="实时分析麦克风输入")
live_parser.add_argument("--detector", help="猫叫声检测器模型路径", default="./models/cat_detector_svm.pkl")
live_parser.add_argument("--intent-model", help="意图分类器模型路径", default="./models")
live_parser.add_argument("--feature-type", default="temporal_modulation",
choices=["temporal_modulation", "mfcc", "yamnet", "hybrid"],
help="特征类型")
live_parser.add_argument("--threshold", type=float, default=0.5, help="猫叫声检测阈值")
live_parser.add_argument("--duration", type=float, default=3.0, help="每次录音持续时间(秒)")
live_parser.add_argument("--interval", type=float, default=1.0, help="分析间隔时间(秒)")
live_parser.add_argument("--device", type=int, help="录音设备ID")
# 添加样本命令
add_sample_parser = subparsers.add_parser("add-sample", help="添加训练样本")
add_sample_parser.add_argument("file", help="音频文件路径")
add_sample_parser.add_argument("label", help="标签")
add_sample_parser.add_argument("--is-cat-sound", action="store_true", help="是否为猫叫声")
add_sample_parser.add_argument("--cat", help="猫咪名称")
# 训练检测器命令
train_detector_parser = subparsers.add_parser("train-detector", help="训练猫叫声检测器")
train_detector_parser.add_argument("--model-type", default="svm", choices=["svm", "rf", "nn"], help="模型类型")
train_detector_parser.add_argument("--output", default="./models", help="输出路径")
# 训练意图分类器命令
train_intent_parser = subparsers.add_parser("train-intent", help="训练意图分类器")
train_intent_parser.add_argument("--samples", required=True, help="样本目录")
train_intent_parser.add_argument("--feature-type", default="hybrid",
choices=["temporal_modulation", "mfcc", "yamnet", "hybrid"],
help="特征类型")
train_intent_parser.add_argument("--output", help="输出路径")
args = parser.parse_args()
if args.command == "analyze":
translator = OptimizedCatTranslator(
detector_model_path=args.detector,
intent_model_path=args.intent_model,
feature_type=args.feature_type,
detector_threshold=args.threshold
)
result = translator.analyze_file(args.file)
# 输出结果
if result["is_species_sound"]:
print(f"检测到 {result['species_labels']} 叫声! 置信度: {result['confidence']:.4f}")
if result["intent_result"]:
intent = result["intent_result"]
if intent['winner']:
print(f"意图: {intent['winner']} (置信度: {intent['confidence']:.4f})")
else:
print("⚠️特征学习中。。。")
print(intent)
else:
print(f"未检测到物种叫声。置信度: {result['confidence']:.4f}")
elif args.command == "live":
translator = OptimizedCatTranslator(
detector_model_path=args.detector,
intent_model_path=args.intent_model,
feature_type=args.feature_type,
detector_threshold=args.threshold
)
translator.start_live_analysis(
duration=args.duration,
interval=args.interval,
device=args.device
)
elif args.command == "add-sample":
translator = OptimizedCatTranslator()
result = translator.add_sample(
file_path=args.file,
label=args.label,
is_cat_sound=args.is_cat_sound,
cat_name=args.cat
)
print(f"样本已添加ID: {result['sample_id']}")
elif args.command == "train-detector":
translator = OptimizedCatTranslator()
metrics = translator.train_detector(
model_type=args.model_type,
output_path=args.output
)
print(f"训练完成")
elif args.command == "train-intent":
translator = OptimizedCatTranslator()
metrics = translator.train_intent_classifier(
samples_dir=args.samples,
feature_type=args.feature_type,
output_path=args.output
)
print(f"训练完成")
else:
parser.print_help()
if __name__ == "__main__":
main()