Merge branch 'open-webui:main' into mlx-manifold

This commit is contained in:
Justin Hayes 2024-07-03 10:54:08 -04:00 committed by GitHub
commit becb4f6a75
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -10,6 +10,7 @@ requirements: llama_index, sqlalchemy, psycopg2-binary
from typing import List, Union, Generator, Iterator
import os
from pydantic import BaseModel
from llama_index.llms.ollama import Ollama
from llama_index.core.query_engine import NLSQLTableQueryEngine
from llama_index.core import SQLDatabase, PromptTemplate
@ -17,23 +18,43 @@ from sqlalchemy import create_engine
class Pipeline:
class Valves(BaseModel):
DB_HOST: str
DB_PORT: str
DB_USER: str
DB_PASSWORD: str
DB_DATABASE: str
DB_TABLES: list[str]
OLLAMA_HOST: str
TEXT_TO_SQL_MODEL: str
# Update valves/ environment variables based on your selected database
def __init__(self):
self.PG_HOST = os.environ["PG_HOST"]
self.PG_PORT = os.environ["PG_PORT"]
self.PG_USER = os.environ["PG_USER"]
self.PG_PASSWORD = os.environ["PG_PASSWORD"]
self.PG_DB = os.environ["PG_DB"]
self.ollama_host = "http://host.docker.internal:11434" # Make sure to update with the URL of your Ollama host, such at http://localhost:11434 or remote server address
self.model = "phi3:medium-128k" # Model to use for text-to-SQL generation
self.name = "Database RAG Pipeline"
self.engine = None
self.nlsql_response = ""
self.tables = ["db_table"] # Update to the name of the database table you want to get data from
# Initialize
self.valves = self.Valves(
**{
"pipelines": ["*"], # Connect to all pipelines
"DB_HOST": os.environ["PG_HOST"], # Database hostname
"DB_PORT": os.environ["PG_PORT"], # Database port
"DB_USER": os.environ["PG_USER"], # User to connect to the database with
"DB_PASSWORD": os.environ["PG_PASSWORD"], # Password to connect to the database with
"DB_DATABASE": os.environ["PG_DB"], # Database to select on the DB instance
"DB_TABLES": ["albums"], # Table(s) to run queries against
"OLLAMA_HOST": "http://host.docker.internal:11434", # Make sure to update with the URL of your Ollama host, such as http://localhost:11434 or remote server address
"TEXT_TO_SQL_MODEL": "phi3:latest" # Model to use for text-to-SQL generation
}
)
def init_db_connection(self):
self.engine = create_engine(f"postgresql+psycopg2://{self.PG_USER}:{self.PG_PASSWORD}@{self.PG_HOST}:{self.PG_PORT}/{self.PG_DB}")
# Update your DB connection string based on selected DB engine - current connection string is for Postgres
self.engine = create_engine(f"postgresql+psycopg2://{self.valves.DB_USER}:{self.valves.DB_PASSWORD}@{self.valves.DB_HOST}:{self.valves.DB_PORT}/{self.valves.DB_DATABASE}")
return self.engine
async def on_startup(self):
# This function is called when the server is started.
self.init_db_connection()
@ -48,10 +69,10 @@ class Pipeline:
# Debug logging is required to see what SQL query is generated by the LlamaIndex library; enable on Pipelines server if needed
# Create database reader for Postgres
sql_database = SQLDatabase(self.engine, include_tables=self.tables)
sql_database = SQLDatabase(self.engine, include_tables=self.valves.DB_TABLES)
# Set up LLM connection; uses phi3 model with 128k context limit since some queries have returned 20k+ tokens
llm = Ollama(model=self.model, base_url=self.ollama_host, request_timeout=180.0, context_window=30000)
llm = Ollama(model=self.valves.TEXT_TO_SQL_MODEL, base_url=self.valves.OLLAMA_HOST, request_timeout=180.0, context_window=30000)
# Set up the custom prompt used when generating SQL queries from text
text_to_sql_prompt = """
@ -78,7 +99,7 @@ class Pipeline:
query_engine = NLSQLTableQueryEngine(
sql_database=sql_database,
tables=self.tables,
tables=self.valves.DB_TABLES,
llm=llm,
embed_model="local",
text_to_sql_prompt=text_to_sql_template,
@ -88,4 +109,3 @@ class Pipeline:
response = query_engine.query(user_message)
return response.response_gen