feat: ollama pipeline

This commit is contained in:
Timothy J. Baek 2024-05-21 22:03:54 -07:00
parent f1bcd5be0f
commit b3bb653f46
3 changed files with 113 additions and 7 deletions

16
main.py
View File

@ -18,6 +18,7 @@ from schemas import OpenAIChatCompletionForm
import os
import importlib.util
import logging
from concurrent.futures import ThreadPoolExecutor
@ -37,7 +38,7 @@ def load_modules_from_directory(directory):
for loaded_module in load_modules_from_directory("./pipelines"):
# Do something with the loaded module
print("Loaded:", loaded_module.__name__)
logging.info("Loaded:", loaded_module.__name__)
pipeline = loaded_module.Pipeline()
@ -105,6 +106,7 @@ async def get_models():
"object": "model",
"created": int(time.time()),
"owned_by": "openai",
"pipeline": True,
}
for pipeline in PIPELINES.values()
]
@ -123,7 +125,7 @@ async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm):
)
def job():
print(form_data.model)
logging.info(form_data.model)
get_response = app.state.PIPELINES[form_data.model]["module"].get_response
if form_data.stream:
@ -135,11 +137,11 @@ async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm):
body=form_data.model_dump(),
)
print(f"stream:true:{res}")
logging.info(f"stream:true:{res}")
if isinstance(res, str):
message = stream_message_template(form_data.model, res)
print(f"stream_content:str:{message}")
logging.info(f"stream_content:str:{message}")
yield f"data: {json.dumps(message)}\n\n"
if isinstance(res, Iterator):
@ -149,7 +151,7 @@ async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm):
except:
pass
print(f"stream_content:Generator:{line}")
logging.info(f"stream_content:Generator:{line}")
if line.startswith("data:"):
yield f"{line}\n\n"
@ -183,7 +185,7 @@ async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm):
messages=form_data.messages,
body=form_data.model_dump(),
)
print(f"stream:false:{res}")
logging.info(f"stream:false:{res}")
if isinstance(res, dict):
return res
@ -197,7 +199,7 @@ async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm):
for stream in res:
message = f"{message}{stream}"
print(f"stream:false:{message}")
logging.info(f"stream:false:{message}")
return {
"id": f"{form_data.model}-{str(uuid.uuid4())}",

View File

@ -0,0 +1,52 @@
from typing import List, Union, Generator, Iterator
from schemas import OpenAIChatMessage
import requests
class Pipeline:
def __init__(self):
# Optionally, you can set the id and name of the pipeline.
self.id = "ollama_pipeline"
self.name = "Ollama Pipeline"
pass
async def on_startup(self):
# This function is called when the server is started.
print(f"on_startup:{__name__}")
pass
async def on_shutdown(self):
# This function is called when the server is stopped.
print(f"on_shutdown:{__name__}")
pass
def get_response(
self, user_message: str, messages: List[OpenAIChatMessage], body: dict
) -> Union[str, Generator, Iterator]:
# This is where you can add your custom pipelines like RAG.'
print(f"get_response:{__name__}")
OLLAMA_BASE_URL = "http://localhost:11434"
MODEL = "llama3"
if "user" in body:
print("######################################")
print(f'# User: {body["user"]["name"]} ({body["user"]["id"]})')
print(f"# Message: {user_message}")
print("######################################")
try:
r = requests.post(
url=f"{OLLAMA_BASE_URL}/v1/chat/completions",
json={**body, "model": MODEL},
stream=True,
)
r.raise_for_status()
if body["stream"]:
return r.iter_lines()
else:
return r.json()
except Exception as e:
return f"Error: {e}"

View File

@ -0,0 +1,52 @@
from typing import List, Union, Generator, Iterator
from schemas import OpenAIChatMessage
import requests
class Pipeline:
def __init__(self):
# Optionally, you can set the id and name of the pipeline.
self.id = "ollama_pipeline"
self.name = "Ollama Pipeline"
pass
async def on_startup(self):
# This function is called when the server is started.
print(f"on_startup:{__name__}")
pass
async def on_shutdown(self):
# This function is called when the server is stopped.
print(f"on_shutdown:{__name__}")
pass
def get_response(
self, user_message: str, messages: List[OpenAIChatMessage], body: dict
) -> Union[str, Generator, Iterator]:
# This is where you can add your custom pipelines like RAG.'
print(f"get_response:{__name__}")
OLLAMA_BASE_URL = "http://localhost:11434"
MODEL = "llama3"
if "user" in body:
print("######################################")
print(f'# User: {body["user"]["name"]} ({body["user"]["id"]})')
print(f"# Message: {user_message}")
print("######################################")
try:
r = requests.post(
url=f"{OLLAMA_BASE_URL}/v1/chat/completions",
json={**body, "model": MODEL},
stream=True,
)
r.raise_for_status()
if body["stream"]:
return r.iter_lines()
else:
return r.json()
except Exception as e:
return f"Error: {e}"