refac: concurrency

This commit is contained in:
Timothy J. Baek 2024-05-21 17:29:31 -07:00
parent eaa4112f46
commit 6b4fba3309

110
main.py
View File

@ -1,5 +1,7 @@
from fastapi import FastAPI, Request, Depends, status, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.concurrency import run_in_threadpool
from starlette.responses import StreamingResponse, Response
from pydantic import BaseModel, ConfigDict
@ -116,65 +118,81 @@ async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm):
detail=f"Model {form_data.model} not found",
)
get_response = PIPELINES[form_data.model]["module"].get_response
def job():
if form_data.stream:
get_response = PIPELINES[form_data.model]["module"].get_response
def stream_content():
if form_data.stream:
res = get_response(user_message, messages=form_data.messages)
def stream_content():
res = get_response(user_message, messages=form_data.messages)
if isinstance(res, str):
message = stream_message_template(res)
yield f"data: {json.dumps(message)}\n\n"
elif isinstance(res, Generator):
for message in res:
message = stream_message_template(message)
if isinstance(res, str):
message = stream_message_template(res)
yield f"data: {json.dumps(message)}\n\n"
finish_message = {
elif isinstance(res, Generator):
for message in res:
print(message)
message = stream_message_template(message)
yield f"data: {json.dumps(message)}\n\n"
finish_message = {
"id": f"rag-{str(uuid.uuid4())}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": MODEL_ID,
"choices": [
{
"index": 0,
"delta": {},
"logprobs": None,
"finish_reason": "stop",
}
],
}
yield f"data: {json.dumps(finish_message)}\n\n"
yield f"data: [DONE]"
return StreamingResponse(stream_content(), media_type="text/event-stream")
else:
res = get_response(user_message, messages=form_data.messages)
message = ""
if isinstance(res, str):
message = res
elif isinstance(res, Generator):
for stream in res:
message = f"{message}{stream}"
return {
"id": f"rag-{str(uuid.uuid4())}",
"object": "chat.completion.chunk",
"object": "chat.completion",
"created": int(time.time()),
"model": MODEL_ID,
"choices": [
{"index": 0, "delta": {}, "logprobs": None, "finish_reason": "stop"}
{
"index": 0,
"message": {
"role": "assistant",
"content": message,
},
"logprobs": None,
"finish_reason": "stop",
}
],
}
yield f"data: {json.dumps(finish_message)}\n\n"
yield f"data: [DONE]"
return StreamingResponse(stream_content(), media_type="text/event-stream")
else:
res = get_response(user_message, messages=form_data.messages)
message = ""
if isinstance(res, str):
message = res
elif isinstance(res, Generator):
for stream in res:
message = f"{message}{stream}"
return {
"id": f"rag-{str(uuid.uuid4())}",
"object": "chat.completion",
"created": int(time.time()),
"model": MODEL_ID,
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": message,
},
"logprobs": None,
"finish_reason": "stop",
}
],
}
try:
return await run_in_threadpool(job)
except Exception as e:
print(e)
raise HTTPException(
status_code=500,
detail="{e}",
)
@app.get("/")