mirror of
https://github.com/open-webui/pipelines
synced 2025-05-11 08:01:08 +00:00
Merge branch 'open-webui:main' into mlx-manifold
This commit is contained in:
commit
becb4f6a75
@ -10,6 +10,7 @@ requirements: llama_index, sqlalchemy, psycopg2-binary
|
||||
|
||||
from typing import List, Union, Generator, Iterator
|
||||
import os
|
||||
from pydantic import BaseModel
|
||||
from llama_index.llms.ollama import Ollama
|
||||
from llama_index.core.query_engine import NLSQLTableQueryEngine
|
||||
from llama_index.core import SQLDatabase, PromptTemplate
|
||||
@ -17,23 +18,43 @@ from sqlalchemy import create_engine
|
||||
|
||||
|
||||
class Pipeline:
|
||||
class Valves(BaseModel):
|
||||
DB_HOST: str
|
||||
DB_PORT: str
|
||||
DB_USER: str
|
||||
DB_PASSWORD: str
|
||||
DB_DATABASE: str
|
||||
DB_TABLES: list[str]
|
||||
OLLAMA_HOST: str
|
||||
TEXT_TO_SQL_MODEL: str
|
||||
|
||||
|
||||
# Update valves/ environment variables based on your selected database
|
||||
def __init__(self):
|
||||
self.PG_HOST = os.environ["PG_HOST"]
|
||||
self.PG_PORT = os.environ["PG_PORT"]
|
||||
self.PG_USER = os.environ["PG_USER"]
|
||||
self.PG_PASSWORD = os.environ["PG_PASSWORD"]
|
||||
self.PG_DB = os.environ["PG_DB"]
|
||||
self.ollama_host = "http://host.docker.internal:11434" # Make sure to update with the URL of your Ollama host, such at http://localhost:11434 or remote server address
|
||||
self.model = "phi3:medium-128k" # Model to use for text-to-SQL generation
|
||||
self.name = "Database RAG Pipeline"
|
||||
self.engine = None
|
||||
self.nlsql_response = ""
|
||||
self.tables = ["db_table"] # Update to the name of the database table you want to get data from
|
||||
|
||||
# Initialize
|
||||
self.valves = self.Valves(
|
||||
**{
|
||||
"pipelines": ["*"], # Connect to all pipelines
|
||||
"DB_HOST": os.environ["PG_HOST"], # Database hostname
|
||||
"DB_PORT": os.environ["PG_PORT"], # Database port
|
||||
"DB_USER": os.environ["PG_USER"], # User to connect to the database with
|
||||
"DB_PASSWORD": os.environ["PG_PASSWORD"], # Password to connect to the database with
|
||||
"DB_DATABASE": os.environ["PG_DB"], # Database to select on the DB instance
|
||||
"DB_TABLES": ["albums"], # Table(s) to run queries against
|
||||
"OLLAMA_HOST": "http://host.docker.internal:11434", # Make sure to update with the URL of your Ollama host, such as http://localhost:11434 or remote server address
|
||||
"TEXT_TO_SQL_MODEL": "phi3:latest" # Model to use for text-to-SQL generation
|
||||
}
|
||||
)
|
||||
|
||||
def init_db_connection(self):
|
||||
self.engine = create_engine(f"postgresql+psycopg2://{self.PG_USER}:{self.PG_PASSWORD}@{self.PG_HOST}:{self.PG_PORT}/{self.PG_DB}")
|
||||
# Update your DB connection string based on selected DB engine - current connection string is for Postgres
|
||||
self.engine = create_engine(f"postgresql+psycopg2://{self.valves.DB_USER}:{self.valves.DB_PASSWORD}@{self.valves.DB_HOST}:{self.valves.DB_PORT}/{self.valves.DB_DATABASE}")
|
||||
return self.engine
|
||||
|
||||
|
||||
async def on_startup(self):
|
||||
# This function is called when the server is started.
|
||||
self.init_db_connection()
|
||||
@ -48,10 +69,10 @@ class Pipeline:
|
||||
# Debug logging is required to see what SQL query is generated by the LlamaIndex library; enable on Pipelines server if needed
|
||||
|
||||
# Create database reader for Postgres
|
||||
sql_database = SQLDatabase(self.engine, include_tables=self.tables)
|
||||
sql_database = SQLDatabase(self.engine, include_tables=self.valves.DB_TABLES)
|
||||
|
||||
# Set up LLM connection; uses phi3 model with 128k context limit since some queries have returned 20k+ tokens
|
||||
llm = Ollama(model=self.model, base_url=self.ollama_host, request_timeout=180.0, context_window=30000)
|
||||
llm = Ollama(model=self.valves.TEXT_TO_SQL_MODEL, base_url=self.valves.OLLAMA_HOST, request_timeout=180.0, context_window=30000)
|
||||
|
||||
# Set up the custom prompt used when generating SQL queries from text
|
||||
text_to_sql_prompt = """
|
||||
@ -78,7 +99,7 @@ class Pipeline:
|
||||
|
||||
query_engine = NLSQLTableQueryEngine(
|
||||
sql_database=sql_database,
|
||||
tables=self.tables,
|
||||
tables=self.valves.DB_TABLES,
|
||||
llm=llm,
|
||||
embed_model="local",
|
||||
text_to_sql_prompt=text_to_sql_template,
|
||||
@ -88,4 +109,3 @@ class Pipeline:
|
||||
response = query_engine.query(user_message)
|
||||
|
||||
return response.response_gen
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user