mirror of
				https://github.com/open-webui/pipelines
				synced 2025-06-26 18:15:58 +00:00 
			
		
		
		
	fix: anthropic system
This commit is contained in:
		
							parent
							
								
									a6daafe1f2
								
							
						
					
					
						commit
						0f610ca7eb
					
				| @ -17,6 +17,8 @@ from typing import List, Union, Generator, Iterator | |||||||
| from pydantic import BaseModel | from pydantic import BaseModel | ||||||
| import requests | import requests | ||||||
| 
 | 
 | ||||||
|  | from utils.pipelines.main import pop_system_message | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| class Pipeline: | class Pipeline: | ||||||
|     class Valves(BaseModel): |     class Valves(BaseModel): | ||||||
| @ -80,6 +82,8 @@ class Pipeline: | |||||||
|     def stream_response( |     def stream_response( | ||||||
|         self, model_id: str, messages: List[dict], body: dict |         self, model_id: str, messages: List[dict], body: dict | ||||||
|     ) -> Generator: |     ) -> Generator: | ||||||
|  |         system_message, messages = pop_system_message(messages) | ||||||
|  | 
 | ||||||
|         max_tokens = ( |         max_tokens = ( | ||||||
|             body.get("max_tokens") if body.get("max_tokens") is not None else 4096 |             body.get("max_tokens") if body.get("max_tokens") is not None else 4096 | ||||||
|         ) |         ) | ||||||
| @ -91,14 +95,19 @@ class Pipeline: | |||||||
|         stop_sequences = body.get("stop") if body.get("stop") is not None else [] |         stop_sequences = body.get("stop") if body.get("stop") is not None else [] | ||||||
| 
 | 
 | ||||||
|         stream = self.client.messages.create( |         stream = self.client.messages.create( | ||||||
|             model=model_id, |             **{ | ||||||
|             messages=messages, |                 "model": model_id, | ||||||
|             max_tokens=max_tokens, |                 **( | ||||||
|             temperature=temperature, |                     {"system": system_message} if system_message else {} | ||||||
|             top_k=top_k, |                 ),  # Add system message if it exists (optional | ||||||
|             top_p=top_p, |                 "messages": messages, | ||||||
|             stop_sequences=stop_sequences, |                 "max_tokens": max_tokens, | ||||||
|             stream=True, |                 "temperature": temperature, | ||||||
|  |                 "top_k": top_k, | ||||||
|  |                 "top_p": top_p, | ||||||
|  |                 "stop_sequences": stop_sequences, | ||||||
|  |                 "stream": True, | ||||||
|  |             } | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|         for chunk in stream: |         for chunk in stream: | ||||||
| @ -108,6 +117,8 @@ class Pipeline: | |||||||
|                 yield chunk.delta.text |                 yield chunk.delta.text | ||||||
| 
 | 
 | ||||||
|     def get_completion(self, model_id: str, messages: List[dict], body: dict) -> str: |     def get_completion(self, model_id: str, messages: List[dict], body: dict) -> str: | ||||||
|  |         system_message, messages = pop_system_message(messages) | ||||||
|  | 
 | ||||||
|         max_tokens = ( |         max_tokens = ( | ||||||
|             body.get("max_tokens") if body.get("max_tokens") is not None else 4096 |             body.get("max_tokens") if body.get("max_tokens") is not None else 4096 | ||||||
|         ) |         ) | ||||||
| @ -119,12 +130,17 @@ class Pipeline: | |||||||
|         stop_sequences = body.get("stop") if body.get("stop") is not None else [] |         stop_sequences = body.get("stop") if body.get("stop") is not None else [] | ||||||
| 
 | 
 | ||||||
|         response = self.client.messages.create( |         response = self.client.messages.create( | ||||||
|             model=model_id, |             **{ | ||||||
|             messages=messages, |                 "model": model_id, | ||||||
|             max_tokens=max_tokens, |                 **( | ||||||
|             temperature=temperature, |                     {"system": system_message} if system_message else {} | ||||||
|             top_k=top_k, |                 ),  # Add system message if it exists (optional | ||||||
|             top_p=top_p, |                 "messages": messages, | ||||||
|             stop_sequences=stop_sequences, |                 "max_tokens": max_tokens, | ||||||
|  |                 "temperature": temperature, | ||||||
|  |                 "top_k": top_k, | ||||||
|  |                 "top_p": top_p, | ||||||
|  |                 "stop_sequences": stop_sequences, | ||||||
|  |             } | ||||||
|         ) |         ) | ||||||
|         return response.content[0].text |         return response.content[0].text | ||||||
|  | |||||||
| @ -5,7 +5,7 @@ from typing import List | |||||||
| from schemas import OpenAIChatMessage | from schemas import OpenAIChatMessage | ||||||
| 
 | 
 | ||||||
| import inspect | import inspect | ||||||
| from typing import get_type_hints, Literal | from typing import get_type_hints, Literal, Tuple | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def stream_message_template(model: str, message: str): | def stream_message_template(model: str, message: str): | ||||||
| @ -47,6 +47,21 @@ def get_last_assistant_message(messages: List[dict]) -> str: | |||||||
|     return None |     return None | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | def get_system_message(messages: List[dict]) -> dict: | ||||||
|  |     for message in messages: | ||||||
|  |         if message["role"] == "system": | ||||||
|  |             return message | ||||||
|  |     return None | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def remove_system_message(messages: List[dict]) -> List[dict]: | ||||||
|  |     return [message for message in messages if message["role"] != "system"] | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def pop_system_message(messages: List[dict]) -> Tuple[dict, List[dict]]: | ||||||
|  |     return get_system_message(messages), remove_system_message(messages) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| def add_or_update_system_message(content: str, messages: List[dict]): | def add_or_update_system_message(content: str, messages: List[dict]): | ||||||
|     """ |     """ | ||||||
|     Adds a new system message at the beginning of the messages list |     Adds a new system message at the beginning of the messages list | ||||||
|  | |||||||
		Loading…
	
		Reference in New Issue
	
	Block a user