mirror of
https://github.com/open-webui/open-webui
synced 2024-12-29 15:25:29 +00:00
enh: ollama /v1/completion
endpoint support
This commit is contained in:
parent
976676a482
commit
1439f6862d
@ -1032,6 +1032,82 @@ class OpenAIChatCompletionForm(BaseModel):
|
|||||||
model_config = ConfigDict(extra="allow")
|
model_config = ConfigDict(extra="allow")
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAICompletionForm(BaseModel):
|
||||||
|
model: str
|
||||||
|
prompt: str
|
||||||
|
|
||||||
|
model_config = ConfigDict(extra="allow")
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/v1/completions")
|
||||||
|
@app.post("/v1/completions/{url_idx}")
|
||||||
|
async def generate_openai_completion(
|
||||||
|
form_data: dict, url_idx: Optional[int] = None, user=Depends(get_verified_user)
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
form_data = OpenAICompletionForm(**form_data)
|
||||||
|
except Exception as e:
|
||||||
|
log.exception(e)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=str(e),
|
||||||
|
)
|
||||||
|
|
||||||
|
payload = {**form_data.model_dump(exclude_none=True, exclude=["metadata"])}
|
||||||
|
if "metadata" in payload:
|
||||||
|
del payload["metadata"]
|
||||||
|
|
||||||
|
model_id = form_data.model
|
||||||
|
if ":" not in model_id:
|
||||||
|
model_id = f"{model_id}:latest"
|
||||||
|
|
||||||
|
model_info = Models.get_model_by_id(model_id)
|
||||||
|
if model_info:
|
||||||
|
if model_info.base_model_id:
|
||||||
|
payload["model"] = model_info.base_model_id
|
||||||
|
params = model_info.params.model_dump()
|
||||||
|
|
||||||
|
if params:
|
||||||
|
payload = apply_model_params_to_body_openai(params, payload)
|
||||||
|
|
||||||
|
# Check if user has access to the model
|
||||||
|
if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL:
|
||||||
|
if not (
|
||||||
|
user.id == model_info.user_id
|
||||||
|
or has_access(
|
||||||
|
user.id, type="read", access_control=model_info.access_control
|
||||||
|
)
|
||||||
|
):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=403,
|
||||||
|
detail="Model not found",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if user.role != "admin":
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=403,
|
||||||
|
detail="Model not found",
|
||||||
|
)
|
||||||
|
|
||||||
|
if ":" not in payload["model"]:
|
||||||
|
payload["model"] = f"{payload['model']}:latest"
|
||||||
|
|
||||||
|
url = await get_ollama_url(url_idx, payload["model"])
|
||||||
|
log.info(f"url: {url}")
|
||||||
|
|
||||||
|
api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
|
||||||
|
prefix_id = api_config.get("prefix_id", None)
|
||||||
|
|
||||||
|
if prefix_id:
|
||||||
|
payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
|
||||||
|
|
||||||
|
return await post_streaming_url(
|
||||||
|
f"{url}/v1/completions",
|
||||||
|
json.dumps(payload),
|
||||||
|
stream=payload.get("stream", False),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/chat/completions")
|
@app.post("/v1/chat/completions")
|
||||||
@app.post("/v1/chat/completions/{url_idx}")
|
@app.post("/v1/chat/completions/{url_idx}")
|
||||||
async def generate_openai_chat_completion(
|
async def generate_openai_chat_completion(
|
||||||
|
Loading…
Reference in New Issue
Block a user