Merge pull request #66 from Lollipop/patch-1

fix typo
This commit is contained in:
mini-omni 2024-09-13 12:33:50 +08:00 committed by GitHub
commit 5f87f73785
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -19,7 +19,7 @@ def multinomial_num_samples_1(probs: torch.Tensor) -> torch.Tensor:
return torch.multinomial(probs, num_samples=1)
def sample_top_p(logits_A: torch.Tensor, top_p: float) -> torch.Tensor:
def sample_top_p(logits: torch.Tensor, top_p: float) -> torch.Tensor:
sorted_logits, sorted_indices = torch.sort(logits, descending=False)
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
# Example: