diff --git a/backend/open_webui/apps/ollama/main.py b/backend/open_webui/apps/ollama/main.py index 82a37a752..e0b376097 100644 --- a/backend/open_webui/apps/ollama/main.py +++ b/backend/open_webui/apps/ollama/main.py @@ -1032,6 +1032,82 @@ class OpenAIChatCompletionForm(BaseModel): model_config = ConfigDict(extra="allow") +class OpenAICompletionForm(BaseModel): + model: str + prompt: str + + model_config = ConfigDict(extra="allow") + + +@app.post("/v1/completions") +@app.post("/v1/completions/{url_idx}") +async def generate_openai_completion( + form_data: dict, url_idx: Optional[int] = None, user=Depends(get_verified_user) +): + try: + form_data = OpenAICompletionForm(**form_data) + except Exception as e: + log.exception(e) + raise HTTPException( + status_code=400, + detail=str(e), + ) + + payload = {**form_data.model_dump(exclude_none=True, exclude=["metadata"])} + if "metadata" in payload: + del payload["metadata"] + + model_id = form_data.model + if ":" not in model_id: + model_id = f"{model_id}:latest" + + model_info = Models.get_model_by_id(model_id) + if model_info: + if model_info.base_model_id: + payload["model"] = model_info.base_model_id + params = model_info.params.model_dump() + + if params: + payload = apply_model_params_to_body_openai(params, payload) + + # Check if user has access to the model + if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL: + if not ( + user.id == model_info.user_id + or has_access( + user.id, type="read", access_control=model_info.access_control + ) + ): + raise HTTPException( + status_code=403, + detail="Model not found", + ) + else: + if user.role != "admin": + raise HTTPException( + status_code=403, + detail="Model not found", + ) + + if ":" not in payload["model"]: + payload["model"] = f"{payload['model']}:latest" + + url = await get_ollama_url(url_idx, payload["model"]) + log.info(f"url: {url}") + + api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + prefix_id = api_config.get("prefix_id", None) + + if prefix_id: + payload["model"] = payload["model"].replace(f"{prefix_id}.", "") + + return await post_streaming_url( + f"{url}/v1/completions", + json.dumps(payload), + stream=payload.get("stream", False), + ) + + @app.post("/v1/chat/completions") @app.post("/v1/chat/completions/{url_idx}") async def generate_openai_chat_completion(