Merge pull request #489 from Eisaichen/main

Update google_manifold_pipeline.py
This commit is contained in:
Tim Jaeryang Baek 2025-04-12 14:10:09 -07:00 committed by GitHub
commit 67699acde9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -5,7 +5,7 @@ date: 2024-06-06
version: 1.3
license: MIT
description: A pipeline for generating text using Google's GenAI models in Open-WebUI.
requirements: google-generativeai
requirements: google-genai
environment_variables: GOOGLE_API_KEY
"""
@ -14,8 +14,11 @@ import os
from pydantic import BaseModel, Field
import google.generativeai as genai
from google.generativeai.types import GenerationConfig
from google import genai
from google.genai import types
from PIL import Image
from io import BytesIO
import base64
class Pipeline:
@ -24,8 +27,9 @@ class Pipeline:
class Valves(BaseModel):
"""Options to change from the WebUI"""
GOOGLE_API_KEY: str = ""
USE_PERMISSIVE_SAFETY: bool = Field(default=False)
GOOGLE_API_KEY: str = Field(default="",description="Google Generative AI API key")
USE_PERMISSIVE_SAFETY: bool = Field(default=False,description="Use permissive safety settings")
GENERATE_IMAGE: bool = Field(default=False,description="Allow image generation")
def __init__(self):
self.type = "manifold"
@ -34,19 +38,20 @@ class Pipeline:
self.valves = self.Valves(**{
"GOOGLE_API_KEY": os.getenv("GOOGLE_API_KEY", ""),
"USE_PERMISSIVE_SAFETY": False
"USE_PERMISSIVE_SAFETY": False,
"GENERATE_IMAGE": False
})
self.pipelines = []
genai.configure(api_key=self.valves.GOOGLE_API_KEY)
self.update_pipelines()
if self.valves.GOOGLE_API_KEY:
self.update_pipelines()
async def on_startup(self) -> None:
"""This function is called when the server is started."""
print(f"on_startup:{__name__}")
genai.configure(api_key=self.valves.GOOGLE_API_KEY)
self.update_pipelines()
if self.valves.GOOGLE_API_KEY:
self.update_pipelines()
async def on_shutdown(self) -> None:
"""This function is called when the server is stopped."""
@ -57,22 +62,23 @@ class Pipeline:
"""This function is called when the valves are updated."""
print(f"on_valves_updated:{__name__}")
genai.configure(api_key=self.valves.GOOGLE_API_KEY)
self.update_pipelines()
if self.valves.GOOGLE_API_KEY:
self.update_pipelines()
def update_pipelines(self) -> None:
"""Update the available models from Google GenAI"""
if self.valves.GOOGLE_API_KEY:
client = genai.Client(api_key=self.valves.GOOGLE_API_KEY)
try:
models = genai.list_models()
models = client.models.list()
self.pipelines = [
{
"id": model.name[7:], # the "models/" part messeses up the URL
"name": model.display_name,
}
for model in models
if "generateContent" in model.supported_generation_methods
if "generateContent" in model.supported_actions
if model.name[:7] == "models/"
]
except Exception:
@ -92,13 +98,13 @@ class Pipeline:
return "Error: GOOGLE_API_KEY is not set"
try:
genai.configure(api_key=self.valves.GOOGLE_API_KEY)
client = genai.Client(api_key=self.valves.GOOGLE_API_KEY)
if model_id.startswith("google_genai."):
model_id = model_id[12:]
model_id = model_id.lstrip(".")
if not model_id.startswith("gemini-"):
if not (model_id.startswith("gemini-") or model_id.startswith("learnlm-") or model_id.startswith("gemma-")):
return f"Error: Invalid model name format: {model_id}"
print(f"Pipe function called for model: {model_id}")
@ -127,50 +133,78 @@ class Pipeline:
"role": "user" if message["role"] == "user" else "model",
"parts": [{"text": message["content"]}]
})
if "gemini-1.5" in model_id:
model = genai.GenerativeModel(model_name=model_id, system_instruction=system_message)
else:
if system_message:
contents.insert(0, {"role": "user", "parts": [{"text": f"System: {system_message}"}]})
model = genai.GenerativeModel(model_name=model_id)
print(f"{contents}")
generation_config = GenerationConfig(
temperature=body.get("temperature", 0.7),
top_p=body.get("top_p", 0.9),
top_k=body.get("top_k", 40),
max_output_tokens=body.get("max_tokens", 8192),
stop_sequences=body.get("stop", []),
)
generation_config = {
"temperature": body.get("temperature", 0.7),
"top_p": body.get("top_p", 0.9),
"top_k": body.get("top_k", 40),
"max_output_tokens": body.get("max_tokens", 8192),
"stop_sequences": body.get("stop", []),
"response_modalities": ['Text']
}
if self.valves.GENERATE_IMAGE and model_id.startswith("gemini-2.0-flash-exp"):
generation_config["response_modalities"].append("Image")
if self.valves.USE_PERMISSIVE_SAFETY:
safety_settings = {
genai.types.HarmCategory.HARM_CATEGORY_HARASSMENT: genai.types.HarmBlockThreshold.BLOCK_NONE,
genai.types.HarmCategory.HARM_CATEGORY_HATE_SPEECH: genai.types.HarmBlockThreshold.BLOCK_NONE,
genai.types.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: genai.types.HarmBlockThreshold.BLOCK_NONE,
genai.types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: genai.types.HarmBlockThreshold.BLOCK_NONE,
}
safety_settings = [
types.SafetySetting(category='HARM_CATEGORY_HARASSMENT', threshold='OFF'),
types.SafetySetting(category='HARM_CATEGORY_HATE_SPEECH', threshold='OFF'),
types.SafetySetting(category='HARM_CATEGORY_SEXUALLY_EXPLICIT', threshold='OFF'),
types.SafetySetting(category='HARM_CATEGORY_DANGEROUS_CONTENT', threshold='OFF'),
types.SafetySetting(category='HARM_CATEGORY_CIVIC_INTEGRITY', threshold='OFF')
]
generation_config = types.GenerateContentConfig(**generation_config, safety_settings=safety_settings)
else:
safety_settings = body.get("safety_settings")
generation_config = types.GenerateContentConfig(**generation_config)
response = model.generate_content(
contents,
generation_config=generation_config,
safety_settings=safety_settings,
stream=body.get("stream", False),
)
if system_message:
contents.insert(0, {"role": "user", "parts": [{"text": f"System: {system_message}"}]})
if body.get("stream", False):
response = client.models.generate_content_stream(
model = model_id,
contents = contents,
config = generation_config,
)
return self.stream_response(response)
else:
return response.text
response = client.models.generate_content(
model = model_id,
contents = contents,
config = generation_config,
)
for part in response.candidates[0].content.parts:
if part.text is not None:
return part.text
elif part.inline_data is not None:
try:
image_data = base64.b64decode(part.inline_data.data)
image = Image.open(BytesIO((image_data)))
content_type = part.inline_data.mime_type
return "Image not supported yet."
except Exception as e:
print(f"Error processing image: {e}")
return "Error processing image."
except Exception as e:
print(f"Error generating content: {e}")
return f"An error occurred: {str(e)}"
return f"{e}"
def stream_response(self, response):
for chunk in response:
if chunk.text:
yield chunk.text
for candidate in chunk.candidates:
if candidate.content.parts is not None:
for part in candidate.content.parts:
if part.text is not None:
yield chunk.text
elif part.inline_data is not None:
try:
image_data = base64.b64decode(part.inline_data.data)
image = Image.open(BytesIO(image_data))
content_type = part.inline_data.mime_type
yield "Image not supported yet."
except Exception as e:
print(f"Error processing image: {e}")
yield "Error processing image."