mirror of
https://github.com/gpt-omni/mini-omni
synced 2024-11-21 15:27:37 +00:00
fix typo
This commit is contained in:
parent
e95181ba9c
commit
042b9500a8
@ -19,7 +19,7 @@ def multinomial_num_samples_1(probs: torch.Tensor) -> torch.Tensor:
|
|||||||
return torch.multinomial(probs, num_samples=1)
|
return torch.multinomial(probs, num_samples=1)
|
||||||
|
|
||||||
|
|
||||||
def sample_top_p(logits_A: torch.Tensor, top_p: float) -> torch.Tensor:
|
def sample_top_p(logits: torch.Tensor, top_p: float) -> torch.Tensor:
|
||||||
sorted_logits, sorted_indices = torch.sort(logits, descending=False)
|
sorted_logits, sorted_indices = torch.sort(logits, descending=False)
|
||||||
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
|
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
|
||||||
# Example:
|
# Example:
|
||||||
|
Loading…
Reference in New Issue
Block a user