This commit is contained in:
Timothy J. Baek 2024-05-21 19:45:02 -07:00
parent ee4544d4f9
commit c37d3f726b
7 changed files with 89 additions and 74 deletions

115
main.py
View File

@ -5,7 +5,7 @@ from fastapi.concurrency import run_in_threadpool
from starlette.responses import StreamingResponse, Response from starlette.responses import StreamingResponse, Response
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from typing import List, Union, Generator from typing import List, Union, Generator, Iterator
import time import time
@ -132,77 +132,88 @@ async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm):
res = get_response( res = get_response(
user_message, user_message,
messages=form_data.messages, messages=form_data.messages,
body=form_data.model_dump_json(), body=form_data.model_dump(),
) )
print(f"stream:true:{res}") print(f"stream:true:{res}")
if isinstance(res, str): if isinstance(res, Iterator):
message = stream_message_template(form_data.model, res) for line in res:
print(f"stream_content:str:{message}") if line:
yield f"data: {json.dumps(message)}\n\n" # Decode the JSON data
decoded_line = line.decode("utf-8")
elif isinstance(res, Generator): print(f"stream_content:Iterator:{decoded_line}")
for message in res: yield f"{decoded_line}\n\n"
print(f"stream_content:Generator:{message}") else:
message = stream_message_template(form_data.model, message) if isinstance(res, str):
message = stream_message_template(form_data.model, res)
print(f"stream_content:str:{message}")
yield f"data: {json.dumps(message)}\n\n" yield f"data: {json.dumps(message)}\n\n"
finish_message = { elif isinstance(res, Generator):
"id": f"{form_data.model}-{str(uuid.uuid4())}", for message in res:
"object": "chat.completion.chunk", print(f"stream_content:Generator:{message}")
"created": int(time.time()), message = stream_message_template(form_data.model, message)
"model": form_data.model, yield f"data: {json.dumps(message)}\n\n"
"choices": [
{
"index": 0,
"delta": {},
"logprobs": None,
"finish_reason": "stop",
}
],
}
yield f"data: {json.dumps(finish_message)}\n\n" finish_message = {
yield f"data: [DONE]" "id": f"{form_data.model}-{str(uuid.uuid4())}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": form_data.model,
"choices": [
{
"index": 0,
"delta": {},
"logprobs": None,
"finish_reason": "stop",
}
],
}
yield f"data: {json.dumps(finish_message)}\n\n"
yield f"data: [DONE]"
return StreamingResponse(stream_content(), media_type="text/event-stream") return StreamingResponse(stream_content(), media_type="text/event-stream")
else: else:
res = get_response( res = get_response(
user_message, user_message,
messages=form_data.messages, messages=form_data.messages,
body=form_data.model_dump_json(), body=form_data.model_dump(),
) )
print(f"stream:false:{res}") print(f"stream:false:{res}")
message = "" if isinstance(res, dict):
return res
else:
message = ""
if isinstance(res, str): if isinstance(res, str):
message = res message = res
elif isinstance(res, Generator): elif isinstance(res, Generator):
for stream in res: for stream in res:
message = f"{message}{stream}" message = f"{message}{stream}"
print(f"stream:false:{message}") print(f"stream:false:{message}")
return { return {
"id": f"{form_data.model}-{str(uuid.uuid4())}", "id": f"{form_data.model}-{str(uuid.uuid4())}",
"object": "chat.completion", "object": "chat.completion",
"created": int(time.time()), "created": int(time.time()),
"model": form_data.model, "model": form_data.model,
"choices": [ "choices": [
{ {
"index": 0, "index": 0,
"message": { "message": {
"role": "assistant", "role": "assistant",
"content": message, "content": message,
}, },
"logprobs": None, "logprobs": None,
"finish_reason": "stop", "finish_reason": "stop",
} }
], ],
} }
return await run_in_threadpool(job) return await run_in_threadpool(job)

View File

@ -80,7 +80,7 @@ class Pipeline:
def get_response( def get_response(
self, user_message: str, messages: List[OpenAIChatMessage], body: dict self, user_message: str, messages: List[OpenAIChatMessage], body: dict
) -> Union[str, Generator]: ) -> Union[str, Generator, Iterator]:
# This is where you can add your custom RAG pipeline. # This is where you can add your custom RAG pipeline.
# Typically, you would retrieve relevant information from your knowledge base and synthesize it to generate a response. # Typically, you would retrieve relevant information from your knowledge base and synthesize it to generate a response.

