50 lines
1.6 KiB
Python
50 lines
1.6 KiB
Python
import uvicorn
|
|
from fastapi import FastAPI
|
|
from contextlib import asynccontextmanager
|
|
import argparse
|
|
|
|
from pt_gen.core import config
|
|
from pt_gen.api import endpoints
|
|
from pt_gen.api.endpoints import get_orchestrator
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
# 应用启动时
|
|
print("应用启动...")
|
|
# 预热依赖:第一次调用 get_orchestrator 来创建实例并建立 redis 连接
|
|
orchestrator = get_orchestrator()
|
|
print("Orchestrator 实例已创建。")
|
|
|
|
yield
|
|
# 应用关闭时
|
|
print("应用关闭,关闭 Redis 连接...")
|
|
# 获取实例并关闭连接
|
|
orchestrator = get_orchestrator()
|
|
await orchestrator.cache.close()
|
|
print("Redis 连接已关闭。")
|
|
|
|
app = FastAPI(title="PT-Gen API", lifespan=lifespan)
|
|
app.include_router(endpoints.router, prefix="/api", tags=["Generator"])
|
|
|
|
@app.get("/", tags=["Root"])
|
|
def read_root():
|
|
return {"message": "Welcome to PT-Gen API"}
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# 解析命令行参数来获取 host 和 port
|
|
parser = argparse.ArgumentParser(description="Run PT-Gen API server.")
|
|
parser.add_argument("--host", type=str, default="127.0.0.1", help="Host to bind")
|
|
parser.add_argument("--port", type=int, default=8000, help="Port to listen on")
|
|
parser.add_argument("--config", type=str, default="configs/config.yaml", help="Path to the config file")
|
|
|
|
cli_args = parser.parse_args()
|
|
|
|
uvicorn.run(
|
|
"main:app",
|
|
host=cli_args.host,
|
|
port=cli_args.port,
|
|
reload=True # 开发时开启,生产环境应为 False
|
|
)
|