mirror of
https://github.com/open-webui/open-webui
synced 2024-11-17 22:12:51 +00:00
refac: openai
This commit is contained in:
parent
8b6f422d45
commit
c44fc82ecd
@ -345,24 +345,17 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_use
|
||||
)
|
||||
|
||||
|
||||
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
|
||||
async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
|
||||
@app.post("/chat/completions")
|
||||
@app.post("/chat/completions/{url_idx}")
|
||||
async def generate_chat_completion(
|
||||
form_data: dict,
|
||||
url_idx: Optional[int] = None,
|
||||
user=Depends(get_verified_user),
|
||||
):
|
||||
idx = 0
|
||||
payload = {**form_data}
|
||||
|
||||
body = await request.body()
|
||||
# TODO: Remove below after gpt-4-vision fix from Open AI
|
||||
# Try to decode the body of the request from bytes to a UTF-8 string (Require add max_token to fix gpt-4-vision)
|
||||
|
||||
payload = None
|
||||
|
||||
try:
|
||||
if "chat/completions" in path:
|
||||
body = body.decode("utf-8")
|
||||
body = json.loads(body)
|
||||
|
||||
payload = {**body}
|
||||
|
||||
model_id = body.get("model")
|
||||
model_id = form_data.get("model")
|
||||
model_info = Models.get_model_by_id(model_id)
|
||||
|
||||
if model_info:
|
||||
@ -374,17 +367,13 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
|
||||
|
||||
if model_info.params:
|
||||
if model_info.params.get("temperature", None) is not None:
|
||||
payload["temperature"] = float(
|
||||
model_info.params.get("temperature")
|
||||
)
|
||||
payload["temperature"] = float(model_info.params.get("temperature"))
|
||||
|
||||
if model_info.params.get("top_p", None):
|
||||
payload["top_p"] = int(model_info.params.get("top_p", None))
|
||||
|
||||
if model_info.params.get("max_tokens", None):
|
||||
payload["max_tokens"] = int(
|
||||
model_info.params.get("max_tokens", None)
|
||||
)
|
||||
payload["max_tokens"] = int(model_info.params.get("max_tokens", None))
|
||||
|
||||
if model_info.params.get("frequency_penalty", None):
|
||||
payload["frequency_penalty"] = int(
|
||||
@ -411,8 +400,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
|
||||
for message in payload["messages"]:
|
||||
if message.get("role") == "system":
|
||||
message["content"] = (
|
||||
model_info.params.get("system", None)
|
||||
+ message["content"]
|
||||
model_info.params.get("system", None) + message["content"]
|
||||
)
|
||||
break
|
||||
else:
|
||||
@ -423,11 +411,11 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
|
||||
"content": model_info.params.get("system", None),
|
||||
},
|
||||
)
|
||||
|
||||
else:
|
||||
pass
|
||||
|
||||
model = app.state.MODELS[payload.get("model")]
|
||||
|
||||
idx = model["urlIdx"]
|
||||
|
||||
if "pipeline" in model and model.get("pipeline"):
|
||||
@ -443,15 +431,12 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
|
||||
# Convert the modified body back to JSON
|
||||
payload = json.dumps(payload)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
log.error("Error loading request body into a dictionary:", e)
|
||||
|
||||
print(payload)
|
||||
|
||||
url = app.state.config.OPENAI_API_BASE_URLS[idx]
|
||||
key = app.state.config.OPENAI_API_KEYS[idx]
|
||||
|
||||
target_url = f"{url}/{path}"
|
||||
print(payload)
|
||||
|
||||
headers = {}
|
||||
headers["Authorization"] = f"Bearer {key}"
|
||||
@ -464,9 +449,72 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
|
||||
try:
|
||||
session = aiohttp.ClientSession(trust_env=True)
|
||||
r = await session.request(
|
||||
method=request.method,
|
||||
url=target_url,
|
||||
data=payload if payload else body,
|
||||
method="POST",
|
||||
url=f"{url}/chat/completions",
|
||||
data=payload,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
r.raise_for_status()
|
||||
|
||||
# Check if response is SSE
|
||||
if "text/event-stream" in r.headers.get("Content-Type", ""):
|
||||
streaming = True
|
||||
return StreamingResponse(
|
||||
r.content,
|
||||
status_code=r.status,
|
||||
headers=dict(r.headers),
|
||||
background=BackgroundTask(
|
||||
cleanup_response, response=r, session=session
|
||||
),
|
||||
)
|
||||
else:
|
||||
response_data = await r.json()
|
||||
return response_data
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
error_detail = "Open WebUI: Server Connection Error"
|
||||
if r is not None:
|
||||
try:
|
||||
res = await r.json()
|
||||
print(res)
|
||||
if "error" in res:
|
||||
error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
|
||||
except:
|
||||
error_detail = f"External: {e}"
|
||||
raise HTTPException(status_code=r.status if r else 500, detail=error_detail)
|
||||
finally:
|
||||
if not streaming and session:
|
||||
if r:
|
||||
r.close()
|
||||
await session.close()
|
||||
|
||||
|
||||
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
|
||||
async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
|
||||
idx = 0
|
||||
|
||||
body = await request.body()
|
||||
|
||||
url = app.state.config.OPENAI_API_BASE_URLS[idx]
|
||||
key = app.state.config.OPENAI_API_KEYS[idx]
|
||||
|
||||
target_url = f"{url}/{path}"
|
||||
|
||||
headers = {}
|
||||
headers["Authorization"] = f"Bearer {key}"
|
||||
headers["Content-Type"] = "application/json"
|
||||
|
||||
r = None
|
||||
session = None
|
||||
streaming = False
|
||||
|
||||
try:
|
||||
session = aiohttp.ClientSession(trust_env=True)
|
||||
r = await session.request(
|
||||
method=request.method,
|
||||
url=target_url,
|
||||
data=body,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user