fix request param handling

This commit is contained in:
Justin Hayes 2024-05-29 09:49:16 -04:00 committed by GitHub
parent 3c8ce1a03b
commit 420ecff7ce
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -71,14 +71,20 @@ class Pipeline:
def stream_response( def stream_response(
self, model_id: str, messages: List[dict], body: dict self, model_id: str, messages: List[dict], body: dict
) -> Generator: ) -> Generator:
max_tokens = body.get("max_tokens") if body.get("max_tokens") is not None else 4096
temperature = body.get("temperature") if body.get("temperature") is not None else 0.8
top_k = body.get("top_k") if body.get("top_k") is not None else 40
top_p = body.get("top_p") if body.get("top_p") is not None else 0.9
stop_sequences = body.get("stop") if body.get("stop") is not None else []
stream = self.client.messages.create( stream = self.client.messages.create(
model=model_id, model=model_id,
messages=messages, messages=messages,
max_tokens=body.get("max_tokens", 4096), max_tokens=max_tokens,
temperature=body.get("temperature", 0.8), temperature=temperature,
top_k=body.get("top_k", 40), top_k=top_k,
top_p=body.get("top_p", 0.9), top_p=top_p,
stop_sequences=body.get("stop", []), stop_sequences=stop_sequences,
stream=True, stream=True,
) )
@ -89,13 +95,19 @@ class Pipeline:
yield chunk.delta.text yield chunk.delta.text
def get_completion(self, model_id: str, messages: List[dict], body: dict) -> str: def get_completion(self, model_id: str, messages: List[dict], body: dict) -> str:
max_tokens = body.get("max_tokens") if body.get("max_tokens") is not None else 4096
temperature = body.get("temperature") if body.get("temperature") is not None else 0.8
top_k = body.get("top_k") if body.get("top_k") is not None else 40
top_p = body.get("top_p") if body.get("top_p") is not None else 0.9
stop_sequences = body.get("stop") if body.get("stop") is not None else []
response = self.client.messages.create( response = self.client.messages.create(
model=model_id, model=model_id,
messages=messages, messages=messages,
max_tokens=body.get("max_tokens", 4096), max_tokens=max_tokens,
temperature=body.get("temperature", 0.8), temperature=temperature,
top_k=body.get("top_k", 40), top_k=top_k,
top_p=body.get("top_p", 0.9), top_p=top_p,
stop_sequences=body.get("stop", []), stop_sequences=stop_sequences,
) )
return response.content[0].text return response.content[0].text