feat: first commit
This commit is contained in:
459
optimized_main.py
Normal file
459
optimized_main.py
Normal file
@@ -0,0 +1,459 @@
|
||||
"""
|
||||
主程序 - 优化后的猫咪翻译器V2系统入口
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
import numpy as np
|
||||
import librosa
|
||||
import sounddevice as sd
|
||||
import time
|
||||
import json
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
|
||||
from src.audio_input import AudioInput
|
||||
from src.hybrid_feature_extractor import HybridFeatureExtractor
|
||||
from src.dag_hmm_classifier_v2 import DAGHMMClassifierV2
|
||||
from src.cat_sound_detector import CatSoundDetector
|
||||
from src.sample_collector import SampleCollector
|
||||
from src.statistical_silence_detector import StatisticalSilenceDetector
|
||||
|
||||
class OptimizedCatTranslator:
|
||||
"""
|
||||
优化后的猫咪翻译器
|
||||
|
||||
集成了时序调制特征、统计静音检测、混合特征提取、
|
||||
调整梅尔滤波器数量以及DAG-HMM与优化特征结合的系统。
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
detector_model_path: Optional[str] = "./models/cat_detector_svm.pkl",
|
||||
intent_model_path: Optional[str] = "./models",
|
||||
feature_type: str = "hybrid",
|
||||
detector_threshold: float = 0.5):
|
||||
"""
|
||||
初始化优化后的猫咪翻译器
|
||||
|
||||
参数:
|
||||
detector_model_path: 猫叫声检测器模型路径
|
||||
intent_model_path: 意图分类器模型路径
|
||||
feature_type: 特征类型,可选"temporal_modulation", "mfcc", "yamnet", "hybrid"
|
||||
detector_threshold: 叫声检测阈值
|
||||
"""
|
||||
self.audio_input = AudioInput()
|
||||
self.feature_extractor = HybridFeatureExtractor()
|
||||
self.detector_threshold = detector_threshold
|
||||
self.feature_type = feature_type
|
||||
self.species_labels = {
|
||||
0: "none",
|
||||
1: "cat",
|
||||
2: "dog",
|
||||
3: "pig",
|
||||
}
|
||||
|
||||
# 加载猫叫声检测器
|
||||
if detector_model_path and os.path.exists(detector_model_path):
|
||||
self.cat_detector = CatSoundDetector()
|
||||
self.cat_detector.load_model(detector_model_path)
|
||||
print(f"猫叫声检测器已从 {detector_model_path} 加载")
|
||||
else:
|
||||
self.cat_detector = None
|
||||
print("未加载猫叫声检测器,将使用YAMNet进行检测")
|
||||
|
||||
# 加载意图分类器
|
||||
if intent_model_path and os.path.exists(intent_model_path):
|
||||
self.intent_classifier = DAGHMMClassifierV2(feature_type=feature_type)
|
||||
self.intent_classifier.load_model(intent_model_path)
|
||||
print(f"意图分类器已从 {intent_model_path} 加载")
|
||||
else:
|
||||
self.intent_classifier = None
|
||||
print("未加载意图分类器,将只进行猫叫声检测")
|
||||
|
||||
def analyze_file(self, file_path: str) -> Dict[str, Any]:
|
||||
"""
|
||||
分析音频文件
|
||||
|
||||
参数:
|
||||
file_path: 音频文件路径
|
||||
|
||||
返回:
|
||||
result: 分析结果
|
||||
"""
|
||||
print(f"分析音频文件: {file_path}")
|
||||
|
||||
# 加载音频
|
||||
audio, sr = self.audio_input.load_from_file(file_path)
|
||||
|
||||
# 分析音频
|
||||
return self.analyze_audio(audio, sr)
|
||||
|
||||
def analyze_audio(self, audio: np.ndarray, sr: int = 16000) -> Dict[str, Any]:
|
||||
"""
|
||||
分析音频数据
|
||||
|
||||
参数:
|
||||
audio: 音频数据
|
||||
sr: 采样率
|
||||
|
||||
返回:
|
||||
result: 分析结果
|
||||
"""
|
||||
# 1. 提取混合特征
|
||||
# hybrid_features = self.feature_extractor.extract_hybrid_features(audio)
|
||||
|
||||
# 2. 检测物种叫声
|
||||
if self.cat_detector:
|
||||
# 使用优化后的物种叫声检测器
|
||||
detector_result = self.cat_detector.predict(audio)
|
||||
confidence = detector_result["prob"]
|
||||
is_species_sound = detector_result["pred"] != 0 and confidence > self.detector_threshold
|
||||
else:
|
||||
# 使用YAMNet检测
|
||||
raise ValueError("未初始化物种叫声检测器")
|
||||
species_labels = self.species_labels[detector_result["pred"]]
|
||||
|
||||
# 3. 如果是猫叫声,进行意图分类
|
||||
intent_result = None
|
||||
if is_species_sound and self.intent_classifier:
|
||||
intent_result = self.intent_classifier.predict(audio, species_labels)
|
||||
|
||||
# 4. 构建结果
|
||||
result = {
|
||||
"species_labels": species_labels,
|
||||
"is_species_sound": bool(is_species_sound),
|
||||
"confidence": float(confidence),
|
||||
"intent_result": intent_result
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
def start_live_analysis(self,
|
||||
duration: float = 3.0,
|
||||
interval: float = 1.0,
|
||||
device: Optional[int] = None):
|
||||
"""
|
||||
开始实时分析
|
||||
|
||||
参数:
|
||||
duration: 每次录音持续时间(秒)
|
||||
interval: 分析间隔时间(秒)
|
||||
device: 录音设备ID
|
||||
"""
|
||||
print(f"开始实时分析,按Ctrl+C停止...")
|
||||
print(f"录音持续时间: {duration}秒,分析间隔: {interval}秒")
|
||||
|
||||
try:
|
||||
while True:
|
||||
# 录音
|
||||
print("\n录音中...")
|
||||
audio = self.audio_input.record_audio(duration=duration, device=device)
|
||||
|
||||
# 分析
|
||||
result = self.analyze_audio(audio)
|
||||
|
||||
# 输出结果
|
||||
if result["is_cat_sound"]:
|
||||
print(f"检测到猫叫声! 置信度: {result['confidence']:.4f}")
|
||||
if result["intent_result"]:
|
||||
intent = result["intent_result"]
|
||||
print(f"意图: {intent['class_name']} (置信度: {intent['confidence']:.4f})")
|
||||
print("所有类别概率:")
|
||||
for cls, prob in intent["probabilities"].items():
|
||||
print(f" {cls}: {prob:.4f}")
|
||||
else:
|
||||
print(f"未检测到猫叫声。置信度: {result['confidence']:.4f}")
|
||||
|
||||
# 等待
|
||||
time.sleep(interval)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n实时分析已停止")
|
||||
|
||||
def add_sample(self,
|
||||
file_path: str,
|
||||
label: str,
|
||||
is_cat_sound: bool = True,
|
||||
cat_name: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
添加训练样本
|
||||
|
||||
参数:
|
||||
file_path: 音频文件路径
|
||||
label: 标签
|
||||
is_cat_sound: 是否为猫叫声
|
||||
cat_name: 猫咪名称
|
||||
|
||||
返回:
|
||||
result: 添加结果
|
||||
"""
|
||||
print(f"添加样本: {file_path}, 标签: {label}, 是否猫叫声: {is_cat_sound}")
|
||||
|
||||
# 加载音频
|
||||
audio, sr = self.audio_input.load_from_file(file_path)
|
||||
|
||||
# 提取特征
|
||||
hybrid_features = self.feature_extractor.extract_hybrid_features(audio)
|
||||
|
||||
# 保存样本
|
||||
samples_dir = os.path.join("samples", cat_name if cat_name else "default")
|
||||
os.makedirs(samples_dir, exist_ok=True)
|
||||
|
||||
# 生成样本ID
|
||||
sample_id = int(time.time())
|
||||
|
||||
# 保存特征和元数据
|
||||
sample_data = {
|
||||
"features": hybrid_features.tolist(),
|
||||
"label": label,
|
||||
"is_cat_sound": is_cat_sound,
|
||||
"cat_name": cat_name,
|
||||
"file_path": file_path,
|
||||
"timestamp": sample_id
|
||||
}
|
||||
|
||||
sample_path = os.path.join(samples_dir, f"sample_{sample_id}.json")
|
||||
with open(sample_path, "w") as f:
|
||||
json.dump(sample_data, f)
|
||||
|
||||
print(f"样本已保存到 {sample_path}")
|
||||
|
||||
return {
|
||||
"sample_id": sample_id,
|
||||
"sample_path": sample_path
|
||||
}
|
||||
|
||||
def train_detector(self,
|
||||
model_type: str = "svm",
|
||||
output_path: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
训练猫叫声检测器
|
||||
|
||||
参数:
|
||||
model_type: 模型类型,可选"svm", "rf", "nn"
|
||||
output_path: 输出路径
|
||||
|
||||
返回:
|
||||
metrics: 训练指标
|
||||
"""
|
||||
print(f"训练物种叫声检测器,模型类型: {model_type}")
|
||||
|
||||
species_sounds_audio = {
|
||||
"cat_sounds": [],
|
||||
"dog_sounds": [],
|
||||
"non_sounds": [],
|
||||
}
|
||||
collector = SampleCollector()
|
||||
|
||||
for species_sounds in species_sounds_audio:
|
||||
[
|
||||
species_sounds_audio[species_sounds].append(librosa.load(file_path, sr=16000)[0]) for file_path in
|
||||
[meta["target_path"] for _, meta in collector.metadata[species_sounds].items()]
|
||||
]
|
||||
|
||||
|
||||
# 获取样本数量
|
||||
sample_counts = collector.get_sample_counts()
|
||||
print(f"猫叫声样本数量: {sample_counts['cat_sounds']}")
|
||||
print(f"狗叫声样本数量: {sample_counts['dog_sounds']}")
|
||||
print(f"非物种叫声样本数量: {sample_counts['non_sounds']}")
|
||||
|
||||
# 初始化检测器
|
||||
detector = CatSoundDetector(model_type=model_type)
|
||||
|
||||
# 准备训练数据
|
||||
|
||||
# 训练模型
|
||||
metrics = detector.train(species_sounds_audio, validation_split=0.2)
|
||||
|
||||
# 输出评估指标
|
||||
print("\n评估指标:")
|
||||
print(f"训练集准确率: {metrics['train_accuracy']:.4f}")
|
||||
# print(f"训练集精确率: {metrics['train_precision']:.4f}")
|
||||
# print(f"训练集召回率: {metrics['train_recall']:.4f}")
|
||||
# print(f"训练集F1得分: {metrics['train_f1']:.4f}")
|
||||
print(f"测试集准确率: {metrics['val_accuracy']:.4f}")
|
||||
print(f"测试集精确率: {metrics['val_precision']:.4f}")
|
||||
print(f"测试集召回率: {metrics['val_recall']:.4f}")
|
||||
print(f"测试集F1得分: {metrics['val_f1']:.4f}")
|
||||
|
||||
# 保存模型
|
||||
model_path = os.path.join(output_path, f"cat_detector_{model_type}.pkl")
|
||||
detector.save_model(model_path)
|
||||
print(f"模型已保存到: {model_path}")
|
||||
|
||||
return metrics
|
||||
|
||||
def train_intent_classifier(self,
|
||||
samples_dir: str,
|
||||
feature_type: str = "hybrid",
|
||||
output_path: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
训练意图分类器
|
||||
|
||||
参数:
|
||||
samples_dir: 样本目录
|
||||
feature_type: 特征类型,可选"temporal_modulation", "mfcc", "yamnet", "hybrid"
|
||||
output_path: 输出路径
|
||||
|
||||
返回:
|
||||
metrics: 训练指标
|
||||
"""
|
||||
print(f"训练意图分类器,特征类型: {feature_type}")
|
||||
|
||||
# 加载样本
|
||||
audio_files = []
|
||||
labels = []
|
||||
# 遍历样本目录下的所有子目录(每个子目录对应一个意图类别)
|
||||
for intent_dir in os.listdir(samples_dir):
|
||||
intent_path = os.path.join(samples_dir, intent_dir)
|
||||
if os.path.isdir(intent_path):
|
||||
for file in os.listdir(intent_path):
|
||||
if file.endswith(".wav") or file.endswith(".WAV") or file.endswith(".mp3"):
|
||||
audio_path = os.path.join(intent_path, file)
|
||||
audio, sr = librosa.load(audio_path, sr=16000)
|
||||
if audio.size > 0: # 确保音频数据不为空
|
||||
audio_files.append(audio)
|
||||
labels.append(intent_dir)
|
||||
else:
|
||||
print(f"警告: 音频文件 {audio_path} 为空,跳过。")
|
||||
|
||||
print(f"加载了 {len(audio_files)} 个样本,共 {len(set(labels))} 个意图类别")
|
||||
|
||||
if not audio_files or len(set(labels)) < 2: # 至少需要两个类别才能训练分类器
|
||||
print("错误: 训练意图分类器所需样本或类别不足,跳过训练。")
|
||||
return {"train_accuracy": float("nan"), "message": "样本或类别不足"}
|
||||
|
||||
# 初始化分类器
|
||||
classifier = DAGHMMClassifierV2(feature_type=feature_type)
|
||||
|
||||
# 训练模型
|
||||
metrics = classifier.fit(audio_files, labels)
|
||||
|
||||
# 保存模型
|
||||
if output_path:
|
||||
classifier.save_model(output_path)
|
||||
print(f"模型已保存到 {output_path}")
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="优化后的猫咪翻译器V2")
|
||||
|
||||
# 子命令
|
||||
subparsers = parser.add_subparsers(dest="command", help="命令")
|
||||
|
||||
# 分析命令
|
||||
analyze_parser = subparsers.add_parser("analyze", help="分析音频文件")
|
||||
analyze_parser.add_argument("file", help="音频文件路径")
|
||||
analyze_parser.add_argument("--detector", help="猫叫声检测器模型路径", default="./models/cat_detector_svm.pkl")
|
||||
analyze_parser.add_argument("--intent-model", help="意图分类器模型路径", default="./models")
|
||||
analyze_parser.add_argument("--feature-type", default="hybrid",
|
||||
choices=["temporal_modulation", "mfcc", "yamnet", "hybrid"],
|
||||
help="特征类型")
|
||||
analyze_parser.add_argument("--threshold", type=float, default=0.5, help="猫叫声检测阈值")
|
||||
|
||||
# 实时分析命令
|
||||
live_parser = subparsers.add_parser("live", help="实时分析麦克风输入")
|
||||
live_parser.add_argument("--detector", help="猫叫声检测器模型路径", default="./models/cat_detector_svm.pkl")
|
||||
live_parser.add_argument("--intent-model", help="意图分类器模型路径", default="./models")
|
||||
live_parser.add_argument("--feature-type", default="temporal_modulation",
|
||||
choices=["temporal_modulation", "mfcc", "yamnet", "hybrid"],
|
||||
help="特征类型")
|
||||
live_parser.add_argument("--threshold", type=float, default=0.5, help="猫叫声检测阈值")
|
||||
live_parser.add_argument("--duration", type=float, default=3.0, help="每次录音持续时间(秒)")
|
||||
live_parser.add_argument("--interval", type=float, default=1.0, help="分析间隔时间(秒)")
|
||||
live_parser.add_argument("--device", type=int, help="录音设备ID")
|
||||
|
||||
# 添加样本命令
|
||||
add_sample_parser = subparsers.add_parser("add-sample", help="添加训练样本")
|
||||
add_sample_parser.add_argument("file", help="音频文件路径")
|
||||
add_sample_parser.add_argument("label", help="标签")
|
||||
add_sample_parser.add_argument("--is-cat-sound", action="store_true", help="是否为猫叫声")
|
||||
add_sample_parser.add_argument("--cat", help="猫咪名称")
|
||||
|
||||
# 训练检测器命令
|
||||
train_detector_parser = subparsers.add_parser("train-detector", help="训练猫叫声检测器")
|
||||
train_detector_parser.add_argument("--model-type", default="svm", choices=["svm", "rf", "nn"], help="模型类型")
|
||||
train_detector_parser.add_argument("--output", default="./models", help="输出路径")
|
||||
|
||||
# 训练意图分类器命令
|
||||
train_intent_parser = subparsers.add_parser("train-intent", help="训练意图分类器")
|
||||
train_intent_parser.add_argument("--samples", required=True, help="样本目录")
|
||||
train_intent_parser.add_argument("--feature-type", default="hybrid",
|
||||
choices=["temporal_modulation", "mfcc", "yamnet", "hybrid"],
|
||||
help="特征类型")
|
||||
train_intent_parser.add_argument("--output", help="输出路径")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.command == "analyze":
|
||||
translator = OptimizedCatTranslator(
|
||||
detector_model_path=args.detector,
|
||||
intent_model_path=args.intent_model,
|
||||
feature_type=args.feature_type,
|
||||
detector_threshold=args.threshold
|
||||
)
|
||||
result = translator.analyze_file(args.file)
|
||||
|
||||
# 输出结果
|
||||
if result["is_species_sound"]:
|
||||
print(f"检测到 {result['species_labels']} 叫声! 置信度: {result['confidence']:.4f}")
|
||||
if result["intent_result"]:
|
||||
intent = result["intent_result"]
|
||||
if intent['winner']:
|
||||
print(f"意图: {intent['winner']} (置信度: {intent['confidence']:.4f})")
|
||||
else:
|
||||
print("⚠️特征学习中。。。")
|
||||
print(intent)
|
||||
|
||||
else:
|
||||
print(f"未检测到物种叫声。置信度: {result['confidence']:.4f}")
|
||||
|
||||
elif args.command == "live":
|
||||
translator = OptimizedCatTranslator(
|
||||
detector_model_path=args.detector,
|
||||
intent_model_path=args.intent_model,
|
||||
feature_type=args.feature_type,
|
||||
detector_threshold=args.threshold
|
||||
)
|
||||
translator.start_live_analysis(
|
||||
duration=args.duration,
|
||||
interval=args.interval,
|
||||
device=args.device
|
||||
)
|
||||
|
||||
elif args.command == "add-sample":
|
||||
translator = OptimizedCatTranslator()
|
||||
result = translator.add_sample(
|
||||
file_path=args.file,
|
||||
label=args.label,
|
||||
is_cat_sound=args.is_cat_sound,
|
||||
cat_name=args.cat
|
||||
)
|
||||
print(f"样本已添加,ID: {result['sample_id']}")
|
||||
|
||||
elif args.command == "train-detector":
|
||||
translator = OptimizedCatTranslator()
|
||||
metrics = translator.train_detector(
|
||||
model_type=args.model_type,
|
||||
output_path=args.output
|
||||
)
|
||||
print(f"训练完成")
|
||||
|
||||
elif args.command == "train-intent":
|
||||
translator = OptimizedCatTranslator()
|
||||
metrics = translator.train_intent_classifier(
|
||||
samples_dir=args.samples,
|
||||
feature_type=args.feature_type,
|
||||
output_path=args.output
|
||||
)
|
||||
print(f"训练完成")
|
||||
|
||||
else:
|
||||
parser.print_help()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user