mirror of
https://github.com/open-webui/pipelines
synced 2025-05-12 16:40:45 +00:00
feat: lifecycle
This commit is contained in:
parent
532de7cbe3
commit
68cc31009e
3
.gitignore
vendored
3
.gitignore
vendored
@ -1 +1,2 @@
|
|||||||
__pycache__
|
__pycache__
|
||||||
|
.env
|
||||||
|
18
config.py
18
config.py
@ -1,2 +1,16 @@
|
|||||||
MODEL_ID = "rag-api"
|
import os
|
||||||
MODEL_NAME = "RAG Model"
|
|
||||||
|
####################################
|
||||||
|
# Load .env file
|
||||||
|
####################################
|
||||||
|
|
||||||
|
try:
|
||||||
|
from dotenv import load_dotenv, find_dotenv
|
||||||
|
|
||||||
|
load_dotenv(find_dotenv("./.env"))
|
||||||
|
except ImportError:
|
||||||
|
print("dotenv not installed, skipping...")
|
||||||
|
|
||||||
|
|
||||||
|
MODEL_ID = os.environ.get("MODEL_ID", "plugin-id")
|
||||||
|
MODEL_NAME = os.environ.get("MODEL_NAME", "Plugin Model")
|
||||||
|
18
main.py
18
main.py
@ -14,9 +14,23 @@ from utils import get_last_user_message, stream_message_template
|
|||||||
from schemas import OpenAIChatCompletionForm
|
from schemas import OpenAIChatCompletionForm
|
||||||
from config import MODEL_ID, MODEL_NAME
|
from config import MODEL_ID, MODEL_NAME
|
||||||
|
|
||||||
from pipelines.pipeline import get_response
|
from pipelines.examples.llamaindex_ollama_github_pipeline import (
|
||||||
|
get_response,
|
||||||
|
on_startup,
|
||||||
|
on_shutdown,
|
||||||
|
)
|
||||||
|
|
||||||
app = FastAPI(docs_url="/docs", redoc_url=None)
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
await on_startup()
|
||||||
|
yield
|
||||||
|
await on_shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
app = FastAPI(docs_url="/docs", redoc_url=None, lifespan=lifespan)
|
||||||
|
|
||||||
|
|
||||||
origins = ["*"]
|
origins = ["*"]
|
||||||
|
@ -89,3 +89,13 @@ def get_response(
|
|||||||
)
|
)
|
||||||
|
|
||||||
return response["llm"]["replies"][0]
|
return response["llm"]["replies"][0]
|
||||||
|
|
||||||
|
|
||||||
|
async def on_startup():
|
||||||
|
# This function is called when the server is started.
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
async def on_shutdown():
|
||||||
|
# This function is called when the server is stopped.
|
||||||
|
pass
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
from typing import List, Union, Generator
|
from typing import List, Union, Generator
|
||||||
from schemas import OpenAIChatMessage
|
from schemas import OpenAIChatMessage
|
||||||
|
import os
|
||||||
|
import asyncio
|
||||||
|
|
||||||
from llama_index.embeddings.ollama import OllamaEmbedding
|
from llama_index.embeddings.ollama import OllamaEmbedding
|
||||||
from llama_index.llms.ollama import Ollama
|
from llama_index.llms.ollama import Ollama
|
||||||
@ -12,41 +14,9 @@ Settings.embed_model = OllamaEmbedding(
|
|||||||
)
|
)
|
||||||
Settings.llm = Ollama(model="llama3")
|
Settings.llm = Ollama(model="llama3")
|
||||||
|
|
||||||
import os
|
|
||||||
|
|
||||||
github_token = os.environ.get("GITHUB_TOKEN")
|
index = None
|
||||||
owner = "open-webui"
|
documents = None
|
||||||
repo = "open-webui"
|
|
||||||
branch = "main"
|
|
||||||
|
|
||||||
github_client = GithubClient(github_token=github_token, verbose=True)
|
|
||||||
|
|
||||||
documents = GithubRepositoryReader(
|
|
||||||
github_client=github_client,
|
|
||||||
owner=owner,
|
|
||||||
repo=repo,
|
|
||||||
use_parser=False,
|
|
||||||
verbose=False,
|
|
||||||
filter_directories=(
|
|
||||||
["docs"],
|
|
||||||
GithubRepositoryReader.FilterType.INCLUDE,
|
|
||||||
),
|
|
||||||
filter_file_extensions=(
|
|
||||||
[
|
|
||||||
".png",
|
|
||||||
".jpg",
|
|
||||||
".jpeg",
|
|
||||||
".gif",
|
|
||||||
".svg",
|
|
||||||
".ico",
|
|
||||||
"json",
|
|
||||||
".ipynb",
|
|
||||||
],
|
|
||||||
GithubRepositoryReader.FilterType.EXCLUDE,
|
|
||||||
),
|
|
||||||
).load_data(branch=branch)
|
|
||||||
|
|
||||||
index = VectorStoreIndex.from_documents(documents)
|
|
||||||
|
|
||||||
|
|
||||||
def get_response(
|
def get_response(
|
||||||
@ -62,3 +32,54 @@ def get_response(
|
|||||||
response = query_engine.query(user_message)
|
response = query_engine.query(user_message)
|
||||||
|
|
||||||
return response.response_gen
|
return response.response_gen
|
||||||
|
|
||||||
|
|
||||||
|
async def on_startup():
|
||||||
|
global index, documents
|
||||||
|
|
||||||
|
github_token = os.environ.get("GITHUB_TOKEN")
|
||||||
|
owner = "open-webui"
|
||||||
|
repo = "plugin-server"
|
||||||
|
branch = "main"
|
||||||
|
|
||||||
|
github_client = GithubClient(github_token=github_token, verbose=True)
|
||||||
|
|
||||||
|
reader = GithubRepositoryReader(
|
||||||
|
github_client=github_client,
|
||||||
|
owner=owner,
|
||||||
|
repo=repo,
|
||||||
|
use_parser=False,
|
||||||
|
verbose=False,
|
||||||
|
filter_file_extensions=(
|
||||||
|
[
|
||||||
|
".png",
|
||||||
|
".jpg",
|
||||||
|
".jpeg",
|
||||||
|
".gif",
|
||||||
|
".svg",
|
||||||
|
".ico",
|
||||||
|
"json",
|
||||||
|
".ipynb",
|
||||||
|
],
|
||||||
|
GithubRepositoryReader.FilterType.EXCLUDE,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
|
||||||
|
reader._loop = loop
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Load data from the branch
|
||||||
|
documents = await asyncio.to_thread(reader.load_data, branch=branch)
|
||||||
|
index = VectorStoreIndex.from_documents(documents)
|
||||||
|
finally:
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
print(documents)
|
||||||
|
print(index)
|
||||||
|
|
||||||
|
|
||||||
|
async def on_shutdown():
|
||||||
|
# This function is called when the pipeline is stopped.
|
||||||
|
pass
|
||||||
|
@ -30,3 +30,13 @@ def get_response(
|
|||||||
response = query_engine.query(user_message)
|
response = query_engine.query(user_message)
|
||||||
|
|
||||||
return response.response_gen
|
return response.response_gen
|
||||||
|
|
||||||
|
|
||||||
|
async def on_startup():
|
||||||
|
# This function is called when the server is started.
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
async def on_shutdown():
|
||||||
|
# This function is called when the server is stopped.
|
||||||
|
pass
|
||||||
|
@ -25,3 +25,13 @@ def get_response(
|
|||||||
response = query_engine.query(user_message)
|
response = query_engine.query(user_message)
|
||||||
|
|
||||||
return response.response_gen
|
return response.response_gen
|
||||||
|
|
||||||
|
|
||||||
|
async def on_startup():
|
||||||
|
# This function is called when the server is started.
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
async def on_shutdown():
|
||||||
|
# This function is called when the server is stopped.
|
||||||
|
pass
|
||||||
|
@ -11,3 +11,13 @@ def get_response(
|
|||||||
print(user_message)
|
print(user_message)
|
||||||
|
|
||||||
return f"rag response to: {user_message}"
|
return f"rag response to: {user_message}"
|
||||||
|
|
||||||
|
|
||||||
|
async def on_startup():
|
||||||
|
# This function is called when the server is started.
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
async def on_shutdown():
|
||||||
|
# This function is called when the server is stopped.
|
||||||
|
pass
|
||||||
|
Loading…
Reference in New Issue
Block a user