223 lines
7.5 KiB
Python
223 lines
7.5 KiB
Python
|
||
import uvicorn
|
||
import os
|
||
import tempfile
|
||
import json
|
||
import librosa
|
||
from io import BytesIO
|
||
from optimized_main import OptimizedCatTranslator
|
||
|
||
from fastapi import FastAPI, UploadFile, File, Form, HTTPException, BackgroundTasks
|
||
from fastapi.responses import JSONResponse, StreamingResponse
|
||
app = FastAPI(title="物种翻译器API服务")
|
||
|
||
# 全局翻译器实例
|
||
translator = OptimizedCatTranslator(
|
||
detector_model_path = "./models/cat_detector_svm.pkl",
|
||
intent_model_path = "./models",
|
||
feature_type = "hybrid",
|
||
detector_threshold = 0.5
|
||
)
|
||
|
||
# 实时分析的状态管理
|
||
live_analysis_running = False
|
||
|
||
|
||
@app.post("/analyze/audio", summary="分析原始音频数据")
|
||
async def analyze_audio(
|
||
audio_data: bytes = File(...),
|
||
sr: int = Form(16000),
|
||
):
|
||
"""
|
||
分析原始音频数据,返回物种叫声检测和意图分析结果
|
||
|
||
接口路径: `/analyze/audio`
|
||
请求方法: POST
|
||
|
||
Args:
|
||
audio_data: 原始音频字节数据(必填,通过File上传)
|
||
sr: 音频采样率,默认16000(可选,通过Form传递)
|
||
|
||
Returns:
|
||
JSONResponse: 包含物种检测和意图分析结果的JSON数据
|
||
|
||
Raises:
|
||
HTTPException:
|
||
- 400: 音频格式不支持或文件不完整(提示"音频格式不支持,请上传WAV/MP3/FLAC等常见格式...")
|
||
- 500: 服务器内部处理异常(返回具体错误信息)
|
||
"""
|
||
try:
|
||
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav", dir="./tmp") as temp_file:
|
||
temp_file.write(audio_data)
|
||
temp_file_path = temp_file.name
|
||
# 将字节数据转换为numpy数组
|
||
audio, _ = librosa.load(temp_file_path)
|
||
|
||
# 分析音频
|
||
result = translator.analyze_audio(audio, sr)
|
||
|
||
# 删除临时文件
|
||
os.unlink(temp_file_path)
|
||
|
||
return JSONResponse(content=result)
|
||
except Exception as e:
|
||
|
||
error_msg = str(e).lower()
|
||
if "could not open" in error_msg or "format not recognised" in error_msg:
|
||
raise HTTPException(
|
||
status_code=400,
|
||
detail="音频格式不支持,请上传WAV/MP3/FLAC等常见格式,或检查文件完整性"
|
||
)
|
||
|
||
raise HTTPException(status_code=500, detail=error_msg)
|
||
|
||
|
||
@app.post("/samples/add", summary="添加训练样本")
|
||
async def add_sample(
|
||
file: UploadFile = File(...),
|
||
label: str = Form(...),
|
||
is_cat_sound: bool = Form(True),
|
||
cat_name: str = Form(None)
|
||
):
|
||
"""添加音频样本到训练集"""
|
||
try:
|
||
# 创建临时文件保存上传的音频
|
||
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as temp_file:
|
||
temp_file.write(await file.read())
|
||
temp_file_path = temp_file.name
|
||
|
||
# 添加样本
|
||
result = translator.add_sample(
|
||
file_path=temp_file_path,
|
||
label=label,
|
||
is_cat_sound=is_cat_sound,
|
||
cat_name=cat_name
|
||
)
|
||
|
||
# 删除临时文件
|
||
os.unlink(temp_file_path)
|
||
|
||
return JSONResponse(content=result)
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
@app.post("/train/detector", summary="训练猫叫声检测器")
|
||
async def train_detector(
|
||
background_tasks: BackgroundTasks,
|
||
model_type: str = Form("svm"),
|
||
output_path: str = Form("./models")
|
||
):
|
||
"""训练新的猫叫声检测器模型(后台任务)"""
|
||
try:
|
||
# 使用后台任务处理长时间运行的训练过程
|
||
result = {"status": "training started", "model_type": model_type, "output_path": output_path}
|
||
|
||
def train_task():
|
||
try:
|
||
metrics = translator.train_detector(
|
||
model_type=model_type,
|
||
output_path=output_path
|
||
)
|
||
# 可以将结果保存到文件或数据库
|
||
with open(os.path.join(output_path, "training_metrics.json"), "w") as f:
|
||
json.dump(metrics, f)
|
||
except Exception as e:
|
||
print(f"Training error: {str(e)}")
|
||
|
||
background_tasks.add_task(train_task)
|
||
return JSONResponse(content=result)
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
@app.post("/train/intent", summary="训练意图分类器")
|
||
async def train_intent_classifier(
|
||
background_tasks: BackgroundTasks,
|
||
samples_dir: str = Form(...),
|
||
feature_type: str = Form("hybrid"),
|
||
output_path: str = Form(None)
|
||
):
|
||
"""训练新的意图分类器模型(后台任务)"""
|
||
try:
|
||
# 使用后台任务处理长时间运行的训练过程
|
||
result = {"status": "training started", "feature_type": feature_type, "samples_dir": samples_dir}
|
||
|
||
def train_task():
|
||
try:
|
||
metrics = translator.train_intent_classifier(
|
||
samples_dir=samples_dir,
|
||
feature_type=feature_type,
|
||
output_path=output_path
|
||
)
|
||
# 可以将结果保存到文件或数据库
|
||
if output_path:
|
||
with open(os.path.join(output_path, "intent_training_metrics.json"), "w") as f:
|
||
json.dump(metrics, f)
|
||
except Exception as e:
|
||
print(f"Intent training error: {str(e)}")
|
||
|
||
background_tasks.add_task(train_task)
|
||
return JSONResponse(content=result)
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
@app.post("/live/start", summary="开始实时分析")
|
||
async def start_live_analysis(
|
||
duration: float = Form(3.0),
|
||
interval: float = Form(1.0),
|
||
device: int = Form(None),
|
||
detector_model_path: str = Form("./models/cat_detector_svm.pkl"),
|
||
intent_model_path: str = Form("./models"),
|
||
feature_type: str = Form("temporal_modulation"),
|
||
detector_threshold: float = Form(0.5)
|
||
):
|
||
"""开始实时音频分析"""
|
||
global live_analysis_running
|
||
|
||
if live_analysis_running:
|
||
raise HTTPException(status_code=400, detail="实时分析已在运行中")
|
||
|
||
try:
|
||
live_analysis_running = True
|
||
|
||
# 这里简化处理,实际应用可能需要使用WebSocket
|
||
def generate():
|
||
global live_analysis_running
|
||
try:
|
||
while live_analysis_running:
|
||
# 录音
|
||
audio = translator.audio_input.record_audio(duration=duration, device=device)
|
||
|
||
# 分析
|
||
result = translator.analyze_audio(audio)
|
||
|
||
# 发送结果
|
||
yield f"data: {json.dumps(result)}\n\n"
|
||
|
||
# 等待
|
||
time.sleep(interval)
|
||
except Exception as e:
|
||
yield f"data: {json.dumps({'error': str(e)})}\n\n"
|
||
finally:
|
||
live_analysis_running = False
|
||
|
||
return StreamingResponse(generate(), media_type="text/event-stream")
|
||
except Exception as e:
|
||
live_analysis_running = False
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
@app.post("/live/stop", summary="停止实时分析")
|
||
async def stop_live_analysis():
|
||
"""停止实时音频分析"""
|
||
global live_analysis_running
|
||
live_analysis_running = False
|
||
return JSONResponse(content={"status": "live analysis stopped"})
|
||
|
||
|
||
if __name__ == "__main__":
|
||
import time # 导入time模块,用于实时分析中的延迟
|
||
|
||
uvicorn.run(app, host="0.0.0.0", port=8000) |