diff --git a/pipelines/examples/anthropic_manifold_pipeline.py b/pipelines/examples/anthropic_manifold_pipeline.py index d472a90..6b3d05d 100644 --- a/pipelines/examples/anthropic_manifold_pipeline.py +++ b/pipelines/examples/anthropic_manifold_pipeline.py @@ -71,14 +71,20 @@ class Pipeline: def stream_response( self, model_id: str, messages: List[dict], body: dict ) -> Generator: + max_tokens = body.get("max_tokens") if body.get("max_tokens") is not None else 4096 + temperature = body.get("temperature") if body.get("temperature") is not None else 0.8 + top_k = body.get("top_k") if body.get("top_k") is not None else 40 + top_p = body.get("top_p") if body.get("top_p") is not None else 0.9 + stop_sequences = body.get("stop") if body.get("stop") is not None else [] + stream = self.client.messages.create( model=model_id, messages=messages, - max_tokens=body.get("max_tokens", 4096), - temperature=body.get("temperature", 0.8), - top_k=body.get("top_k", 40), - top_p=body.get("top_p", 0.9), - stop_sequences=body.get("stop", []), + max_tokens=max_tokens, + temperature=temperature, + top_k=top_k, + top_p=top_p, + stop_sequences=stop_sequences, stream=True, ) @@ -89,13 +95,19 @@ class Pipeline: yield chunk.delta.text def get_completion(self, model_id: str, messages: List[dict], body: dict) -> str: + max_tokens = body.get("max_tokens") if body.get("max_tokens") is not None else 4096 + temperature = body.get("temperature") if body.get("temperature") is not None else 0.8 + top_k = body.get("top_k") if body.get("top_k") is not None else 40 + top_p = body.get("top_p") if body.get("top_p") is not None else 0.9 + stop_sequences = body.get("stop") if body.get("stop") is not None else [] + response = self.client.messages.create( model=model_id, messages=messages, - max_tokens=body.get("max_tokens", 4096), - temperature=body.get("temperature", 0.8), - top_k=body.get("top_k", 40), - top_p=body.get("top_p", 0.9), - stop_sequences=body.get("stop", []), + max_tokens=max_tokens, + temperature=temperature, + top_k=top_k, + top_p=top_p, + stop_sequences=stop_sequences, ) return response.content[0].text