mirror of
https://github.com/open-webui/open-webui
synced 2025-01-18 00:30:51 +00:00
feat: preset backend logic
This commit is contained in:
parent
7d2ab168f1
commit
88d053833d
@ -875,15 +875,88 @@ async def generate_chat_completion(
|
||||
url_idx: Optional[int] = None,
|
||||
user=Depends(get_verified_user),
|
||||
):
|
||||
model_id = get_model_id_from_custom_model_id(form_data.model)
|
||||
model = model_id
|
||||
|
||||
log.debug(
|
||||
"form_data.model_dump_json(exclude_none=True).encode(): {0} ".format(
|
||||
form_data.model_dump_json(exclude_none=True).encode()
|
||||
)
|
||||
)
|
||||
|
||||
payload = {
|
||||
**form_data.model_dump(exclude_none=True),
|
||||
}
|
||||
|
||||
model_id = form_data.model
|
||||
model_info = Models.get_model_by_id(model_id)
|
||||
|
||||
if model_info:
|
||||
print(model_info)
|
||||
if model_info.base_model_id:
|
||||
payload["model"] = model_info.base_model_id
|
||||
|
||||
model_info.params = model_info.params.model_dump()
|
||||
|
||||
if model_info.params:
|
||||
payload["options"] = {}
|
||||
|
||||
payload["options"]["mirostat"] = model_info.params.get("mirostat", None)
|
||||
payload["options"]["mirostat_eta"] = model_info.params.get(
|
||||
"mirostat_eta", None
|
||||
)
|
||||
payload["options"]["mirostat_tau"] = model_info.params.get(
|
||||
"mirostat_tau", None
|
||||
)
|
||||
payload["options"]["num_ctx"] = model_info.params.get("num_ctx", None)
|
||||
|
||||
payload["options"]["repeat_last_n"] = model_info.params.get(
|
||||
"repeat_last_n", None
|
||||
)
|
||||
payload["options"]["repeat_penalty"] = model_info.params.get(
|
||||
"frequency_penalty", None
|
||||
)
|
||||
|
||||
payload["options"]["temperature"] = model_info.params.get(
|
||||
"temperature", None
|
||||
)
|
||||
payload["options"]["seed"] = model_info.params.get("seed", None)
|
||||
|
||||
# TODO: add "stop" back in
|
||||
# payload["stop"] = model_info.params.get("stop", None)
|
||||
|
||||
payload["options"]["tfs_z"] = model_info.params.get("tfs_z", None)
|
||||
|
||||
payload["options"]["num_predict"] = model_info.params.get(
|
||||
"max_tokens", None
|
||||
)
|
||||
payload["options"]["top_k"] = model_info.params.get("top_k", None)
|
||||
|
||||
payload["options"]["top_p"] = model_info.params.get("top_p", None)
|
||||
|
||||
if model_info.params.get("system", None):
|
||||
# Check if the payload already has a system message
|
||||
# If not, add a system message to the payload
|
||||
if payload.get("messages"):
|
||||
for message in payload["messages"]:
|
||||
if message.get("role") == "system":
|
||||
message["content"] = (
|
||||
model_info.params.get("system", None) + message["content"]
|
||||
)
|
||||
break
|
||||
else:
|
||||
payload["messages"].insert(
|
||||
0,
|
||||
{
|
||||
"role": "system",
|
||||
"content": model_info.params.get("system", None),
|
||||
},
|
||||
)
|
||||
|
||||
if url_idx == None:
|
||||
if ":" not in model:
|
||||
model = f"{model}:latest"
|
||||
if ":" not in payload["model"]:
|
||||
payload["model"] = f"{payload['model']}:latest"
|
||||
|
||||
if model in app.state.MODELS:
|
||||
url_idx = random.choice(app.state.MODELS[model]["urls"])
|
||||
if payload["model"] in app.state.MODELS:
|
||||
url_idx = random.choice(app.state.MODELS[payload["model"]]["urls"])
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
@ -893,23 +966,12 @@ async def generate_chat_completion(
|
||||
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
||||
log.info(f"url: {url}")
|
||||
|
||||
print(payload)
|
||||
|
||||
r = None
|
||||
|
||||
# payload = {
|
||||
# **form_data.model_dump_json(exclude_none=True).encode(),
|
||||
# "model": model,
|
||||
# "messages": form_data.messages,
|
||||
|
||||
# }
|
||||
|
||||
log.debug(
|
||||
"form_data.model_dump_json(exclude_none=True).encode(): {0} ".format(
|
||||
form_data.model_dump_json(exclude_none=True).encode()
|
||||
)
|
||||
)
|
||||
|
||||
def get_request():
|
||||
nonlocal form_data
|
||||
nonlocal payload
|
||||
nonlocal r
|
||||
|
||||
request_id = str(uuid.uuid4())
|
||||
@ -918,7 +980,7 @@ async def generate_chat_completion(
|
||||
|
||||
def stream_content():
|
||||
try:
|
||||
if form_data.stream:
|
||||
if payload.get("stream", None):
|
||||
yield json.dumps({"id": request_id, "done": False}) + "\n"
|
||||
|
||||
for chunk in r.iter_content(chunk_size=8192):
|
||||
@ -936,7 +998,7 @@ async def generate_chat_completion(
|
||||
r = requests.request(
|
||||
method="POST",
|
||||
url=f"{url}/api/chat",
|
||||
data=form_data.model_dump_json(exclude_none=True).encode(),
|
||||
data=json.dumps(payload),
|
||||
stream=True,
|
||||
)
|
||||
|
||||
@ -992,14 +1054,56 @@ async def generate_openai_chat_completion(
|
||||
user=Depends(get_verified_user),
|
||||
):
|
||||
|
||||
payload = {
|
||||
**form_data.model_dump(exclude_none=True),
|
||||
}
|
||||
|
||||
model_id = form_data.model
|
||||
model_info = Models.get_model_by_id(model_id)
|
||||
|
||||
if model_info:
|
||||
print(model_info)
|
||||
if model_info.base_model_id:
|
||||
payload["model"] = model_info.base_model_id
|
||||
|
||||
model_info.params = model_info.params.model_dump()
|
||||
|
||||
if model_info.params:
|
||||
payload["temperature"] = model_info.params.get("temperature", None)
|
||||
payload["top_p"] = model_info.params.get("top_p", None)
|
||||
payload["max_tokens"] = model_info.params.get("max_tokens", None)
|
||||
payload["frequency_penalty"] = model_info.params.get(
|
||||
"frequency_penalty", None
|
||||
)
|
||||
payload["seed"] = model_info.params.get("seed", None)
|
||||
# TODO: add "stop" back in
|
||||
# payload["stop"] = model_info.params.get("stop", None)
|
||||
|
||||
if model_info.params.get("system", None):
|
||||
# Check if the payload already has a system message
|
||||
# If not, add a system message to the payload
|
||||
if payload.get("messages"):
|
||||
for message in payload["messages"]:
|
||||
if message.get("role") == "system":
|
||||
message["content"] = (
|
||||
model_info.params.get("system", None) + message["content"]
|
||||
)
|
||||
break
|
||||
else:
|
||||
payload["messages"].insert(
|
||||
0,
|
||||
{
|
||||
"role": "system",
|
||||
"content": model_info.params.get("system", None),
|
||||
},
|
||||
)
|
||||
|
||||
if url_idx == None:
|
||||
model = form_data.model
|
||||
if ":" not in payload["model"]:
|
||||
payload["model"] = f"{payload['model']}:latest"
|
||||
|
||||
if ":" not in model:
|
||||
model = f"{model}:latest"
|
||||
|
||||
if model in app.state.MODELS:
|
||||
url_idx = random.choice(app.state.MODELS[model]["urls"])
|
||||
if payload["model"] in app.state.MODELS:
|
||||
url_idx = random.choice(app.state.MODELS[payload["model"]]["urls"])
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
@ -1012,7 +1116,7 @@ async def generate_openai_chat_completion(
|
||||
r = None
|
||||
|
||||
def get_request():
|
||||
nonlocal form_data
|
||||
nonlocal payload
|
||||
nonlocal r
|
||||
|
||||
request_id = str(uuid.uuid4())
|
||||
@ -1021,7 +1125,7 @@ async def generate_openai_chat_completion(
|
||||
|
||||
def stream_content():
|
||||
try:
|
||||
if form_data.stream:
|
||||
if payload.get("stream"):
|
||||
yield json.dumps(
|
||||
{"request_id": request_id, "done": False}
|
||||
) + "\n"
|
||||
@ -1041,7 +1145,7 @@ async def generate_openai_chat_completion(
|
||||
r = requests.request(
|
||||
method="POST",
|
||||
url=f"{url}/v1/chat/completions",
|
||||
data=form_data.model_dump_json(exclude_none=True).encode(),
|
||||
data=json.dumps(payload),
|
||||
stream=True,
|
||||
)
|
||||
|
||||
|
@ -315,41 +315,87 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
|
||||
body = await request.body()
|
||||
# TODO: Remove below after gpt-4-vision fix from Open AI
|
||||
# Try to decode the body of the request from bytes to a UTF-8 string (Require add max_token to fix gpt-4-vision)
|
||||
|
||||
payload = None
|
||||
|
||||
try:
|
||||
body = body.decode("utf-8")
|
||||
body = json.loads(body)
|
||||
if "chat/completions" in path:
|
||||
body = body.decode("utf-8")
|
||||
body = json.loads(body)
|
||||
|
||||
print(app.state.MODELS)
|
||||
payload = {**body}
|
||||
|
||||
model = app.state.MODELS[body.get("model")]
|
||||
model_id = body.get("model")
|
||||
model_info = Models.get_model_by_id(model_id)
|
||||
|
||||
idx = model["urlIdx"]
|
||||
if model_info:
|
||||
print(model_info)
|
||||
if model_info.base_model_id:
|
||||
payload["model"] = model_info.base_model_id
|
||||
|
||||
if "pipeline" in model and model.get("pipeline"):
|
||||
body["user"] = {"name": user.name, "id": user.id}
|
||||
body["title"] = (
|
||||
True if body["stream"] == False and body["max_tokens"] == 50 else False
|
||||
)
|
||||
model_info.params = model_info.params.model_dump()
|
||||
|
||||
# Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000
|
||||
# This is a workaround until OpenAI fixes the issue with this model
|
||||
if body.get("model") == "gpt-4-vision-preview":
|
||||
if "max_tokens" not in body:
|
||||
body["max_tokens"] = 4000
|
||||
log.debug("Modified body_dict:", body)
|
||||
if model_info.params:
|
||||
payload["temperature"] = model_info.params.get("temperature", None)
|
||||
payload["top_p"] = model_info.params.get("top_p", None)
|
||||
payload["max_tokens"] = model_info.params.get("max_tokens", None)
|
||||
payload["frequency_penalty"] = model_info.params.get(
|
||||
"frequency_penalty", None
|
||||
)
|
||||
payload["seed"] = model_info.params.get("seed", None)
|
||||
# TODO: add "stop" back in
|
||||
# payload["stop"] = model_info.params.get("stop", None)
|
||||
|
||||
# Fix for ChatGPT calls failing because the num_ctx key is in body
|
||||
if "num_ctx" in body:
|
||||
# If 'num_ctx' is in the dictionary, delete it
|
||||
# Leaving it there generates an error with the
|
||||
# OpenAI API (Feb 2024)
|
||||
del body["num_ctx"]
|
||||
if model_info.params.get("system", None):
|
||||
# Check if the payload already has a system message
|
||||
# If not, add a system message to the payload
|
||||
if payload.get("messages"):
|
||||
for message in payload["messages"]:
|
||||
if message.get("role") == "system":
|
||||
message["content"] = (
|
||||
model_info.params.get("system", None)
|
||||
+ message["content"]
|
||||
)
|
||||
break
|
||||
else:
|
||||
payload["messages"].insert(
|
||||
0,
|
||||
{
|
||||
"role": "system",
|
||||
"content": model_info.params.get("system", None),
|
||||
},
|
||||
)
|
||||
else:
|
||||
pass
|
||||
|
||||
print(app.state.MODELS)
|
||||
model = app.state.MODELS[payload.get("model")]
|
||||
|
||||
idx = model["urlIdx"]
|
||||
|
||||
if "pipeline" in model and model.get("pipeline"):
|
||||
payload["user"] = {"name": user.name, "id": user.id}
|
||||
payload["title"] = (
|
||||
True
|
||||
if payload["stream"] == False and payload["max_tokens"] == 50
|
||||
else False
|
||||
)
|
||||
|
||||
# Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000
|
||||
# This is a workaround until OpenAI fixes the issue with this model
|
||||
if payload.get("model") == "gpt-4-vision-preview":
|
||||
if "max_tokens" not in payload:
|
||||
payload["max_tokens"] = 4000
|
||||
log.debug("Modified payload:", payload)
|
||||
|
||||
# Convert the modified body back to JSON
|
||||
payload = json.dumps(payload)
|
||||
|
||||
# Convert the modified body back to JSON
|
||||
body = json.dumps(body)
|
||||
except json.JSONDecodeError as e:
|
||||
log.error("Error loading request body into a dictionary:", e)
|
||||
|
||||
print(payload)
|
||||
|
||||
url = app.state.config.OPENAI_API_BASE_URLS[idx]
|
||||
key = app.state.config.OPENAI_API_KEYS[idx]
|
||||
|
||||
@ -368,7 +414,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
|
||||
r = requests.request(
|
||||
method=request.method,
|
||||
url=target_url,
|
||||
data=body,
|
||||
data=payload if payload else body,
|
||||
headers=headers,
|
||||
stream=True,
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user