From 130e9cec3586295f05b2f369281ba942db266628 Mon Sep 17 00:00:00 2001 From: Justin Hayes Date: Fri, 28 Jun 2024 11:47:04 -0400 Subject: [PATCH] Refac --- .../providers/google_manifold_pipeline.py | 112 ++++++++++-------- 1 file changed, 63 insertions(+), 49 deletions(-) diff --git a/examples/pipelines/providers/google_manifold_pipeline.py b/examples/pipelines/providers/google_manifold_pipeline.py index 38fbfd2..c74c668 100644 --- a/examples/pipelines/providers/google_manifold_pipeline.py +++ b/examples/pipelines/providers/google_manifold_pipeline.py @@ -15,6 +15,7 @@ import os from pydantic import BaseModel import google.generativeai as genai +from google.generativeai.types import GenerationConfig class Pipeline: @@ -81,65 +82,78 @@ class Pipeline: def pipe( self, user_message: str, model_id: str, messages: List[dict], body: dict ) -> Union[str, Iterator]: - print(f"Pipe function called for model: {model_id}") - print(f"Stream mode: {body['stream']}") - - system_prompt = None - google_messages = [] - - for message in messages: - if message["role"] == "system": - system_prompt = message["content"] - continue - - google_role = "user" if message["role"] == "user" else "model" - - try: - content = message.get("content", "") - if isinstance(content, list): - parts = [] - for item in content: - if item["type"] == "text": - parts.append({"text": item["text"]}) - else: - parts = [{"text": content}] - - google_messages.append({ - "role": google_role, - "parts": parts - }) - except Exception as e: - print(f"Error processing message: {e}") - print(f"Problematic message: {message}") + if not self.valves.GOOGLE_API_KEY: + return "Error: GOOGLE_API_KEY is not set" try: - model = genai.GenerativeModel( - f"models/{model_id}", - generation_config=genai.GenerationConfig( - temperature=body.get("temperature", 0.7), - top_p=body.get("top_p", 1.0), - top_k=body.get("top_k", 1), - max_output_tokens=body.get("max_tokens", 1024), - ) + genai.configure(api_key=self.valves.GOOGLE_API_KEY) + + if model_id.startswith("google_genai."): + model_id = model_id[12:] + model_id = model_id.lstrip(".") + + if not model_id.startswith("gemini-"): + return f"Error: Invalid model name format: {model_id}" + + print(f"Pipe function called for model: {model_id}") + print(f"Stream mode: {body.get('stream', False)}") + + system_message = next((msg["content"] for msg in messages if msg["role"] == "system"), None) + + contents = [] + for message in messages: + if message["role"] != "system": + if isinstance(message.get("content"), list): + parts = [] + for content in message["content"]: + if content["type"] == "text": + parts.append({"text": content["text"]}) + elif content["type"] == "image_url": + image_url = content["image_url"]["url"] + if image_url.startswith("data:image"): + image_data = image_url.split(",")[1] + parts.append({"inline_data": {"mime_type": "image/jpeg", "data": image_data}}) + else: + parts.append({"image_url": image_url}) + contents.append({"role": message["role"], "parts": parts}) + else: + contents.append({ + "role": "user" if message["role"] == "user" else "model", + "parts": [{"text": message["content"]}] + }) + + if system_message: + contents.insert(0, {"role": "user", "parts": [{"text": f"System: {system_message}"}]}) + + model = genai.GenerativeModel(model_name=model_id) + + generation_config = GenerationConfig( + temperature=body.get("temperature", 0.7), + top_p=body.get("top_p", 0.9), + top_k=body.get("top_k", 40), + max_output_tokens=body.get("max_tokens", 8192), + stop_sequences=body.get("stop", []), ) + safety_settings = body.get("safety_settings") + response = model.generate_content( - google_messages, - stream=body["stream"], + contents, + generation_config=generation_config, + safety_settings=safety_settings, + stream=body.get("stream", False), ) - if body["stream"]: - print("Streaming response") - return (chunk.text for chunk in response) + if body.get("stream", False): + return self.stream_response(response) else: - print("Non-streaming response") - result = response.text - print(f"Generated content: {result}") - return result + return response.text except Exception as e: print(f"Error generating content: {e}") return f"An error occurred: {str(e)}" - finally: - print("Pipe function completed") + def stream_response(self, response): + for chunk in response: + if chunk.text: + yield chunk.text