This commit is contained in:
Quantuary 2025-03-18 07:19:54 -07:00 committed by GitHub
commit 8296c84249
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 78 additions and 50 deletions

View File

@ -17,15 +17,12 @@ import os
from typing import Iterator, List, Union
import vertexai
from google import genai
from google.genai import types
from pydantic import BaseModel, Field
from vertexai.generative_models import (
Content,
GenerationConfig,
GenerativeModel,
HarmBlockThreshold,
HarmCategory,
Part,
)
from vertexai.generative_models import (Content, GenerationConfig,
GenerativeModel, HarmBlockThreshold,
HarmCategory, Part)
class Pipeline:
@ -46,14 +43,15 @@ class Pipeline:
**{
"GOOGLE_PROJECT_ID": os.getenv("GOOGLE_PROJECT_ID", ""),
"GOOGLE_CLOUD_REGION": os.getenv("GOOGLE_CLOUD_REGION", ""),
"GOOGLE_APPLICATION_CREDENTIALS": os.getenv("GOOGLE_APPLICATION_CREDENTIALS", ""),
"USE_PERMISSIVE_SAFETY": False,
}
)
self.pipelines = [
{"id": "gemini-1.5-flash-001", "name": "Gemini 1.5 Flash"},
{"id": "gemini-1.5-pro-001", "name": "Gemini 1.5 Pro"},
{"id": "gemini-flash-experimental", "name": "Gemini 1.5 Flash Experimental"},
{"id": "gemini-pro-experimental", "name": "Gemini 1.5 Pro Experimental"},
{"id": "gemini-2.0-flash", "name": "Gemini 2.0 Flash"},
{"id": "gemini-2.0-flash-lite-preview-02-05", "name": "Gemini 2.0 Flash Lite"},
{"id": "gemini-2.0-pro-exp-02-05", "name": "Gemini 2.0 Pro"},
{"id": "gemini-2.0-flash-thinking-exp-01-21", "name": "Gemini 2.0 Flash Thinking"},
]
async def on_startup(self) -> None:
@ -91,55 +89,85 @@ class Pipeline:
(msg["content"] for msg in messages if msg["role"] == "system"), None
)
model = GenerativeModel(
model_name=model_id,
system_instruction=system_message,
client = genai.Client(
vertexai=True,
project=self.valves.GOOGLE_PROJECT_ID,
location=self.valves.GOOGLE_CLOUD_REGION,
)
if body.get("title", False): # If chat title generation is requested
contents = [Content(role="user", parts=[Part.from_text(user_message)])]
else:
contents = self.build_conversation_history(messages)
contents = [
types.Content(
role="user",
parts=[types.Part(text=user_message)]
)
]
generation_config = GenerationConfig(
generate_content_config = types.GenerateContentConfig(
temperature=body.get("temperature", 0.7),
top_p=body.get("top_p", 0.9),
top_k=body.get("top_k", 40),
top_p=body.get("top_p", 0.95),
max_output_tokens=body.get("max_tokens", 8192),
stop_sequences=body.get("stop", []),
response_modalities=["TEXT"],
safety_settings=[
types.SafetySetting(
category="HARM_CATEGORY_HATE_SPEECH",
threshold="OFF"
),
types.SafetySetting(
category="HARM_CATEGORY_DANGEROUS_CONTENT",
threshold="OFF"
),
types.SafetySetting(
category="HARM_CATEGORY_SEXUALLY_EXPLICIT",
threshold="OFF"
),
types.SafetySetting(
category="HARM_CATEGORY_HARASSMENT",
threshold="OFF"
)
],
)
if self.valves.USE_PERMISSIVE_SAFETY:
safety_settings = {
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
}
else:
safety_settings = body.get("safety_settings")
is_streaming = body.get("stream", False)
try:
response = client.models.generate_content_stream(
model=model_id,
contents=contents,
config=generate_content_config,
)
response = model.generate_content(
contents,
stream=body.get("stream", False),
generation_config=generation_config,
safety_settings=safety_settings,
)
if is_streaming:
def stream_chunks():
try:
for chunk in response:
if chunk and chunk.text:
yield chunk.text
except Exception as e:
print(f"Streaming error: {e}")
yield f"Error during streaming: {str(e)}"
return stream_chunks()
else:
return ''.join(chunk.text for chunk in response)
if body.get("stream", False):
return self.stream_response(response)
else:
return response.text
except Exception as e:
error_msg = f"Generation error: {str(e)}"
print(error_msg)
return error_msg
except Exception as e:
print(f"Error generating content: {e}")
return f"An error occurred: {str(e)}"
error_msg = f"Pipeline error: {str(e)}"
print(error_msg)
return error_msg
def stream_response(self, response):
for chunk in response:
if chunk.text:
print(f"Chunk: {chunk.text}")
yield chunk.text
try:
for chunk in response:
if chunk and chunk.text:
print(f"Chunk: {chunk.text}")
yield chunk.text
except Exception as e:
print(f"Stream response error: {e}")
yield f"Error during streaming: {str(e)}"
def build_conversation_history(self, messages: List[dict]) -> List[Content]:
contents = []
@ -167,4 +195,4 @@ class Pipeline:
role = "user" if message["role"] == "user" else "model"
contents.append(Content(role=role, parts=parts))
return contents
return contents

View File

@ -16,7 +16,7 @@ httpx
# AI libraries
openai
anthropic
google-generativeai
google-genai==1.2.0
vertexai
# Database