diff --git a/litgpt/generate/base.py b/litgpt/generate/base.py index ad59048..2c5d23b 100644 --- a/litgpt/generate/base.py +++ b/litgpt/generate/base.py @@ -19,7 +19,7 @@ def multinomial_num_samples_1(probs: torch.Tensor) -> torch.Tensor: return torch.multinomial(probs, num_samples=1) -def sample_top_p(logits_A: torch.Tensor, top_p: float) -> torch.Tensor: +def sample_top_p(logits: torch.Tensor, top_p: float) -> torch.Tensor: sorted_logits, sorted_indices = torch.sort(logits, descending=False) cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) # Example: