From 042b9500a88a0f0663ff7ab06bc1d892a3ee32b4 Mon Sep 17 00:00:00 2001 From: Xiaoming Liu Date: Thu, 12 Sep 2024 14:16:10 +0800 Subject: [PATCH] fix typo --- litgpt/generate/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/litgpt/generate/base.py b/litgpt/generate/base.py index ad59048..2c5d23b 100644 --- a/litgpt/generate/base.py +++ b/litgpt/generate/base.py @@ -19,7 +19,7 @@ def multinomial_num_samples_1(probs: torch.Tensor) -> torch.Tensor: return torch.multinomial(probs, num_samples=1) -def sample_top_p(logits_A: torch.Tensor, top_p: float) -> torch.Tensor: +def sample_top_p(logits: torch.Tensor, top_p: float) -> torch.Tensor: sorted_logits, sorted_indices = torch.sort(logits, descending=False) cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) # Example: