feat: first commit

This commit is contained in:
2025-10-08 20:39:09 +08:00
commit 80f0e7f8d7
82 changed files with 12216 additions and 0 deletions

223
api.py Normal file
View File

@@ -0,0 +1,223 @@
import uvicorn
import os
import tempfile
import json
import librosa
from io import BytesIO
from optimized_main import OptimizedCatTranslator
from fastapi import FastAPI, UploadFile, File, Form, HTTPException, BackgroundTasks
from fastapi.responses import JSONResponse, StreamingResponse
app = FastAPI(title="物种翻译器API服务")
# 全局翻译器实例
translator = OptimizedCatTranslator(
detector_model_path = "./models/cat_detector_svm.pkl",
intent_model_path = "./models",
feature_type = "hybrid",
detector_threshold = 0.5
)
# 实时分析的状态管理
live_analysis_running = False
@app.post("/analyze/audio", summary="分析原始音频数据")
async def analyze_audio(
audio_data: bytes = File(...),
sr: int = Form(16000),
):
"""
分析原始音频数据,返回物种叫声检测和意图分析结果
接口路径: `/analyze/audio`
请求方法: POST
Args:
audio_data: 原始音频字节数据必填通过File上传
sr: 音频采样率默认16000可选通过Form传递
Returns:
JSONResponse: 包含物种检测和意图分析结果的JSON数据
Raises:
HTTPException:
- 400: 音频格式不支持或文件不完整(提示"音频格式不支持请上传WAV/MP3/FLAC等常见格式..."
- 500: 服务器内部处理异常(返回具体错误信息)
"""
try:
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav", dir="./data/tmp") as temp_file:
temp_file.write(audio_data)
temp_file_path = temp_file.name
# 将字节数据转换为numpy数组
audio, _ = librosa.load(temp_file_path)
# 分析音频
result = translator.analyze_audio(audio, sr)
# 删除临时文件
os.unlink(temp_file_path)
return JSONResponse(content=result)
except Exception as e:
error_msg = str(e).lower()
if "could not open" in error_msg or "format not recognised" in error_msg:
raise HTTPException(
status_code=400,
detail="音频格式不支持请上传WAV/MP3/FLAC等常见格式或检查文件完整性"
)
raise HTTPException(status_code=500, detail=error_msg)
@app.post("/samples/add", summary="添加训练样本")
async def add_sample(
file: UploadFile = File(...),
label: str = Form(...),
is_cat_sound: bool = Form(True),
cat_name: str = Form(None)
):
"""添加音频样本到训练集"""
try:
# 创建临时文件保存上传的音频
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as temp_file:
temp_file.write(await file.read())
temp_file_path = temp_file.name
# 添加样本
result = translator.add_sample(
file_path=temp_file_path,
label=label,
is_cat_sound=is_cat_sound,
cat_name=cat_name
)
# 删除临时文件
os.unlink(temp_file_path)
return JSONResponse(content=result)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/train/detector", summary="训练猫叫声检测器")
async def train_detector(
background_tasks: BackgroundTasks,
model_type: str = Form("svm"),
output_path: str = Form("./models")
):
"""训练新的猫叫声检测器模型(后台任务)"""
try:
# 使用后台任务处理长时间运行的训练过程
result = {"status": "training started", "model_type": model_type, "output_path": output_path}
def train_task():
try:
metrics = translator.train_detector(
model_type=model_type,
output_path=output_path
)
# 可以将结果保存到文件或数据库
with open(os.path.join(output_path, "training_metrics.json"), "w") as f:
json.dump(metrics, f)
except Exception as e:
print(f"Training error: {str(e)}")
background_tasks.add_task(train_task)
return JSONResponse(content=result)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/train/intent", summary="训练意图分类器")
async def train_intent_classifier(
background_tasks: BackgroundTasks,
samples_dir: str = Form(...),
feature_type: str = Form("hybrid"),
output_path: str = Form(None)
):
"""训练新的意图分类器模型(后台任务)"""
try:
# 使用后台任务处理长时间运行的训练过程
result = {"status": "training started", "feature_type": feature_type, "samples_dir": samples_dir}
def train_task():
try:
metrics = translator.train_intent_classifier(
samples_dir=samples_dir,
feature_type=feature_type,
output_path=output_path
)
# 可以将结果保存到文件或数据库
if output_path:
with open(os.path.join(output_path, "intent_training_metrics.json"), "w") as f:
json.dump(metrics, f)
except Exception as e:
print(f"Intent training error: {str(e)}")
background_tasks.add_task(train_task)
return JSONResponse(content=result)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/live/start", summary="开始实时分析")
async def start_live_analysis(
duration: float = Form(3.0),
interval: float = Form(1.0),
device: int = Form(None),
detector_model_path: str = Form("./models/cat_detector_svm.pkl"),
intent_model_path: str = Form("./models"),
feature_type: str = Form("temporal_modulation"),
detector_threshold: float = Form(0.5)
):
"""开始实时音频分析"""
global live_analysis_running
if live_analysis_running:
raise HTTPException(status_code=400, detail="实时分析已在运行中")
try:
live_analysis_running = True
# 这里简化处理实际应用可能需要使用WebSocket
def generate():
global live_analysis_running
try:
while live_analysis_running:
# 录音
audio = translator.audio_input.record_audio(duration=duration, device=device)
# 分析
result = translator.analyze_audio(audio)
# 发送结果
yield f"data: {json.dumps(result)}\n\n"
# 等待
time.sleep(interval)
except Exception as e:
yield f"data: {json.dumps({'error': str(e)})}\n\n"
finally:
live_analysis_running = False
return StreamingResponse(generate(), media_type="text/event-stream")
except Exception as e:
live_analysis_running = False
raise HTTPException(status_code=500, detail=str(e))
@app.post("/live/stop", summary="停止实时分析")
async def stop_live_analysis():
"""停止实时音频分析"""
global live_analysis_running
live_analysis_running = False
return JSONResponse(content={"status": "live analysis stopped"})
if __name__ == "__main__":
import time # 导入time模块用于实时分析中的延迟
uvicorn.run(app, host="0.0.0.0", port=8000)