mirror of
https://github.com/open-webui/open-webui
synced 2024-11-22 08:07:55 +00:00
refac: colbert cuda support
This commit is contained in:
parent
b7f0759485
commit
06debb322b
@ -203,12 +203,19 @@ def update_reranking_model(
|
||||
|
||||
class Colbert:
|
||||
def __init__(self, name) -> None:
|
||||
self.ckpt = Checkpoint(name, colbert_config=ColBERTConfig())
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
self.ckpt = Checkpoint(name, colbert_config=ColBERTConfig()).to(
|
||||
self.device
|
||||
)
|
||||
pass
|
||||
|
||||
def calculate_similarity_scores(
|
||||
self, query_embeddings, document_embeddings
|
||||
):
|
||||
|
||||
query_embeddings = query_embeddings.to(self.device)
|
||||
document_embeddings = document_embeddings.to(self.device)
|
||||
|
||||
# Validate dimensions to ensure compatibility
|
||||
if query_embeddings.dim() != 3:
|
||||
raise ValueError(
|
||||
@ -237,7 +244,7 @@ def update_reranking_model(
|
||||
|
||||
normalized_scores = torch.softmax(final_scores, dim=0)
|
||||
|
||||
return normalized_scores.numpy().astype(np.float32)
|
||||
return normalized_scores.detach().cpu().numpy().astype(np.float32)
|
||||
|
||||
def predict(self, sentences):
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user