mirror of
https://github.com/open-webui/open-webui
synced 2024-12-28 06:42:47 +00:00
enh: ollama /v1/completion
endpoint support
This commit is contained in:
parent
976676a482
commit
1439f6862d
@ -1032,6 +1032,82 @@ class OpenAIChatCompletionForm(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class OpenAICompletionForm(BaseModel):
|
||||
model: str
|
||||
prompt: str
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
@app.post("/v1/completions")
|
||||
@app.post("/v1/completions/{url_idx}")
|
||||
async def generate_openai_completion(
|
||||
form_data: dict, url_idx: Optional[int] = None, user=Depends(get_verified_user)
|
||||
):
|
||||
try:
|
||||
form_data = OpenAICompletionForm(**form_data)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
payload = {**form_data.model_dump(exclude_none=True, exclude=["metadata"])}
|
||||
if "metadata" in payload:
|
||||
del payload["metadata"]
|
||||
|
||||
model_id = form_data.model
|
||||
if ":" not in model_id:
|
||||
model_id = f"{model_id}:latest"
|
||||
|
||||
model_info = Models.get_model_by_id(model_id)
|
||||
if model_info:
|
||||
if model_info.base_model_id:
|
||||
payload["model"] = model_info.base_model_id
|
||||
params = model_info.params.model_dump()
|
||||
|
||||
if params:
|
||||
payload = apply_model_params_to_body_openai(params, payload)
|
||||
|
||||
# Check if user has access to the model
|
||||
if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL:
|
||||
if not (
|
||||
user.id == model_info.user_id
|
||||
or has_access(
|
||||
user.id, type="read", access_control=model_info.access_control
|
||||
)
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Model not found",
|
||||
)
|
||||
else:
|
||||
if user.role != "admin":
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Model not found",
|
||||
)
|
||||
|
||||
if ":" not in payload["model"]:
|
||||
payload["model"] = f"{payload['model']}:latest"
|
||||
|
||||
url = await get_ollama_url(url_idx, payload["model"])
|
||||
log.info(f"url: {url}")
|
||||
|
||||
api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
|
||||
prefix_id = api_config.get("prefix_id", None)
|
||||
|
||||
if prefix_id:
|
||||
payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
|
||||
|
||||
return await post_streaming_url(
|
||||
f"{url}/v1/completions",
|
||||
json.dumps(payload),
|
||||
stream=payload.get("stream", False),
|
||||
)
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
@app.post("/v1/chat/completions/{url_idx}")
|
||||
async def generate_openai_chat_completion(
|
||||
|
Loading…
Reference in New Issue
Block a user