diff --git a/.gitignore b/.gitignore index ed8ebf5..112f331 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ -__pycache__ \ No newline at end of file +__pycache__ +.env diff --git a/config.py b/config.py index 11fa8ad..c74f8a2 100644 --- a/config.py +++ b/config.py @@ -1,2 +1,16 @@ -MODEL_ID = "rag-api" -MODEL_NAME = "RAG Model" +import os + +#################################### +# Load .env file +#################################### + +try: + from dotenv import load_dotenv, find_dotenv + + load_dotenv(find_dotenv("./.env")) +except ImportError: + print("dotenv not installed, skipping...") + + +MODEL_ID = os.environ.get("MODEL_ID", "plugin-id") +MODEL_NAME = os.environ.get("MODEL_NAME", "Plugin Model") diff --git a/main.py b/main.py index f9e7ab1..b7eda2f 100644 --- a/main.py +++ b/main.py @@ -14,9 +14,23 @@ from utils import get_last_user_message, stream_message_template from schemas import OpenAIChatCompletionForm from config import MODEL_ID, MODEL_NAME -from pipelines.pipeline import get_response +from pipelines.examples.llamaindex_ollama_github_pipeline import ( + get_response, + on_startup, + on_shutdown, +) -app = FastAPI(docs_url="/docs", redoc_url=None) +from contextlib import asynccontextmanager + + +@asynccontextmanager +async def lifespan(app: FastAPI): + await on_startup() + yield + await on_shutdown() + + +app = FastAPI(docs_url="/docs", redoc_url=None, lifespan=lifespan) origins = ["*"] diff --git a/pipelines/examples/haystack_pipeline.py b/pipelines/examples/haystack_pipeline.py index 52ef777..5a87486 100644 --- a/pipelines/examples/haystack_pipeline.py +++ b/pipelines/examples/haystack_pipeline.py @@ -89,3 +89,13 @@ def get_response( ) return response["llm"]["replies"][0] + + +async def on_startup(): + # This function is called when the server is started. + pass + + +async def on_shutdown(): + # This function is called when the server is stopped. + pass diff --git a/pipelines/examples/llamaindex_ollama_github_pipeline.py b/pipelines/examples/llamaindex_ollama_github_pipeline.py index 7780bf2..d343080 100644 --- a/pipelines/examples/llamaindex_ollama_github_pipeline.py +++ b/pipelines/examples/llamaindex_ollama_github_pipeline.py @@ -1,5 +1,7 @@ from typing import List, Union, Generator from schemas import OpenAIChatMessage +import os +import asyncio from llama_index.embeddings.ollama import OllamaEmbedding from llama_index.llms.ollama import Ollama @@ -12,41 +14,9 @@ Settings.embed_model = OllamaEmbedding( ) Settings.llm = Ollama(model="llama3") -import os -github_token = os.environ.get("GITHUB_TOKEN") -owner = "open-webui" -repo = "open-webui" -branch = "main" - -github_client = GithubClient(github_token=github_token, verbose=True) - -documents = GithubRepositoryReader( - github_client=github_client, - owner=owner, - repo=repo, - use_parser=False, - verbose=False, - filter_directories=( - ["docs"], - GithubRepositoryReader.FilterType.INCLUDE, - ), - filter_file_extensions=( - [ - ".png", - ".jpg", - ".jpeg", - ".gif", - ".svg", - ".ico", - "json", - ".ipynb", - ], - GithubRepositoryReader.FilterType.EXCLUDE, - ), -).load_data(branch=branch) - -index = VectorStoreIndex.from_documents(documents) +index = None +documents = None def get_response( @@ -62,3 +32,54 @@ def get_response( response = query_engine.query(user_message) return response.response_gen + + +async def on_startup(): + global index, documents + + github_token = os.environ.get("GITHUB_TOKEN") + owner = "open-webui" + repo = "plugin-server" + branch = "main" + + github_client = GithubClient(github_token=github_token, verbose=True) + + reader = GithubRepositoryReader( + github_client=github_client, + owner=owner, + repo=repo, + use_parser=False, + verbose=False, + filter_file_extensions=( + [ + ".png", + ".jpg", + ".jpeg", + ".gif", + ".svg", + ".ico", + "json", + ".ipynb", + ], + GithubRepositoryReader.FilterType.EXCLUDE, + ), + ) + + loop = asyncio.new_event_loop() + + reader._loop = loop + + try: + # Load data from the branch + documents = await asyncio.to_thread(reader.load_data, branch=branch) + index = VectorStoreIndex.from_documents(documents) + finally: + loop.close() + + print(documents) + print(index) + + +async def on_shutdown(): + # This function is called when the pipeline is stopped. + pass diff --git a/pipelines/examples/llamaindex_ollama_pipeline.py b/pipelines/examples/llamaindex_ollama_pipeline.py index 93c1725..c20f41c 100644 --- a/pipelines/examples/llamaindex_ollama_pipeline.py +++ b/pipelines/examples/llamaindex_ollama_pipeline.py @@ -30,3 +30,13 @@ def get_response( response = query_engine.query(user_message) return response.response_gen + + +async def on_startup(): + # This function is called when the server is started. + pass + + +async def on_shutdown(): + # This function is called when the server is stopped. + pass diff --git a/pipelines/examples/llamaindex_pipeline.py b/pipelines/examples/llamaindex_pipeline.py index de68ec2..e709bd0 100644 --- a/pipelines/examples/llamaindex_pipeline.py +++ b/pipelines/examples/llamaindex_pipeline.py @@ -25,3 +25,13 @@ def get_response( response = query_engine.query(user_message) return response.response_gen + + +async def on_startup(): + # This function is called when the server is started. + pass + + +async def on_shutdown(): + # This function is called when the server is stopped. + pass diff --git a/pipelines/pipeline.py b/pipelines/pipeline.py index aeb2a85..daa701f 100644 --- a/pipelines/pipeline.py +++ b/pipelines/pipeline.py @@ -11,3 +11,13 @@ def get_response( print(user_message) return f"rag response to: {user_message}" + + +async def on_startup(): + # This function is called when the server is started. + pass + + +async def on_shutdown(): + # This function is called when the server is stopped. + pass