open-webui/backend/apps/openai/main.py
2024-01-20 04:17:06 +05:30

172 lines
5.8 KiB
Python

from fastapi import FastAPI, Request, Response, HTTPException, Depends
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, JSONResponse
import requests
import json
from pydantic import BaseModel
from apps.web.models.users import Users
from constants import ERROR_MESSAGES
from utils.utils import decode_token, get_current_user
from config import OPENAI_API_BASE_URL, OPENAI_API_KEY
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.state.OPENAI_API_BASE_URL = OPENAI_API_BASE_URL
app.state.OPENAI_API_KEY = OPENAI_API_KEY
class UrlUpdateForm(BaseModel):
url: str
class KeyUpdateForm(BaseModel):
key: str
@app.get("/url")
async def get_openai_url(user=Depends(get_current_user)):
if user and user.role == "admin":
return {"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL}
else:
raise HTTPException(status_code=401,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
@app.post("/url/update")
async def update_openai_url(form_data: UrlUpdateForm,
user=Depends(get_current_user)):
if user and user.role == "admin":
app.state.OPENAI_API_BASE_URL = form_data.url
return {"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL}
else:
raise HTTPException(status_code=401,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
@app.get("/key")
async def get_openai_key(user=Depends(get_current_user)):
if user and user.role == "admin":
return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY}
else:
raise HTTPException(status_code=401,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
@app.post("/key/update")
async def update_openai_key(form_data: KeyUpdateForm,
user=Depends(get_current_user)):
if user and user.role == "admin":
app.state.OPENAI_API_KEY = form_data.key
return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY}
else:
raise HTTPException(status_code=401,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def proxy(path: str, request: Request, user=Depends(get_current_user)):
target_url = f"{app.state.OPENAI_API_BASE_URL}/{path}"
print(target_url, app.state.OPENAI_API_KEY)
if user.role not in ["user", "admin"]:
raise HTTPException(status_code=401,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
if app.state.OPENAI_API_KEY == "":
raise HTTPException(status_code=401,
detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
# Try to decode the body of the request from bytes to a UTF-8 string (Require add max_token to fix gpt-4-vision)
try:
body_str = (await request.body()).decode('utf-8')
except UnicodeDecodeError as e:
print("Error decoding request body:", e)
raise HTTPException(status_code=400, detail="Invalid request body")
# Check if the body is not empty
if body_str:
try:
body_dict = json.loads(body_str)
except json.JSONDecodeError as e:
print("Error loading request body into a dictionary:", e)
raise HTTPException(status_code=400, detail="Invalid JSON in request body")
# Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 10000
# This is a workaround until OpenAI fixes the issue with this model
if body_dict.get("model") == "gpt-4-vision-preview":
body_dict["max_tokens"] = 10000
print("Modified body_dict:", body_dict)
# Try to convert the modified body back to JSON
try:
# Convert the modified body back to JSON
body_json = json.dumps(body_dict)
except TypeError as e:
print("Error converting modified body to JSON:", e)
raise HTTPException(status_code=500, detail="Internal server error")
else:
body_json = body_str # If the body is empty, use it as is
headers = {}
headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEY}"
headers["Content-Type"] = "application/json"
try:
r = requests.request(
method=request.method,
url=target_url,
data=body_json,
headers=headers,
stream=True,
)
r.raise_for_status()
# Check if response is SSE
if "text/event-stream" in r.headers.get("Content-Type", ""):
return StreamingResponse(
r.iter_content(chunk_size=8192),
status_code=r.status_code,
headers=dict(r.headers),
)
else:
# For non-SSE, read the response and return it
# response_data = (
# r.json()
# if r.headers.get("Content-Type", "")
# == "application/json"
# else r.text
# )
response_data = r.json()
print(type(response_data))
if "openai" in app.state.OPENAI_API_BASE_URL and path == "models":
response_data["data"] = list(
filter(lambda model: "gpt" in model["id"],
response_data["data"]))
return response_data
except Exception as e:
print(e)
error_detail = "Ollama WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"External: {res['error']}"
except:
error_detail = f"External: {e}"
raise HTTPException(status_code=r.status_code, detail=error_detail)