fix request param handling

This commit is contained in:
Justin Hayes 2024-05-29 09:49:16 -04:00 committed by GitHub
parent 3c8ce1a03b
commit 420ecff7ce
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -71,14 +71,20 @@ class Pipeline:
def stream_response(
self, model_id: str, messages: List[dict], body: dict
) -> Generator:
max_tokens = body.get("max_tokens") if body.get("max_tokens") is not None else 4096
temperature = body.get("temperature") if body.get("temperature") is not None else 0.8
top_k = body.get("top_k") if body.get("top_k") is not None else 40
top_p = body.get("top_p") if body.get("top_p") is not None else 0.9
stop_sequences = body.get("stop") if body.get("stop") is not None else []
stream = self.client.messages.create(
model=model_id,
messages=messages,
max_tokens=body.get("max_tokens", 4096),
temperature=body.get("temperature", 0.8),
top_k=body.get("top_k", 40),
top_p=body.get("top_p", 0.9),
stop_sequences=body.get("stop", []),
max_tokens=max_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
stop_sequences=stop_sequences,
stream=True,
)
@ -89,13 +95,19 @@ class Pipeline:
yield chunk.delta.text
def get_completion(self, model_id: str, messages: List[dict], body: dict) -> str:
max_tokens = body.get("max_tokens") if body.get("max_tokens") is not None else 4096
temperature = body.get("temperature") if body.get("temperature") is not None else 0.8
top_k = body.get("top_k") if body.get("top_k") is not None else 40
top_p = body.get("top_p") if body.get("top_p") is not None else 0.9
stop_sequences = body.get("stop") if body.get("stop") is not None else []
response = self.client.messages.create(
model=model_id,
messages=messages,
max_tokens=body.get("max_tokens", 4096),
temperature=body.get("temperature", 0.8),
top_k=body.get("top_k", 40),
top_p=body.get("top_p", 0.9),
stop_sequences=body.get("stop", []),
max_tokens=max_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
stop_sequences=stop_sequences,
)
return response.content[0].text