fix: openai proxy

This commit is contained in:
Timothy J. Baek 2024-05-29 11:28:42 -07:00
parent 37c87e3a14
commit e427ef767b
2 changed files with 31 additions and 17 deletions

View File

@ -338,21 +338,36 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
model_info.params = model_info.params.model_dump() model_info.params = model_info.params.model_dump()
if model_info.params: if model_info.params:
payload["temperature"] = model_info.params.get("temperature", None) if model_info.params.get("temperature", None):
payload["top_p"] = model_info.params.get("top_p", None) payload["temperature"] = int(
payload["max_tokens"] = model_info.params.get("max_tokens", None) model_info.params.get("temperature")
payload["frequency_penalty"] = model_info.params.get( )
"frequency_penalty", None
) if model_info.params.get("top_p", None):
payload["seed"] = model_info.params.get("seed", None) payload["top_p"] = int(model_info.params.get("top_p", None))
payload["stop"] = (
[ if model_info.params.get("max_tokens", None):
bytes(stop, "utf-8").decode("unicode_escape") payload["max_tokens"] = int(
for stop in model_info.params["stop"] model_info.params.get("max_tokens", None)
] )
if model_info.params.get("stop", None)
else None if model_info.params.get("frequency_penalty", None):
) payload["frequency_penalty"] = int(
model_info.params.get("frequency_penalty", None)
)
if model_info.params.get("seed", None):
payload["seed"] = model_info.params.get("seed", None)
if model_info.params.get("stop", None):
payload["stop"] = (
[
bytes(stop, "utf-8").decode("unicode_escape")
for stop in model_info.params["stop"]
]
if model_info.params.get("stop", None)
else None
)
if model_info.params.get("system", None): if model_info.params.get("system", None):
# Check if the payload already has a system message # Check if the payload already has a system message
@ -376,7 +391,6 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
else: else:
pass pass
print(app.state.MODELS)
model = app.state.MODELS[payload.get("model")] model = app.state.MODELS[payload.get("model")]
idx = model["urlIdx"] idx = model["urlIdx"]
@ -442,6 +456,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
if r is not None: if r is not None:
try: try:
res = r.json() res = r.json()
print(res)
if "error" in res: if "error" in res:
error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}" error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
except: except:

View File

@ -82,7 +82,6 @@ async def update_model_by_id(
else: else:
if form_data.id in request.app.state.MODELS: if form_data.id in request.app.state.MODELS:
model = Models.insert_new_model(form_data, user.id) model = Models.insert_new_model(form_data, user.id)
print(model)
if model: if model:
return model return model
else: else: