import uvicorn import os import tempfile import json import librosa from io import BytesIO from optimized_main import OptimizedCatTranslator from fastapi import FastAPI, UploadFile, File, Form, HTTPException, BackgroundTasks from fastapi.responses import JSONResponse, StreamingResponse app = FastAPI(title="物种翻译器API服务") # 全局翻译器实例 translator = OptimizedCatTranslator( detector_model_path = "./models/cat_detector_svm.pkl", intent_model_path = "./models", feature_type = "hybrid", detector_threshold = 0.5 ) # 实时分析的状态管理 live_analysis_running = False @app.post("/analyze/audio", summary="分析原始音频数据") async def analyze_audio( audio_data: bytes = File(...), sr: int = Form(16000), ): """ 分析原始音频数据,返回物种叫声检测和意图分析结果 接口路径: `/analyze/audio` 请求方法: POST Args: audio_data: 原始音频字节数据(必填,通过File上传) sr: 音频采样率,默认16000(可选,通过Form传递) Returns: JSONResponse: 包含物种检测和意图分析结果的JSON数据 Raises: HTTPException: - 400: 音频格式不支持或文件不完整(提示"音频格式不支持,请上传WAV/MP3/FLAC等常见格式...") - 500: 服务器内部处理异常(返回具体错误信息) """ try: with tempfile.NamedTemporaryFile(delete=False, suffix=".wav", dir="./tmp") as temp_file: temp_file.write(audio_data) temp_file_path = temp_file.name # 将字节数据转换为numpy数组 audio, _ = librosa.load(temp_file_path) # 分析音频 result = translator.analyze_audio(audio, sr) # 删除临时文件 os.unlink(temp_file_path) return JSONResponse(content=result) except Exception as e: error_msg = str(e).lower() if "could not open" in error_msg or "format not recognised" in error_msg: raise HTTPException( status_code=400, detail="音频格式不支持,请上传WAV/MP3/FLAC等常见格式,或检查文件完整性" ) raise HTTPException(status_code=500, detail=error_msg) @app.post("/samples/add", summary="添加训练样本") async def add_sample( file: UploadFile = File(...), label: str = Form(...), is_cat_sound: bool = Form(True), cat_name: str = Form(None) ): """添加音频样本到训练集""" try: # 创建临时文件保存上传的音频 with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as temp_file: temp_file.write(await file.read()) temp_file_path = temp_file.name # 添加样本 result = translator.add_sample( file_path=temp_file_path, label=label, is_cat_sound=is_cat_sound, cat_name=cat_name ) # 删除临时文件 os.unlink(temp_file_path) return JSONResponse(content=result) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/train/detector", summary="训练猫叫声检测器") async def train_detector( background_tasks: BackgroundTasks, model_type: str = Form("svm"), output_path: str = Form("./models") ): """训练新的猫叫声检测器模型(后台任务)""" try: # 使用后台任务处理长时间运行的训练过程 result = {"status": "training started", "model_type": model_type, "output_path": output_path} def train_task(): try: metrics = translator.train_detector( model_type=model_type, output_path=output_path ) # 可以将结果保存到文件或数据库 with open(os.path.join(output_path, "training_metrics.json"), "w") as f: json.dump(metrics, f) except Exception as e: print(f"Training error: {str(e)}") background_tasks.add_task(train_task) return JSONResponse(content=result) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/train/intent", summary="训练意图分类器") async def train_intent_classifier( background_tasks: BackgroundTasks, samples_dir: str = Form(...), feature_type: str = Form("hybrid"), output_path: str = Form(None) ): """训练新的意图分类器模型(后台任务)""" try: # 使用后台任务处理长时间运行的训练过程 result = {"status": "training started", "feature_type": feature_type, "samples_dir": samples_dir} def train_task(): try: metrics = translator.train_intent_classifier( samples_dir=samples_dir, feature_type=feature_type, output_path=output_path ) # 可以将结果保存到文件或数据库 if output_path: with open(os.path.join(output_path, "intent_training_metrics.json"), "w") as f: json.dump(metrics, f) except Exception as e: print(f"Intent training error: {str(e)}") background_tasks.add_task(train_task) return JSONResponse(content=result) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/live/start", summary="开始实时分析") async def start_live_analysis( duration: float = Form(3.0), interval: float = Form(1.0), device: int = Form(None), detector_model_path: str = Form("./models/cat_detector_svm.pkl"), intent_model_path: str = Form("./models"), feature_type: str = Form("temporal_modulation"), detector_threshold: float = Form(0.5) ): """开始实时音频分析""" global live_analysis_running if live_analysis_running: raise HTTPException(status_code=400, detail="实时分析已在运行中") try: live_analysis_running = True # 这里简化处理,实际应用可能需要使用WebSocket def generate(): global live_analysis_running try: while live_analysis_running: # 录音 audio = translator.audio_input.record_audio(duration=duration, device=device) # 分析 result = translator.analyze_audio(audio) # 发送结果 yield f"data: {json.dumps(result)}\n\n" # 等待 time.sleep(interval) except Exception as e: yield f"data: {json.dumps({'error': str(e)})}\n\n" finally: live_analysis_running = False return StreamingResponse(generate(), media_type="text/event-stream") except Exception as e: live_analysis_running = False raise HTTPException(status_code=500, detail=str(e)) @app.post("/live/stop", summary="停止实时分析") async def stop_live_analysis(): """停止实时音频分析""" global live_analysis_running live_analysis_running = False return JSONResponse(content={"status": "live analysis stopped"}) if __name__ == "__main__": import time # 导入time模块,用于实时分析中的延迟 uvicorn.run(app, host="0.0.0.0", port=8000)