Merge pull request #286 from abdessalaam/patch-1

Create azure_jais_core42_pipeline.py
This commit is contained in:
Timothy Jaeryang Baek 2024-10-19 13:56:19 -07:00 committed by GitHub
commit 0a0856479a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -0,0 +1,215 @@
"""
title: Jais Azure Pipeline with Stream Handling Fix
author: Abdessalaam Al-Alestini
date: 2024-06-20
version: 1.3
license: MIT
description: A pipeline for generating text using the Jais model via Azure AI Inference API, with fixed stream handling.
About Jais: https://inceptionai.ai/jais/
requirements: azure-ai-inference
environment_variables: AZURE_INFERENCE_CREDENTIAL, AZURE_INFERENCE_ENDPOINT, MODEL_ID
"""
import os
import json
import logging
from typing import List, Union, Generator, Iterator, Tuple
from pydantic import BaseModel
from azure.ai.inference import ChatCompletionsClient
from azure.core.credentials import AzureKeyCredential
from azure.ai.inference.models import SystemMessage, UserMessage, AssistantMessage
# Set up logging
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
def pop_system_message(messages: List[dict]) -> Tuple[str, List[dict]]:
"""
Extract the system message from the list of messages.
Args:
messages (List[dict]): List of message dictionaries.
Returns:
Tuple[str, List[dict]]: A tuple containing the system message (or empty string) and the updated list of messages.
"""
system_message = ""
updated_messages = []
for message in messages:
if message['role'] == 'system':
system_message = message['content']
else:
updated_messages.append(message)
return system_message, updated_messages
class Pipeline:
class Valves(BaseModel):
AZURE_INFERENCE_CREDENTIAL: str = ""
AZURE_INFERENCE_ENDPOINT: str = ""
MODEL_ID: str = "jais-30b-chat"
def __init__(self):
self.type = "manifold"
self.id = "jais-azure"
self.name = "jais-azure/"
self.valves = self.Valves(
**{
"AZURE_INFERENCE_CREDENTIAL":
os.getenv("AZURE_INFERENCE_CREDENTIAL",
"your-azure-inference-key-here"),
"AZURE_INFERENCE_ENDPOINT":
os.getenv("AZURE_INFERENCE_ENDPOINT",
"your-azure-inference-endpoint-here"),
"MODEL_ID":
os.getenv("MODEL_ID", "jais-30b-chat"),
})
self.update_client()
def update_client(self):
self.client = ChatCompletionsClient(
endpoint=self.valves.AZURE_INFERENCE_ENDPOINT,
credential=AzureKeyCredential(
self.valves.AZURE_INFERENCE_CREDENTIAL))
def get_jais_models(self):
return [
{
"id": "jais-30b-chat",
"name": "Jais 30B Chat"
},
]
async def on_startup(self):
logger.info(f"on_startup:{__name__}")
pass
async def on_shutdown(self):
logger.info(f"on_shutdown:{__name__}")
pass
async def on_valves_updated(self):
self.update_client()
def pipelines(self) -> List[dict]:
return self.get_jais_models()
def pipe(self, user_message: str, model_id: str, messages: List[dict],
body: dict) -> Union[str, Generator, Iterator]:
try:
logger.debug(
f"Received request - user_message: {user_message}, model_id: {model_id}"
)
logger.debug(f"Messages: {json.dumps(messages, indent=2)}")
logger.debug(f"Body: {json.dumps(body, indent=2)}")
# Remove unnecessary keys
for key in ['user', 'chat_id', 'title']:
body.pop(key, None)
system_message, messages = pop_system_message(messages)
# Prepare messages for Jais
jais_messages = [SystemMessage(
content=system_message)] if system_message else []
jais_messages += [
UserMessage(content=msg['content']) if msg['role'] == 'user'
else SystemMessage(content=msg['content']) if msg['role']
== 'system' else AssistantMessage(content=msg['content'])
for msg in messages
]
# Prepare the payload
allowed_params = {
'temperature', 'max_tokens', 'presence_penalty',
'frequency_penalty', 'top_p'
}
filtered_body = {
k: v
for k, v in body.items() if k in allowed_params
}
logger.debug(f"Prepared Jais messages: {jais_messages}")
logger.debug(f"Filtered body: {filtered_body}")
is_stream = body.get("stream", False)
if is_stream:
return self.stream_response(jais_messages, filtered_body)
else:
return self.get_completion(jais_messages, filtered_body)
except Exception as e:
logger.error(f"Error in pipe: {str(e)}", exc_info=True)
return json.dumps({"error": str(e)})
def stream_response(self, jais_messages: List[Union[SystemMessage, UserMessage, AssistantMessage]], params: dict) -> str:
try:
complete_response = ""
response = self.client.complete(messages=jais_messages,
model=self.valves.MODEL_ID,
stream=True,
**params)
for update in response:
if update.choices:
delta_content = update.choices[0].delta.content
if delta_content:
complete_response += delta_content
return complete_response
except Exception as e:
logger.error(f"Error in stream_response: {str(e)}", exc_info=True)
return json.dumps({"error": str(e)})
def get_completion(self, jais_messages: List[Union[SystemMessage, UserMessage, AssistantMessage]], params: dict) -> str:
try:
response = self.client.complete(messages=jais_messages,
model=self.valves.MODEL_ID,
**params)
if response.choices:
result = response.choices[0].message.content
logger.debug(f"Completion result: {result}")
return result
else:
logger.warning("No choices in completion response")
return ""
except Exception as e:
logger.error(f"Error in get_completion: {str(e)}", exc_info=True)
return json.dumps({"error": str(e)})
# TEST CASE TO RUN THE PIPELINE
if __name__ == "__main__":
pipeline = Pipeline()
messages = [{
"role": "user",
"content": "How many languages are in the world?"
}]
body = {
"temperature": 0.5,
"max_tokens": 150,
"presence_penalty": 0.1,
"frequency_penalty": 0.8,
"stream": True # Change to True to test streaming
}
result = pipeline.pipe(user_message="How many languages are in the world?",
model_id="jais-30b-chat",
messages=messages,
body=body)
# Handle streaming result
if isinstance(result, str):
content = json.dumps({"content": result}, ensure_ascii=False)
print(content)
else:
complete_response = ""
for part in result:
content_delta = json.loads(part).get("delta")
if content_delta:
complete_response += content_delta
print(json.dumps({"content": complete_response}, ensure_ascii=False))