From 05746a9960269ccc4faf7dc29cc189f43cb2c539 Mon Sep 17 00:00:00 2001 From: Samuel Date: Thu, 24 Oct 2024 13:48:56 +0000 Subject: [PATCH] feat: ollama non streaming case working --- backend/open_webui/apps/ollama/main.py | 2 ++ backend/open_webui/main.py | 23 ++++++++++++++++++----- backend/open_webui/utils/payload.py | 2 ++ 3 files changed, 22 insertions(+), 5 deletions(-) diff --git a/backend/open_webui/apps/ollama/main.py b/backend/open_webui/apps/ollama/main.py index f835e3175..f3ca2876c 100644 --- a/backend/open_webui/apps/ollama/main.py +++ b/backend/open_webui/apps/ollama/main.py @@ -730,6 +730,7 @@ async def generate_completion( class ChatMessage(BaseModel): role: str content: str + tool_calls: Optional[list[dict]] = None images: Optional[list[str]] = None @@ -741,6 +742,7 @@ class GenerateChatCompletionForm(BaseModel): template: Optional[str] = None stream: Optional[bool] = None keep_alive: Optional[Union[int, str]] = None + tools: Optional[list[dict]] = None def get_ollama_url(url_idx: Optional[int], model: str): diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 8cf5bfa6f..7e81f82be 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -567,16 +567,28 @@ async def handle_streaming_response(request: Request, response: Response, async def handle_nonstreaming_response(request: Request, response: Response, tools: dict, user: UserModel, data_items: list) -> JSONResponse: # It only should be one response since we are in the non streaming scenario + content = '' async for data in response.body_iterator: - content = data + content += data.decode() citations = [] response_dict = json.loads(content) body = json.loads(request._body) + if app.state.MODELS[body["model"]]["owned_by"] == "ollama": + is_ollama = True + is_openai = not is_ollama + + while (is_ollama and "tool_calls" in response_dict.get("message", {})) or \ + (is_openai and response_dict["choices"][0]["finish_reason"] == "tool_calls"): + if is_ollama: + message = response_dict.get("message", {}) + tool_calls = message.get("tool_calls", []) + if message: + body["messages"].append(message) + else: + body["messages"].append(response_dict["choices"][0]["message"]) + tool_calls = response_dict["choices"][0]["message"].get("tool_calls", []) - while response_dict["choices"][0]["finish_reason"] == "tool_calls": - body["messages"].append(response_dict["choices"][0]["message"]) - tool_calls = response_dict["choices"][0]["message"].get("tool_calls", []) for tool_call in tool_calls: tool_function_name = tool_call["function"]["name"] if not tool_call["function"]["arguments"]: @@ -601,7 +613,7 @@ async def handle_nonstreaming_response(request: Request, response: Response, # Append the tool output to the messages body["messages"].append({ "role": "tool", - "tool_call_id" : tool_call["id"], + "tool_call_id" : tool_call.get("id",""), "name": tool_function_name, "content": tool_output }) @@ -1315,6 +1327,7 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u if model["owned_by"] == "ollama": # Using /ollama/api/chat endpoint form_data = convert_payload_openai_to_ollama(form_data) + log.debug(f"{form_data=}") form_data = GenerateChatCompletionForm(**form_data) response = await generate_ollama_chat_completion(form_data=form_data, user=user) if form_data.stream: diff --git a/backend/open_webui/utils/payload.py b/backend/open_webui/utils/payload.py index 72aec6a6c..57d7a2919 100644 --- a/backend/open_webui/utils/payload.py +++ b/backend/open_webui/utils/payload.py @@ -104,6 +104,8 @@ def convert_payload_openai_to_ollama(openai_payload: dict) -> dict: ollama_payload["model"] = openai_payload.get("model") ollama_payload["messages"] = openai_payload.get("messages") ollama_payload["stream"] = openai_payload.get("stream", False) + if "tools" in openai_payload: + ollama_payload["tools"] = openai_payload["tools"] # If there are advanced parameters in the payload, format them in Ollama's options field ollama_options = {}