This commit is contained in:
Timothy J. Baek 2024-05-21 18:18:02 -07:00
parent f508caaea7
commit 24e02a9017
2 changed files with 6 additions and 4 deletions

10
main.py
View File

@ -64,6 +64,8 @@ async def lifespan(app: FastAPI):
app = FastAPI(docs_url="/docs", redoc_url=None, lifespan=lifespan) app = FastAPI(docs_url="/docs", redoc_url=None, lifespan=lifespan)
app.state.PIPELINES = PIPELINES
origins = ["*"] origins = ["*"]
@ -112,14 +114,14 @@ async def get_models():
async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm): async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm):
user_message = get_last_user_message(form_data.messages) user_message = get_last_user_message(form_data.messages)
if form_data.model not in PIPELINES: if form_data.model not in app.state.PIPELINES:
return HTTPException( return HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
detail=f"Model {form_data.model} not found", detail=f"Model {form_data.model} not found",
) )
async def job(): def job():
get_response = PIPELINES[form_data.model]["module"].get_response get_response = app.state.PIPELINES[form_data.model]["module"].get_response
if form_data.stream: if form_data.stream:
@ -184,7 +186,7 @@ async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm):
], ],
} }
return await job() return await run_in_threadpool(job)
@app.get("/") @app.get("/")