From b3bb653f469a309aff31d1c5f9a469e66e71bbf9 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Tue, 21 May 2024 22:03:54 -0700 Subject: [PATCH] feat: ollama pipeline --- main.py | 16 +++++---- pipelines/examples/ollama_pipeline.py | 52 +++++++++++++++++++++++++++ pipelines/ollama_pipeline.py | 52 +++++++++++++++++++++++++++ 3 files changed, 113 insertions(+), 7 deletions(-) create mode 100644 pipelines/examples/ollama_pipeline.py create mode 100644 pipelines/ollama_pipeline.py diff --git a/main.py b/main.py index f022817..f735241 100644 --- a/main.py +++ b/main.py @@ -18,6 +18,7 @@ from schemas import OpenAIChatCompletionForm import os import importlib.util +import logging from concurrent.futures import ThreadPoolExecutor @@ -37,7 +38,7 @@ def load_modules_from_directory(directory): for loaded_module in load_modules_from_directory("./pipelines"): # Do something with the loaded module - print("Loaded:", loaded_module.__name__) + logging.info("Loaded:", loaded_module.__name__) pipeline = loaded_module.Pipeline() @@ -105,6 +106,7 @@ async def get_models(): "object": "model", "created": int(time.time()), "owned_by": "openai", + "pipeline": True, } for pipeline in PIPELINES.values() ] @@ -123,7 +125,7 @@ async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm): ) def job(): - print(form_data.model) + logging.info(form_data.model) get_response = app.state.PIPELINES[form_data.model]["module"].get_response if form_data.stream: @@ -135,11 +137,11 @@ async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm): body=form_data.model_dump(), ) - print(f"stream:true:{res}") + logging.info(f"stream:true:{res}") if isinstance(res, str): message = stream_message_template(form_data.model, res) - print(f"stream_content:str:{message}") + logging.info(f"stream_content:str:{message}") yield f"data: {json.dumps(message)}\n\n" if isinstance(res, Iterator): @@ -149,7 +151,7 @@ async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm): except: pass - print(f"stream_content:Generator:{line}") + logging.info(f"stream_content:Generator:{line}") if line.startswith("data:"): yield f"{line}\n\n" @@ -183,7 +185,7 @@ async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm): messages=form_data.messages, body=form_data.model_dump(), ) - print(f"stream:false:{res}") + logging.info(f"stream:false:{res}") if isinstance(res, dict): return res @@ -197,7 +199,7 @@ async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm): for stream in res: message = f"{message}{stream}" - print(f"stream:false:{message}") + logging.info(f"stream:false:{message}") return { "id": f"{form_data.model}-{str(uuid.uuid4())}", diff --git a/pipelines/examples/ollama_pipeline.py b/pipelines/examples/ollama_pipeline.py new file mode 100644 index 0000000..437461d --- /dev/null +++ b/pipelines/examples/ollama_pipeline.py @@ -0,0 +1,52 @@ +from typing import List, Union, Generator, Iterator +from schemas import OpenAIChatMessage +import requests + + +class Pipeline: + def __init__(self): + # Optionally, you can set the id and name of the pipeline. + self.id = "ollama_pipeline" + self.name = "Ollama Pipeline" + pass + + async def on_startup(self): + # This function is called when the server is started. + print(f"on_startup:{__name__}") + pass + + async def on_shutdown(self): + # This function is called when the server is stopped. + print(f"on_shutdown:{__name__}") + pass + + def get_response( + self, user_message: str, messages: List[OpenAIChatMessage], body: dict + ) -> Union[str, Generator, Iterator]: + # This is where you can add your custom pipelines like RAG.' + print(f"get_response:{__name__}") + + OLLAMA_BASE_URL = "http://localhost:11434" + MODEL = "llama3" + + if "user" in body: + print("######################################") + print(f'# User: {body["user"]["name"]} ({body["user"]["id"]})') + print(f"# Message: {user_message}") + print("######################################") + + try: + r = requests.post( + url=f"{OLLAMA_BASE_URL}/v1/chat/completions", + json={**body, "model": MODEL}, + stream=True, + ) + + r.raise_for_status() + + if body["stream"]: + return r.iter_lines() + else: + return r.json() + except Exception as e: + return f"Error: {e}" diff --git a/pipelines/ollama_pipeline.py b/pipelines/ollama_pipeline.py new file mode 100644 index 0000000..437461d --- /dev/null +++ b/pipelines/ollama_pipeline.py @@ -0,0 +1,52 @@ +from typing import List, Union, Generator, Iterator +from schemas import OpenAIChatMessage +import requests + + +class Pipeline: + def __init__(self): + # Optionally, you can set the id and name of the pipeline. + self.id = "ollama_pipeline" + self.name = "Ollama Pipeline" + pass + + async def on_startup(self): + # This function is called when the server is started. + print(f"on_startup:{__name__}") + pass + + async def on_shutdown(self): + # This function is called when the server is stopped. + print(f"on_shutdown:{__name__}") + pass + + def get_response( + self, user_message: str, messages: List[OpenAIChatMessage], body: dict + ) -> Union[str, Generator, Iterator]: + # This is where you can add your custom pipelines like RAG.' + print(f"get_response:{__name__}") + + OLLAMA_BASE_URL = "http://localhost:11434" + MODEL = "llama3" + + if "user" in body: + print("######################################") + print(f'# User: {body["user"]["name"]} ({body["user"]["id"]})') + print(f"# Message: {user_message}") + print("######################################") + + try: + r = requests.post( + url=f"{OLLAMA_BASE_URL}/v1/chat/completions", + json={**body, "model": MODEL}, + stream=True, + ) + + r.raise_for_status() + + if body["stream"]: + return r.iter_lines() + else: + return r.json() + except Exception as e: + return f"Error: {e}"