From 9e58c7150919528ef6615232ae900f9ef9f9fc10 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Tue, 28 May 2024 16:07:21 -0700 Subject: [PATCH] refac --- .../examples/litellm_manifold_pipeline.py | 24 ++++++++-------- .../examples/ollama_manifold_pipeline.py | 25 ++++++++++------- pipelines/ollama_manifold_pipeline.py | 28 +++++++++++-------- 3 files changed, 44 insertions(+), 33 deletions(-) diff --git a/pipelines/examples/litellm_manifold_pipeline.py b/pipelines/examples/litellm_manifold_pipeline.py index 2edce36..9e9c2ed 100644 --- a/pipelines/examples/litellm_manifold_pipeline.py +++ b/pipelines/examples/litellm_manifold_pipeline.py @@ -28,6 +28,17 @@ class Pipeline: self.pipelines = [] pass + async def on_startup(self): + # This function is called when the server is started or after valves are updated. + print(f"on_startup:{__name__}") + self.pipelines = self.get_litellm_models() + pass + + async def on_shutdown(self): + # This function is called when the server is stopped or before valves are updated. + print(f"on_shutdown:{__name__}") + pass + def get_litellm_models(self): if self.valves.LITELLM_BASE_URL: try: @@ -43,22 +54,11 @@ class Pipeline: except Exception as e: print(f"Error: {e}") return [ - {"id": "litellm", "name": "LiteLLM: Please configure LiteLLM URL"}, + {"id": "litellm", "name": "Please configure LiteLLM URL"}, ] else: return [] - async def on_startup(self): - # This function is called when the server is started or after valves are updated. - print(f"on_startup:{__name__}") - self.pipelines = self.get_litellm_models() - pass - - async def on_shutdown(self): - # This function is called when the server is stopped or before valves are updated. - print(f"on_shutdown:{__name__}") - pass - def pipe( self, user_message: str, model_id: str, messages: List[dict], body: dict ) -> Union[str, Generator, Iterator]: diff --git a/pipelines/examples/ollama_manifold_pipeline.py b/pipelines/examples/ollama_manifold_pipeline.py index e4262a0..a90b82f 100644 --- a/pipelines/examples/ollama_manifold_pipeline.py +++ b/pipelines/examples/ollama_manifold_pipeline.py @@ -1,5 +1,6 @@ from typing import List, Union, Generator, Iterator from schemas import OpenAIChatMessage +from pydantic import BaseModel import requests @@ -19,21 +20,17 @@ class Pipeline: # Optionally, you can set the name of the manifold pipeline. self.name = "Ollama: " - self.OLLAMA_BASE_URL = "http://localhost:11434" - self.pipelines = self.get_ollama_models() + class Valves(BaseModel): + OLLAMA_BASE_URL: str + + self.valves = Valves(**{"OLLAMA_BASE_URL": "http://localhost:11434"}) + self.pipelines = [] pass - def get_ollama_models(self): - r = requests.get(f"{self.OLLAMA_BASE_URL}/api/tags") - models = r.json() - - return [ - {"id": model["model"], "name": model["name"]} for model in models["models"] - ] - async def on_startup(self): # This function is called when the server is started or after valves are updated. print(f"on_startup:{__name__}") + self.pipelines = self.get_ollama_models() pass async def on_shutdown(self): @@ -41,6 +38,14 @@ class Pipeline: print(f"on_shutdown:{__name__}") pass + def get_ollama_models(self): + r = requests.get(f"{self.valves.OLLAMA_BASE_URL}/api/tags") + models = r.json() + + return [ + {"id": model["model"], "name": model["name"]} for model in models["models"] + ] + def pipe( self, user_message: str, model_id: str, messages: List[dict], body: dict ) -> Union[str, Generator, Iterator]: diff --git a/pipelines/ollama_manifold_pipeline.py b/pipelines/ollama_manifold_pipeline.py index ce47e40..a90b82f 100644 --- a/pipelines/ollama_manifold_pipeline.py +++ b/pipelines/ollama_manifold_pipeline.py @@ -1,5 +1,6 @@ from typing import List, Union, Generator, Iterator from schemas import OpenAIChatMessage +from pydantic import BaseModel import requests @@ -10,29 +11,26 @@ class Pipeline: # Manifold pipelines can have multiple pipelines. self.type = "manifold" + # Optionally, you can set the id and name of the pipeline. # Assign a unique identifier to the pipeline. # The identifier must be unique across all pipelines. # The identifier must be an alphanumeric string that can include underscores or hyphens. It cannot contain spaces, special characters, slashes, or backslashes. self.id = "ollama_manifold" # Optionally, you can set the name of the manifold pipeline. - self.name = "Manifold: " + self.name = "Ollama: " - self.OLLAMA_BASE_URL = "http://localhost:11434" - self.pipelines = self.get_ollama_models() + class Valves(BaseModel): + OLLAMA_BASE_URL: str + + self.valves = Valves(**{"OLLAMA_BASE_URL": "http://localhost:11434"}) + self.pipelines = [] pass - def get_ollama_models(self): - r = requests.get(f"{self.OLLAMA_BASE_URL}/api/tags") - models = r.json() - - return [ - {"id": model["model"], "name": model["name"]} for model in models["models"] - ] - async def on_startup(self): # This function is called when the server is started or after valves are updated. print(f"on_startup:{__name__}") + self.pipelines = self.get_ollama_models() pass async def on_shutdown(self): @@ -40,6 +38,14 @@ class Pipeline: print(f"on_shutdown:{__name__}") pass + def get_ollama_models(self): + r = requests.get(f"{self.valves.OLLAMA_BASE_URL}/api/tags") + models = r.json() + + return [ + {"id": model["model"], "name": model["name"]} for model in models["models"] + ] + def pipe( self, user_message: str, model_id: str, messages: List[dict], body: dict ) -> Union[str, Generator, Iterator]: