feat: first commit
This commit is contained in:
223
api.py
Normal file
223
api.py
Normal file
@@ -0,0 +1,223 @@
|
||||
|
||||
import uvicorn
|
||||
import os
|
||||
import tempfile
|
||||
import json
|
||||
import librosa
|
||||
from io import BytesIO
|
||||
from optimized_main import OptimizedCatTranslator
|
||||
|
||||
from fastapi import FastAPI, UploadFile, File, Form, HTTPException, BackgroundTasks
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
app = FastAPI(title="物种翻译器API服务")
|
||||
|
||||
# 全局翻译器实例
|
||||
translator = OptimizedCatTranslator(
|
||||
detector_model_path = "./models/cat_detector_svm.pkl",
|
||||
intent_model_path = "./models",
|
||||
feature_type = "hybrid",
|
||||
detector_threshold = 0.5
|
||||
)
|
||||
|
||||
# 实时分析的状态管理
|
||||
live_analysis_running = False
|
||||
|
||||
|
||||
@app.post("/analyze/audio", summary="分析原始音频数据")
|
||||
async def analyze_audio(
|
||||
audio_data: bytes = File(...),
|
||||
sr: int = Form(16000),
|
||||
):
|
||||
"""
|
||||
分析原始音频数据,返回物种叫声检测和意图分析结果
|
||||
|
||||
接口路径: `/analyze/audio`
|
||||
请求方法: POST
|
||||
|
||||
Args:
|
||||
audio_data: 原始音频字节数据(必填,通过File上传)
|
||||
sr: 音频采样率,默认16000(可选,通过Form传递)
|
||||
|
||||
Returns:
|
||||
JSONResponse: 包含物种检测和意图分析结果的JSON数据
|
||||
|
||||
Raises:
|
||||
HTTPException:
|
||||
- 400: 音频格式不支持或文件不完整(提示"音频格式不支持,请上传WAV/MP3/FLAC等常见格式...")
|
||||
- 500: 服务器内部处理异常(返回具体错误信息)
|
||||
"""
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav", dir="./data/tmp") as temp_file:
|
||||
temp_file.write(audio_data)
|
||||
temp_file_path = temp_file.name
|
||||
# 将字节数据转换为numpy数组
|
||||
audio, _ = librosa.load(temp_file_path)
|
||||
|
||||
# 分析音频
|
||||
result = translator.analyze_audio(audio, sr)
|
||||
|
||||
# 删除临时文件
|
||||
os.unlink(temp_file_path)
|
||||
|
||||
return JSONResponse(content=result)
|
||||
except Exception as e:
|
||||
|
||||
error_msg = str(e).lower()
|
||||
if "could not open" in error_msg or "format not recognised" in error_msg:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="音频格式不支持,请上传WAV/MP3/FLAC等常见格式,或检查文件完整性"
|
||||
)
|
||||
|
||||
raise HTTPException(status_code=500, detail=error_msg)
|
||||
|
||||
|
||||
@app.post("/samples/add", summary="添加训练样本")
|
||||
async def add_sample(
|
||||
file: UploadFile = File(...),
|
||||
label: str = Form(...),
|
||||
is_cat_sound: bool = Form(True),
|
||||
cat_name: str = Form(None)
|
||||
):
|
||||
"""添加音频样本到训练集"""
|
||||
try:
|
||||
# 创建临时文件保存上传的音频
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as temp_file:
|
||||
temp_file.write(await file.read())
|
||||
temp_file_path = temp_file.name
|
||||
|
||||
# 添加样本
|
||||
result = translator.add_sample(
|
||||
file_path=temp_file_path,
|
||||
label=label,
|
||||
is_cat_sound=is_cat_sound,
|
||||
cat_name=cat_name
|
||||
)
|
||||
|
||||
# 删除临时文件
|
||||
os.unlink(temp_file_path)
|
||||
|
||||
return JSONResponse(content=result)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.post("/train/detector", summary="训练猫叫声检测器")
|
||||
async def train_detector(
|
||||
background_tasks: BackgroundTasks,
|
||||
model_type: str = Form("svm"),
|
||||
output_path: str = Form("./models")
|
||||
):
|
||||
"""训练新的猫叫声检测器模型(后台任务)"""
|
||||
try:
|
||||
# 使用后台任务处理长时间运行的训练过程
|
||||
result = {"status": "training started", "model_type": model_type, "output_path": output_path}
|
||||
|
||||
def train_task():
|
||||
try:
|
||||
metrics = translator.train_detector(
|
||||
model_type=model_type,
|
||||
output_path=output_path
|
||||
)
|
||||
# 可以将结果保存到文件或数据库
|
||||
with open(os.path.join(output_path, "training_metrics.json"), "w") as f:
|
||||
json.dump(metrics, f)
|
||||
except Exception as e:
|
||||
print(f"Training error: {str(e)}")
|
||||
|
||||
background_tasks.add_task(train_task)
|
||||
return JSONResponse(content=result)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.post("/train/intent", summary="训练意图分类器")
|
||||
async def train_intent_classifier(
|
||||
background_tasks: BackgroundTasks,
|
||||
samples_dir: str = Form(...),
|
||||
feature_type: str = Form("hybrid"),
|
||||
output_path: str = Form(None)
|
||||
):
|
||||
"""训练新的意图分类器模型(后台任务)"""
|
||||
try:
|
||||
# 使用后台任务处理长时间运行的训练过程
|
||||
result = {"status": "training started", "feature_type": feature_type, "samples_dir": samples_dir}
|
||||
|
||||
def train_task():
|
||||
try:
|
||||
metrics = translator.train_intent_classifier(
|
||||
samples_dir=samples_dir,
|
||||
feature_type=feature_type,
|
||||
output_path=output_path
|
||||
)
|
||||
# 可以将结果保存到文件或数据库
|
||||
if output_path:
|
||||
with open(os.path.join(output_path, "intent_training_metrics.json"), "w") as f:
|
||||
json.dump(metrics, f)
|
||||
except Exception as e:
|
||||
print(f"Intent training error: {str(e)}")
|
||||
|
||||
background_tasks.add_task(train_task)
|
||||
return JSONResponse(content=result)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.post("/live/start", summary="开始实时分析")
|
||||
async def start_live_analysis(
|
||||
duration: float = Form(3.0),
|
||||
interval: float = Form(1.0),
|
||||
device: int = Form(None),
|
||||
detector_model_path: str = Form("./models/cat_detector_svm.pkl"),
|
||||
intent_model_path: str = Form("./models"),
|
||||
feature_type: str = Form("temporal_modulation"),
|
||||
detector_threshold: float = Form(0.5)
|
||||
):
|
||||
"""开始实时音频分析"""
|
||||
global live_analysis_running
|
||||
|
||||
if live_analysis_running:
|
||||
raise HTTPException(status_code=400, detail="实时分析已在运行中")
|
||||
|
||||
try:
|
||||
live_analysis_running = True
|
||||
|
||||
# 这里简化处理,实际应用可能需要使用WebSocket
|
||||
def generate():
|
||||
global live_analysis_running
|
||||
try:
|
||||
while live_analysis_running:
|
||||
# 录音
|
||||
audio = translator.audio_input.record_audio(duration=duration, device=device)
|
||||
|
||||
# 分析
|
||||
result = translator.analyze_audio(audio)
|
||||
|
||||
# 发送结果
|
||||
yield f"data: {json.dumps(result)}\n\n"
|
||||
|
||||
# 等待
|
||||
time.sleep(interval)
|
||||
except Exception as e:
|
||||
yield f"data: {json.dumps({'error': str(e)})}\n\n"
|
||||
finally:
|
||||
live_analysis_running = False
|
||||
|
||||
return StreamingResponse(generate(), media_type="text/event-stream")
|
||||
except Exception as e:
|
||||
live_analysis_running = False
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.post("/live/stop", summary="停止实时分析")
|
||||
async def stop_live_analysis():
|
||||
"""停止实时音频分析"""
|
||||
global live_analysis_running
|
||||
live_analysis_running = False
|
||||
return JSONResponse(content={"status": "live analysis stopped"})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import time # 导入time模块,用于实时分析中的延迟
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
Reference in New Issue
Block a user