openapi-servers/servers/sql/main.py
Timothy Jaeryang Baek 2cc73881eb feat: sql server
2025-06-05 11:37:43 +04:00

122 lines
3.8 KiB
Python

import os
from fastapi import FastAPI, HTTPException, Query
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from typing import Optional
# --- LLM/SQL libraries ---
from langchain_experimental.sql import SQLDatabaseChain
from langchain_community.llms.openai import OpenAI
from langchain_community.utilities import SQLDatabase
from sqlalchemy.exc import SQLAlchemyError
# -- Load DB URL from environment variable --
DATABASE_URL = os.getenv("DATABASE_URL")
if not DATABASE_URL:
raise RuntimeError("DATABASE_URL environment variable must be set.")
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") # Set this in your environment
# -------------------------------
# Pydantic models
# -------------------------------
class SQLChatInput(BaseModel):
query: str = Field(
...,
description="Your question or task in natural language (e.g. 'Show me the top 10 customers by sales.')",
)
class SQLChatOutput(BaseModel):
sql: str = Field(..., description="SQL that was executed")
answer: str = Field(..., description="Answer to your query, from the database")
raw_result: Optional[list] = Field(
None, description="Raw result rows (list of dict/tuples)"
)
# -------------------------------
# API Setup
# -------------------------------
app = FastAPI(
title="Chat with SQL API",
version="1.0.0",
description=(
"Chat in natural language with any SQL database using LLMs. "
"Query and analyze your data conversationally!"
),
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# -------------------------------
# LLM + SQL Chain Setup (singleton)
# -------------------------------
def get_chain():
# Initiate reflected SQLAlchemy DB
db = SQLDatabase.from_uri(DATABASE_URL)
# LLM instance: using OpenAI GPT (or swap for your preferred)
llm = OpenAI(
temperature=0, openai_api_key=OPENAI_API_KEY, model_name="gpt-3.5-turbo"
)
return SQLDatabaseChain.from_llm(
llm, db, verbose=True, return_sql=True, return_intermediate_steps=True
)
sql_chain = get_chain()
# -------------------------------
# Schema endpoint
# -------------------------------
@app.get("/schema", summary="Get database schema overview")
def get_db_schema():
"""
Returns the tables and columns for the currently connected database.
"""
try:
db = sql_chain.database
return db.get_table_info()
except Exception as e:
raise HTTPException(
status_code=500, detail=f"Failed to retrieve schema info: {e}"
)
# -------------------------------
# Chatting endpoint
# -------------------------------
@app.post(
"/chat_sql", response_model=SQLChatOutput, summary="Chat with your SQL database"
)
def chat_sql(data: SQLChatInput):
"""
Enter a natural language instruction/question, get answer from your database.
"""
try:
# Run chain
result = sql_chain({"query": data.query})
# result example: {'result': 'Answer in plain text', 'intermediate_steps': {'sql_cmd': sql, ...}}
answer = result["result"]
sql = None
raw_result = None
if "intermediate_steps" in result and "sql_cmd" in result["intermediate_steps"]:
sql = result["intermediate_steps"]["sql_cmd"]
if "intermediate_steps" in result and "result" in result["intermediate_steps"]:
raw_result = result["intermediate_steps"]["result"]
return SQLChatOutput(sql=sql or "", answer=answer, raw_result=raw_result)
except SQLAlchemyError as e:
raise HTTPException(status_code=400, detail=f"Database error: {str(e)}")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error: {e}")