This commit is contained in:
Quantuary 2025-03-18 07:19:54 -07:00 committed by GitHub
commit 8296c84249
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 78 additions and 50 deletions

View File

@ -17,15 +17,12 @@ import os
from typing import Iterator, List, Union from typing import Iterator, List, Union
import vertexai import vertexai
from google import genai
from google.genai import types
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from vertexai.generative_models import ( from vertexai.generative_models import (Content, GenerationConfig,
Content, GenerativeModel, HarmBlockThreshold,
GenerationConfig, HarmCategory, Part)
GenerativeModel,
HarmBlockThreshold,
HarmCategory,
Part,
)
class Pipeline: class Pipeline:
@ -46,14 +43,15 @@ class Pipeline:
**{ **{
"GOOGLE_PROJECT_ID": os.getenv("GOOGLE_PROJECT_ID", ""), "GOOGLE_PROJECT_ID": os.getenv("GOOGLE_PROJECT_ID", ""),
"GOOGLE_CLOUD_REGION": os.getenv("GOOGLE_CLOUD_REGION", ""), "GOOGLE_CLOUD_REGION": os.getenv("GOOGLE_CLOUD_REGION", ""),
"GOOGLE_APPLICATION_CREDENTIALS": os.getenv("GOOGLE_APPLICATION_CREDENTIALS", ""),
"USE_PERMISSIVE_SAFETY": False, "USE_PERMISSIVE_SAFETY": False,
} }
) )
self.pipelines = [ self.pipelines = [
{"id": "gemini-1.5-flash-001", "name": "Gemini 1.5 Flash"}, {"id": "gemini-2.0-flash", "name": "Gemini 2.0 Flash"},
{"id": "gemini-1.5-pro-001", "name": "Gemini 1.5 Pro"}, {"id": "gemini-2.0-flash-lite-preview-02-05", "name": "Gemini 2.0 Flash Lite"},
{"id": "gemini-flash-experimental", "name": "Gemini 1.5 Flash Experimental"}, {"id": "gemini-2.0-pro-exp-02-05", "name": "Gemini 2.0 Pro"},
{"id": "gemini-pro-experimental", "name": "Gemini 1.5 Pro Experimental"}, {"id": "gemini-2.0-flash-thinking-exp-01-21", "name": "Gemini 2.0 Flash Thinking"},
] ]
async def on_startup(self) -> None: async def on_startup(self) -> None:
@ -91,55 +89,85 @@ class Pipeline:
(msg["content"] for msg in messages if msg["role"] == "system"), None (msg["content"] for msg in messages if msg["role"] == "system"), None
) )
model = GenerativeModel( client = genai.Client(
model_name=model_id, vertexai=True,
system_instruction=system_message, project=self.valves.GOOGLE_PROJECT_ID,
location=self.valves.GOOGLE_CLOUD_REGION,
) )
if body.get("title", False): # If chat title generation is requested contents = [
contents = [Content(role="user", parts=[Part.from_text(user_message)])] types.Content(
else: role="user",
contents = self.build_conversation_history(messages) parts=[types.Part(text=user_message)]
)
]
generation_config = GenerationConfig( generate_content_config = types.GenerateContentConfig(
temperature=body.get("temperature", 0.7), temperature=body.get("temperature", 0.7),
top_p=body.get("top_p", 0.9), top_p=body.get("top_p", 0.95),
top_k=body.get("top_k", 40),
max_output_tokens=body.get("max_tokens", 8192), max_output_tokens=body.get("max_tokens", 8192),
stop_sequences=body.get("stop", []), response_modalities=["TEXT"],
safety_settings=[
types.SafetySetting(
category="HARM_CATEGORY_HATE_SPEECH",
threshold="OFF"
),
types.SafetySetting(
category="HARM_CATEGORY_DANGEROUS_CONTENT",
threshold="OFF"
),
types.SafetySetting(
category="HARM_CATEGORY_SEXUALLY_EXPLICIT",
threshold="OFF"
),
types.SafetySetting(
category="HARM_CATEGORY_HARASSMENT",
threshold="OFF"
)
],
) )
if self.valves.USE_PERMISSIVE_SAFETY: is_streaming = body.get("stream", False)
safety_settings = { try:
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE, response = client.models.generate_content_stream(
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE, model=model_id,
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE, contents=contents,
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE, config=generate_content_config,
}
else:
safety_settings = body.get("safety_settings")
response = model.generate_content(
contents,
stream=body.get("stream", False),
generation_config=generation_config,
safety_settings=safety_settings,
) )
if body.get("stream", False): if is_streaming:
return self.stream_response(response) def stream_chunks():
try:
for chunk in response:
if chunk and chunk.text:
yield chunk.text
except Exception as e:
print(f"Streaming error: {e}")
yield f"Error during streaming: {str(e)}"
return stream_chunks()
else: else:
return response.text return ''.join(chunk.text for chunk in response)
except Exception as e: except Exception as e:
print(f"Error generating content: {e}") error_msg = f"Generation error: {str(e)}"
return f"An error occurred: {str(e)}" print(error_msg)
return error_msg
except Exception as e:
error_msg = f"Pipeline error: {str(e)}"
print(error_msg)
return error_msg
def stream_response(self, response): def stream_response(self, response):
try:
for chunk in response: for chunk in response:
if chunk.text: if chunk and chunk.text:
print(f"Chunk: {chunk.text}") print(f"Chunk: {chunk.text}")
yield chunk.text yield chunk.text
except Exception as e:
print(f"Stream response error: {e}")
yield f"Error during streaming: {str(e)}"
def build_conversation_history(self, messages: List[dict]) -> List[Content]: def build_conversation_history(self, messages: List[dict]) -> List[Content]:
contents = [] contents = []

View File

@ -16,7 +16,7 @@ httpx
# AI libraries # AI libraries
openai openai
anthropic anthropic
google-generativeai google-genai==1.2.0
vertexai vertexai
# Database # Database