feat: ollama non streaming case working

This commit is contained in:
Samuel 2024-10-24 13:48:56 +00:00
parent bedcb6cce8
commit 05746a9960
3 changed files with 22 additions and 5 deletions

View File

@ -730,6 +730,7 @@ async def generate_completion(
class ChatMessage(BaseModel):
role: str
content: str
tool_calls: Optional[list[dict]] = None
images: Optional[list[str]] = None
@ -741,6 +742,7 @@ class GenerateChatCompletionForm(BaseModel):
template: Optional[str] = None
stream: Optional[bool] = None
keep_alive: Optional[Union[int, str]] = None
tools: Optional[list[dict]] = None
def get_ollama_url(url_idx: Optional[int], model: str):

View File

@ -567,16 +567,28 @@ async def handle_streaming_response(request: Request, response: Response,
async def handle_nonstreaming_response(request: Request, response: Response,
tools: dict, user: UserModel, data_items: list) -> JSONResponse:
# It only should be one response since we are in the non streaming scenario
content = ''
async for data in response.body_iterator:
content = data
content += data.decode()
citations = []
response_dict = json.loads(content)
body = json.loads(request._body)
if app.state.MODELS[body["model"]]["owned_by"] == "ollama":
is_ollama = True
is_openai = not is_ollama
while (is_ollama and "tool_calls" in response_dict.get("message", {})) or \
(is_openai and response_dict["choices"][0]["finish_reason"] == "tool_calls"):
if is_ollama:
message = response_dict.get("message", {})
tool_calls = message.get("tool_calls", [])
if message:
body["messages"].append(message)
else:
body["messages"].append(response_dict["choices"][0]["message"])
tool_calls = response_dict["choices"][0]["message"].get("tool_calls", [])
while response_dict["choices"][0]["finish_reason"] == "tool_calls":
body["messages"].append(response_dict["choices"][0]["message"])
tool_calls = response_dict["choices"][0]["message"].get("tool_calls", [])
for tool_call in tool_calls:
tool_function_name = tool_call["function"]["name"]
if not tool_call["function"]["arguments"]:
@ -601,7 +613,7 @@ async def handle_nonstreaming_response(request: Request, response: Response,
# Append the tool output to the messages
body["messages"].append({
"role": "tool",
"tool_call_id" : tool_call["id"],
"tool_call_id" : tool_call.get("id",""),
"name": tool_function_name,
"content": tool_output
})
@ -1315,6 +1327,7 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
if model["owned_by"] == "ollama":
# Using /ollama/api/chat endpoint
form_data = convert_payload_openai_to_ollama(form_data)
log.debug(f"{form_data=}")
form_data = GenerateChatCompletionForm(**form_data)
response = await generate_ollama_chat_completion(form_data=form_data, user=user)
if form_data.stream:

View File

@ -104,6 +104,8 @@ def convert_payload_openai_to_ollama(openai_payload: dict) -> dict:
ollama_payload["model"] = openai_payload.get("model")
ollama_payload["messages"] = openai_payload.get("messages")
ollama_payload["stream"] = openai_payload.get("stream", False)
if "tools" in openai_payload:
ollama_payload["tools"] = openai_payload["tools"]
# If there are advanced parameters in the payload, format them in Ollama's options field
ollama_options = {}