View File

@ -71,7 +71,7 @@ class Pipeline:
def get_response( def get_response(
self, user_message: str, messages: List[OpenAIChatMessage], body: dict self, user_message: str, messages: List[OpenAIChatMessage], body: dict
) -> Union[str, Generator]: ) -> Union[str, Generator, Iterator]:
# This is where you can add your custom RAG pipeline. # This is where you can add your custom RAG pipeline.
# Typically, you would retrieve relevant information from your knowledge base and synthesize it to generate a response. # Typically, you would retrieve relevant information from your knowledge base and synthesize it to generate a response.

View File

@ -31,7 +31,7 @@ class Pipeline:
def get_response( def get_response(
self, user_message: str, messages: List[OpenAIChatMessage], body: dict self, user_message: str, messages: List[OpenAIChatMessage], body: dict
) -> Union[str, Generator]: ) -> Union[str, Generator, Iterator]:
# This is where you can add your custom RAG pipeline. # This is where you can add your custom RAG pipeline.
# Typically, you would retrieve relevant information from your knowledge base and synthesize it to generate a response. # Typically, you would retrieve relevant information from your knowledge base and synthesize it to generate a response.

View File

@ -26,7 +26,7 @@ class Pipeline:
def get_response( def get_response(
self, user_message: str, messages: List[OpenAIChatMessage], body: dict self, user_message: str, messages: List[OpenAIChatMessage], body: dict
) -> Union[str, Generator]: ) -> Union[str, Generator, Iterator]:
# This is where you can add your custom RAG pipeline. # This is where you can add your custom RAG pipeline.
# Typically, you would retrieve relevant information from your knowledge base and synthesize it to generate a response. # Typically, you would retrieve relevant information from your knowledge base and synthesize it to generate a response.

View File

@ -1,4 +1,4 @@
from typing import List, Union, Generator from typing import List, Union, Generator, Iterator
from schemas import OpenAIChatMessage from schemas import OpenAIChatMessage
import requests import requests
@ -19,31 +19,35 @@ class Pipeline:
def get_response( def get_response(
self, user_message: str, messages: List[OpenAIChatMessage], body: dict self, user_message: str, messages: List[OpenAIChatMessage], body: dict
) -> Union[str, Generator]: ) -> Union[str, Generator, Iterator]:
# This is where you can add your custom pipelines like RAG.' # This is where you can add your custom pipelines like RAG.'
print(f"get_response:{__name__}") print(f"get_response:{__name__}")
print(messages) print(messages)
print(user_message) print(user_message)
OPENAI_API_KEY = "your-api-key-here" OPENAI_API_KEY = "your-openai-api-key-here"
headers = {} headers = {}
headers["Authorization"] = f"Bearer {OPENAI_API_KEY}" headers["Authorization"] = f"Bearer {OPENAI_API_KEY}"
headers["Content-Type"] = "application/json" headers["Content-Type"] = "application/json"
r = requests.request( data = {**body, "model": "gpt-3.5-turbo"}
method="POST",
url="https://api.openai.com/v1",
data=body,
headers=headers,
stream=True,
)
r.raise_for_status() print(data)
# Check if response is SSE try:
if "text/event-stream" in r.headers.get("Content-Type", ""): r = requests.post(
return r.iter_content(chunk_size=8192) url="https://api.openai.com/v1/chat/completions",
else: json={**body, "model": "gpt-3.5-turbo"},
response_data = r.json() headers=headers,
return f"{response_data['choices'][0]['text']}" stream=True,
)
r.raise_for_status()
if data["stream"]:
return r.iter_lines()
else:
return r.json()
except Exception as e:
return f"Error: {e}"

View File

@ -18,7 +18,7 @@ class Pipeline:
def get_response( def get_response(
self, user_message: str, messages: List[OpenAIChatMessage], body: dict self, user_message: str, messages: List[OpenAIChatMessage], body: dict
) -> Union[str, Generator]: ) -> Union[str, Generator, Iterator]:
# This is where you can add your custom pipelines like RAG.' # This is where you can add your custom pipelines like RAG.'
print(f"get_response:{__name__}") print(f"get_response:{__name__}")