diff --git a/main.py b/main.py index 9e21e5d..43597e2 100644 --- a/main.py +++ b/main.py @@ -109,7 +109,7 @@ async def get_models(): @app.post("/chat/completions") @app.post("/v1/chat/completions") -def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm): +async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm): user_message = get_last_user_message(form_data.messages) if form_data.model not in PIPELINES: @@ -118,7 +118,7 @@ def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm): detail=f"Model {form_data.model} not found", ) - def job(): + async def job(): get_response = PIPELINES[form_data.model]["module"].get_response if form_data.stream: @@ -184,7 +184,7 @@ def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm): ], } - return job() + return await job() @app.get("/") diff --git a/pipelines/examples/haystack_pipeline.py b/pipelines/examples/haystack_pipeline.py index b5dc31f..ff76e4e 100644 --- a/pipelines/examples/haystack_pipeline.py +++ b/pipelines/examples/haystack_pipeline.py @@ -2,7 +2,7 @@ from typing import List, Union, Generator from schemas import OpenAIChatMessage import os -global basic_rag_pipeline +basic_rag_pipeline = None def get_response( @@ -23,6 +23,7 @@ def get_response( async def on_startup(): + global basic_rag_pipeline os.environ["OPENAI_API_KEY"] = "your_openai_api_key_here"