Merge pull request #470 from kikumoto/feature/update_aws_bedrock_claude_implementation

Feature/update aws bedrock claude implementation
This commit is contained in:
Tim Jaeryang Baek 2025-04-14 08:55:42 -07:00 committed by GitHub
commit ef900c4a3b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -12,7 +12,7 @@ import base64
import json import json
import logging import logging
from io import BytesIO from io import BytesIO
from typing import List, Union, Generator, Iterator from typing import List, Union, Generator, Iterator, Optional, Any
import boto3 import boto3
@ -23,12 +23,23 @@ import requests
from utils.pipelines.main import pop_system_message from utils.pipelines.main import pop_system_message
REASONING_EFFORT_BUDGET_TOKEN_MAP = {
"none": None,
"low": 1024,
"medium": 4096,
"high": 16384,
"max": 32768,
}
# Maximum combined token limit for Claude 3.7
MAX_COMBINED_TOKENS = 64000
class Pipeline: class Pipeline:
class Valves(BaseModel): class Valves(BaseModel):
AWS_ACCESS_KEY: str = "" AWS_ACCESS_KEY: Optional[str] = None
AWS_SECRET_KEY: str = "" AWS_SECRET_KEY: Optional[str] = None
AWS_REGION_NAME: str = "" AWS_REGION_NAME: Optional[str] = None
def __init__(self): def __init__(self):
self.type = "manifold" self.type = "manifold"
@ -47,21 +58,25 @@ class Pipeline:
} }
) )
self.bedrock = boto3.client(aws_access_key_id=self.valves.AWS_ACCESS_KEY, self.valves = self.Valves(
aws_secret_access_key=self.valves.AWS_SECRET_KEY, **{
service_name="bedrock", "AWS_ACCESS_KEY": os.getenv("AWS_ACCESS_KEY", ""),
region_name=self.valves.AWS_REGION_NAME) "AWS_SECRET_KEY": os.getenv("AWS_SECRET_KEY", ""),
self.bedrock_runtime = boto3.client(aws_access_key_id=self.valves.AWS_ACCESS_KEY, "AWS_REGION_NAME": os.getenv(
aws_secret_access_key=self.valves.AWS_SECRET_KEY, "AWS_REGION_NAME", os.getenv(
service_name="bedrock-runtime", "AWS_REGION", os.getenv("AWS_DEFAULT_REGION", "")
region_name=self.valves.AWS_REGION_NAME) )
),
}
)
self.pipelines = self.get_models() self.update_pipelines()
async def on_startup(self): async def on_startup(self):
# This function is called when the server is started. # This function is called when the server is started.
print(f"on_startup:{__name__}") print(f"on_startup:{__name__}")
self.update_pipelines()
pass pass
async def on_shutdown(self): async def on_shutdown(self):
@ -72,40 +87,58 @@ class Pipeline:
async def on_valves_updated(self): async def on_valves_updated(self):
# This function is called when the valves are updated. # This function is called when the valves are updated.
print(f"on_valves_updated:{__name__}") print(f"on_valves_updated:{__name__}")
self.bedrock = boto3.client(aws_access_key_id=self.valves.AWS_ACCESS_KEY, self.update_pipelines()
def update_pipelines(self) -> None:
try:
self.bedrock = boto3.client(service_name="bedrock",
aws_access_key_id=self.valves.AWS_ACCESS_KEY,
aws_secret_access_key=self.valves.AWS_SECRET_KEY, aws_secret_access_key=self.valves.AWS_SECRET_KEY,
service_name="bedrock",
region_name=self.valves.AWS_REGION_NAME) region_name=self.valves.AWS_REGION_NAME)
self.bedrock_runtime = boto3.client(aws_access_key_id=self.valves.AWS_ACCESS_KEY, self.bedrock_runtime = boto3.client(service_name="bedrock-runtime",
aws_access_key_id=self.valves.AWS_ACCESS_KEY,
aws_secret_access_key=self.valves.AWS_SECRET_KEY, aws_secret_access_key=self.valves.AWS_SECRET_KEY,
service_name="bedrock-runtime",
region_name=self.valves.AWS_REGION_NAME) region_name=self.valves.AWS_REGION_NAME)
self.pipelines = self.get_models() self.pipelines = self.get_models()
except Exception as e:
def pipelines(self) -> List[dict]: print(f"Error: {e}")
return self.get_models() self.pipelines = [
{
"id": "error",
"name": "Could not fetch models from Bedrock, please set up AWS Key/Secret or Instance/Task Role.",
},
]
def get_models(self): def get_models(self):
if self.valves.AWS_ACCESS_KEY and self.valves.AWS_SECRET_KEY:
try: try:
response = self.bedrock.list_foundation_models(byProvider='Anthropic', byInferenceType='ON_DEMAND') res = []
return [ response = self.bedrock.list_foundation_models(byProvider='Anthropic')
{ for model in response['modelSummaries']:
"id": model["modelId"], inference_types = model.get('inferenceTypesSupported', [])
"name": model["modelName"], if "ON_DEMAND" in inference_types:
} res.append({'id': model['modelId'], 'name': model['modelName']})
for model in response["modelSummaries"] elif "INFERENCE_PROFILE" in inference_types:
] inferenceProfileId = self.getInferenceProfileId(model['modelArn'])
if inferenceProfileId:
res.append({'id': inferenceProfileId, 'name': model['modelName']})
return res
except Exception as e: except Exception as e:
print(f"Error: {e}") print(f"Error: {e}")
return [ return [
{ {
"id": "error", "id": "error",
"name": "Could not fetch models from Bedrock, please update the Access/Secret Key in the valves.", "name": "Could not fetch models from Bedrock, please check permissoin.",
}, },
] ]
else:
return [] def getInferenceProfileId(self, modelArn: str) -> str:
response = self.bedrock.list_inference_profiles()
for profile in response.get('inferenceProfileSummaries', []):
for model in profile.get('models', []):
if model.get('modelArn') == modelArn:
return profile['inferenceProfileId']
return None
def pipe( def pipe(
self, user_message: str, model_id: str, messages: List[dict], body: dict self, user_message: str, model_id: str, messages: List[dict], body: dict
@ -139,11 +172,53 @@ class Pipeline:
payload = {"modelId": model_id, payload = {"modelId": model_id,
"messages": processed_messages, "messages": processed_messages,
"system": [{'text': system_message if system_message else 'you are an intelligent ai assistant'}], "system": [{'text': system_message["content"] if system_message else 'you are an intelligent ai assistant'}],
"inferenceConfig": {"temperature": body.get("temperature", 0.5)}, "inferenceConfig": {
"additionalModelRequestFields": {"top_k": body.get("top_k", 200), "top_p": body.get("top_p", 0.9)} "temperature": body.get("temperature", 0.5),
"topP": body.get("top_p", 0.9),
"maxTokens": body.get("max_tokens", 4096),
"stopSequences": body.get("stop", []),
},
"additionalModelRequestFields": {"top_k": body.get("top_k", 200)}
} }
if body.get("stream", False): if body.get("stream", False):
supports_thinking = "claude-3-7" in model_id
reasoning_effort = body.get("reasoning_effort", "none")
budget_tokens = REASONING_EFFORT_BUDGET_TOKEN_MAP.get(reasoning_effort)
# Allow users to input an integer value representing budget tokens
if (
not budget_tokens
and reasoning_effort not in REASONING_EFFORT_BUDGET_TOKEN_MAP.keys()
):
try:
budget_tokens = int(reasoning_effort)
except ValueError as e:
print("Failed to convert reasoning effort to int", e)
budget_tokens = None
if supports_thinking and budget_tokens:
# Check if the combined tokens (budget_tokens + max_tokens) exceeds the limit
max_tokens = payload.get("max_tokens", 4096)
combined_tokens = budget_tokens + max_tokens
if combined_tokens > MAX_COMBINED_TOKENS:
error_message = f"Error: Combined tokens (budget_tokens {budget_tokens} + max_tokens {max_tokens} = {combined_tokens}) exceeds the maximum limit of {MAX_COMBINED_TOKENS}"
print(error_message)
return error_message
payload["inferenceConfig"]["maxTokens"] = combined_tokens
payload["additionalModelRequestFields"]["thinking"] = {
"type": "enabled",
"budget_tokens": budget_tokens,
}
# Thinking requires temperature 1.0 and does not support top_p, top_k
payload["inferenceConfig"]["temperature"] = 1.0
if "top_k" in payload["additionalModelRequestFields"]:
del payload["additionalModelRequestFields"]["top_k"]
if "topP" in payload["inferenceConfig"]:
del payload["inferenceConfig"]["topP"]
return self.stream_response(model_id, payload) return self.stream_response(model_id, payload)
else: else:
return self.get_completion(model_id, payload) return self.get_completion(model_id, payload)
@ -152,30 +227,45 @@ class Pipeline:
def process_image(self, image: str): def process_image(self, image: str):
img_stream = None img_stream = None
if image["url"].startswith("data:image"): content_type = None
if ',' in image["url"]:
base64_string = image["url"].split(',')[1]
image_data = base64.b64decode(base64_string)
if image["url"].startswith("data:image"):
mime_type, base64_string = image["url"].split(",", 1)
content_type = mime_type.split(":")[1].split(";")[0]
image_data = base64.b64decode(base64_string)
img_stream = BytesIO(image_data) img_stream = BytesIO(image_data)
else: else:
img_stream = requests.get(image["url"]).content response = requests.get(image["url"])
img_stream = BytesIO(response.content)
content_type = response.headers.get('Content-Type', 'image/jpeg')
media_type = content_type.split('/')[-1] if '/' in content_type else content_type
return { return {
"image": {"format": "png" if image["url"].endswith(".png") else "jpeg", "image": {
"source": {"bytes": img_stream.read()}} "format": media_type,
"source": {"bytes": img_stream.read()}
}
} }
def stream_response(self, model_id: str, payload: dict) -> Generator: def stream_response(self, model_id: str, payload: dict) -> Generator:
if "system" in payload:
del payload["system"]
if "additionalModelRequestFields" in payload:
del payload["additionalModelRequestFields"]
streaming_response = self.bedrock_runtime.converse_stream(**payload) streaming_response = self.bedrock_runtime.converse_stream(**payload)
in_resasoning_context = False
for chunk in streaming_response["stream"]: for chunk in streaming_response["stream"]:
if "contentBlockDelta" in chunk: if in_resasoning_context and "contentBlockStop" in chunk:
in_resasoning_context = False
yield "\n </think> \n\n"
elif "contentBlockDelta" in chunk and "delta" in chunk["contentBlockDelta"]:
if "reasoningContent" in chunk["contentBlockDelta"]["delta"]:
if not in_resasoning_context:
yield "<think>"
in_resasoning_context = True
if "text" in chunk["contentBlockDelta"]["delta"]["reasoningContent"]:
yield chunk["contentBlockDelta"]["delta"]["reasoningContent"]["text"]
elif "text" in chunk["contentBlockDelta"]["delta"]:
yield chunk["contentBlockDelta"]["delta"]["text"] yield chunk["contentBlockDelta"]["delta"]["text"]
def get_completion(self, model_id: str, payload: dict) -> str: def get_completion(self, model_id: str, payload: dict) -> str:
response = self.bedrock_runtime.converse(**payload) response = self.bedrock_runtime.converse(**payload)
return response['output']['message']['content'][0]['text'] return response['output']['message']['content'][0]['text']