Files
petshy/api.py
2025-10-08 20:39:09 +08:00

223 lines
7.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import uvicorn
import os
import tempfile
import json
import librosa
from io import BytesIO
from optimized_main import OptimizedCatTranslator
from fastapi import FastAPI, UploadFile, File, Form, HTTPException, BackgroundTasks
from fastapi.responses import JSONResponse, StreamingResponse
app = FastAPI(title="物种翻译器API服务")
# 全局翻译器实例
translator = OptimizedCatTranslator(
detector_model_path = "./models/cat_detector_svm.pkl",
intent_model_path = "./models",
feature_type = "hybrid",
detector_threshold = 0.5
)
# 实时分析的状态管理
live_analysis_running = False
@app.post("/analyze/audio", summary="分析原始音频数据")
async def analyze_audio(
audio_data: bytes = File(...),
sr: int = Form(16000),
):
"""
分析原始音频数据,返回物种叫声检测和意图分析结果
接口路径: `/analyze/audio`
请求方法: POST
Args:
audio_data: 原始音频字节数据必填通过File上传
sr: 音频采样率默认16000可选通过Form传递
Returns:
JSONResponse: 包含物种检测和意图分析结果的JSON数据
Raises:
HTTPException:
- 400: 音频格式不支持或文件不完整(提示"音频格式不支持请上传WAV/MP3/FLAC等常见格式..."
- 500: 服务器内部处理异常(返回具体错误信息)
"""
try:
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav", dir="./data/tmp") as temp_file:
temp_file.write(audio_data)
temp_file_path = temp_file.name
# 将字节数据转换为numpy数组
audio, _ = librosa.load(temp_file_path)
# 分析音频
result = translator.analyze_audio(audio, sr)
# 删除临时文件
os.unlink(temp_file_path)
return JSONResponse(content=result)
except Exception as e:
error_msg = str(e).lower()
if "could not open" in error_msg or "format not recognised" in error_msg:
raise HTTPException(
status_code=400,
detail="音频格式不支持请上传WAV/MP3/FLAC等常见格式或检查文件完整性"
)
raise HTTPException(status_code=500, detail=error_msg)
@app.post("/samples/add", summary="添加训练样本")
async def add_sample(
file: UploadFile = File(...),
label: str = Form(...),
is_cat_sound: bool = Form(True),
cat_name: str = Form(None)
):
"""添加音频样本到训练集"""
try:
# 创建临时文件保存上传的音频
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as temp_file:
temp_file.write(await file.read())
temp_file_path = temp_file.name
# 添加样本
result = translator.add_sample(
file_path=temp_file_path,
label=label,
is_cat_sound=is_cat_sound,
cat_name=cat_name
)
# 删除临时文件
os.unlink(temp_file_path)
return JSONResponse(content=result)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/train/detector", summary="训练猫叫声检测器")
async def train_detector(
background_tasks: BackgroundTasks,
model_type: str = Form("svm"),
output_path: str = Form("./models")
):
"""训练新的猫叫声检测器模型(后台任务)"""
try:
# 使用后台任务处理长时间运行的训练过程
result = {"status": "training started", "model_type": model_type, "output_path": output_path}
def train_task():
try:
metrics = translator.train_detector(
model_type=model_type,
output_path=output_path
)
# 可以将结果保存到文件或数据库
with open(os.path.join(output_path, "training_metrics.json"), "w") as f:
json.dump(metrics, f)
except Exception as e:
print(f"Training error: {str(e)}")
background_tasks.add_task(train_task)
return JSONResponse(content=result)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/train/intent", summary="训练意图分类器")
async def train_intent_classifier(
background_tasks: BackgroundTasks,
samples_dir: str = Form(...),
feature_type: str = Form("hybrid"),
output_path: str = Form(None)
):
"""训练新的意图分类器模型(后台任务)"""
try:
# 使用后台任务处理长时间运行的训练过程
result = {"status": "training started", "feature_type": feature_type, "samples_dir": samples_dir}
def train_task():
try:
metrics = translator.train_intent_classifier(
samples_dir=samples_dir,
feature_type=feature_type,
output_path=output_path
)
# 可以将结果保存到文件或数据库
if output_path:
with open(os.path.join(output_path, "intent_training_metrics.json"), "w") as f:
json.dump(metrics, f)
except Exception as e:
print(f"Intent training error: {str(e)}")
background_tasks.add_task(train_task)
return JSONResponse(content=result)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/live/start", summary="开始实时分析")
async def start_live_analysis(
duration: float = Form(3.0),
interval: float = Form(1.0),
device: int = Form(None),
detector_model_path: str = Form("./models/cat_detector_svm.pkl"),
intent_model_path: str = Form("./models"),
feature_type: str = Form("temporal_modulation"),
detector_threshold: float = Form(0.5)
):
"""开始实时音频分析"""
global live_analysis_running
if live_analysis_running:
raise HTTPException(status_code=400, detail="实时分析已在运行中")
try:
live_analysis_running = True
# 这里简化处理实际应用可能需要使用WebSocket
def generate():
global live_analysis_running
try:
while live_analysis_running:
# 录音
audio = translator.audio_input.record_audio(duration=duration, device=device)
# 分析
result = translator.analyze_audio(audio)
# 发送结果
yield f"data: {json.dumps(result)}\n\n"
# 等待
time.sleep(interval)
except Exception as e:
yield f"data: {json.dumps({'error': str(e)})}\n\n"
finally:
live_analysis_running = False
return StreamingResponse(generate(), media_type="text/event-stream")
except Exception as e:
live_analysis_running = False
raise HTTPException(status_code=500, detail=str(e))
@app.post("/live/stop", summary="停止实时分析")
async def stop_live_analysis():
"""停止实时音频分析"""
global live_analysis_running
live_analysis_running = False
return JSONResponse(content={"status": "live analysis stopped"})
if __name__ == "__main__":
import time # 导入time模块用于实时分析中的延迟
uvicorn.run(app, host="0.0.0.0", port=8000)