mirror of
https://github.com/open-webui/openapi-servers
synced 2025-06-26 18:17:04 +00:00
feat: sql server
This commit is contained in:
parent
c6a791d1e1
commit
2cc73881eb
121
servers/sql/main.py
Normal file
121
servers/sql/main.py
Normal file
@ -0,0 +1,121 @@
|
|||||||
|
import os
|
||||||
|
from fastapi import FastAPI, HTTPException, Query
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
# --- LLM/SQL libraries ---
|
||||||
|
from langchain_experimental.sql import SQLDatabaseChain
|
||||||
|
from langchain_community.llms.openai import OpenAI
|
||||||
|
from langchain_community.utilities import SQLDatabase
|
||||||
|
|
||||||
|
from sqlalchemy.exc import SQLAlchemyError
|
||||||
|
|
||||||
|
# -- Load DB URL from environment variable --
|
||||||
|
DATABASE_URL = os.getenv("DATABASE_URL")
|
||||||
|
if not DATABASE_URL:
|
||||||
|
raise RuntimeError("DATABASE_URL environment variable must be set.")
|
||||||
|
|
||||||
|
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") # Set this in your environment
|
||||||
|
|
||||||
|
|
||||||
|
# -------------------------------
|
||||||
|
# Pydantic models
|
||||||
|
# -------------------------------
|
||||||
|
class SQLChatInput(BaseModel):
|
||||||
|
query: str = Field(
|
||||||
|
...,
|
||||||
|
description="Your question or task in natural language (e.g. 'Show me the top 10 customers by sales.')",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SQLChatOutput(BaseModel):
|
||||||
|
sql: str = Field(..., description="SQL that was executed")
|
||||||
|
answer: str = Field(..., description="Answer to your query, from the database")
|
||||||
|
raw_result: Optional[list] = Field(
|
||||||
|
None, description="Raw result rows (list of dict/tuples)"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# -------------------------------
|
||||||
|
# API Setup
|
||||||
|
# -------------------------------
|
||||||
|
app = FastAPI(
|
||||||
|
title="Chat with SQL API",
|
||||||
|
version="1.0.0",
|
||||||
|
description=(
|
||||||
|
"Chat in natural language with any SQL database using LLMs. "
|
||||||
|
"Query and analyze your data conversationally!"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"],
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# -------------------------------
|
||||||
|
# LLM + SQL Chain Setup (singleton)
|
||||||
|
# -------------------------------
|
||||||
|
def get_chain():
|
||||||
|
# Initiate reflected SQLAlchemy DB
|
||||||
|
db = SQLDatabase.from_uri(DATABASE_URL)
|
||||||
|
# LLM instance: using OpenAI GPT (or swap for your preferred)
|
||||||
|
llm = OpenAI(
|
||||||
|
temperature=0, openai_api_key=OPENAI_API_KEY, model_name="gpt-3.5-turbo"
|
||||||
|
)
|
||||||
|
return SQLDatabaseChain.from_llm(
|
||||||
|
llm, db, verbose=True, return_sql=True, return_intermediate_steps=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
sql_chain = get_chain()
|
||||||
|
|
||||||
|
|
||||||
|
# -------------------------------
|
||||||
|
# Schema endpoint
|
||||||
|
# -------------------------------
|
||||||
|
@app.get("/schema", summary="Get database schema overview")
|
||||||
|
def get_db_schema():
|
||||||
|
"""
|
||||||
|
Returns the tables and columns for the currently connected database.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
db = sql_chain.database
|
||||||
|
return db.get_table_info()
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500, detail=f"Failed to retrieve schema info: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# -------------------------------
|
||||||
|
# Chatting endpoint
|
||||||
|
# -------------------------------
|
||||||
|
@app.post(
|
||||||
|
"/chat_sql", response_model=SQLChatOutput, summary="Chat with your SQL database"
|
||||||
|
)
|
||||||
|
def chat_sql(data: SQLChatInput):
|
||||||
|
"""
|
||||||
|
Enter a natural language instruction/question, get answer from your database.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Run chain
|
||||||
|
result = sql_chain({"query": data.query})
|
||||||
|
# result example: {'result': 'Answer in plain text', 'intermediate_steps': {'sql_cmd': sql, ...}}
|
||||||
|
answer = result["result"]
|
||||||
|
sql = None
|
||||||
|
raw_result = None
|
||||||
|
if "intermediate_steps" in result and "sql_cmd" in result["intermediate_steps"]:
|
||||||
|
sql = result["intermediate_steps"]["sql_cmd"]
|
||||||
|
if "intermediate_steps" in result and "result" in result["intermediate_steps"]:
|
||||||
|
raw_result = result["intermediate_steps"]["result"]
|
||||||
|
return SQLChatOutput(sql=sql or "", answer=answer, raw_result=raw_result)
|
||||||
|
except SQLAlchemyError as e:
|
||||||
|
raise HTTPException(status_code=400, detail=f"Database error: {str(e)}")
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=f"Error: {e}")
|
11
servers/sql/requirements.txt
Normal file
11
servers/sql/requirements.txt
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
fastapi
|
||||||
|
uvicorn[standard]
|
||||||
|
pydantic
|
||||||
|
python-multipart
|
||||||
|
|
||||||
|
langchain
|
||||||
|
langchain_community
|
||||||
|
langchain-openai
|
||||||
|
langchain-experimental
|
||||||
|
sentence_transformers
|
||||||
|
sqlalchemy
|
Loading…
Reference in New Issue
Block a user