Fix Gemini

This commit is contained in:
Justin Hayes 2024-06-28 10:50:39 -04:00 committed by GitHub
parent 3a48d80123
commit 48f77c5798
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -72,56 +72,62 @@ class Pipeline:
def pipe(
self, user_message: str, model_id: str, messages: List[dict], body: dict
) -> Union[str, Iterator]:
"""The pipe function (connects open-webui to google-genai)
Args:
user_message (str): The last message input by the user
model_id (str): The model to use
messages (List[dict]): The chat history
body (dict): The raw request body in OpenAI's "chat/completions" style
Returns:
str: The complete response
Yields:
Iterator[str]: Yields a new message part every time it is received
"""
print(f"pipe:{__name__}")
system_prompt = None
google_messages = []
for message in messages:
google_role = ""
if message["role"] == "user":
google_role = "user"
elif message["role"] == "assistant":
google_role = "model"
elif message["role"] == "system":
if message["role"] == "system":
system_prompt = message["content"]
continue # System promt is not inyected as a message
google_messages.append(
genai.protos.Content(
role=google_role,
parts=[
genai.protos.Part(
text=message["content"],
),
],
continue
google_role = "user" if message["role"] == "user" else "model"
try:
content = message.get("content", "")
if isinstance(content, list):
# Handle potential multi-modal content
parts = []
for item in content:
if item["type"] == "text":
parts.append({"text": item["text"]})
# Add handling for other content types if necessary
else:
parts = [{"text": content}]
google_messages.append({
"role": google_role,
"parts": parts
})
except Exception as e:
print(f"Error processing message: {e}")
print(f"Problematic message: {message}")
# You might want to skip this message or handle the error differently
try:
model = genai.GenerativeModel(
f"models/{model_id}",
generation_config=genai.GenerationConfig(
temperature=body.get("temperature", 0.7),
top_p=body.get("top_p", 1.0),
top_k=body.get("top_k", 1),
max_output_tokens=body.get("max_tokens", 1024),
)
)
response = genai.GenerativeModel(
f"models/{model_id}", # we have to add the "models/" part again
system_instruction=system_prompt,
).generate_content(
google_messages,
stream=body["stream"],
)
response = model.generate_content(
google_messages,
stream=body["stream"],
)
if body["stream"]:
for chunk in response:
yield chunk.text
return ""
if body["stream"]:
for chunk in response:
yield chunk.text
return ""
return response.text
return response.text
except Exception as e:
print(f"Error generating content: {e}")
return f"An error occurred: {str(e)}"