refac: colbert cuda support

This commit is contained in:
Timothy J. Baek 2024-09-16 12:42:48 +02:00
parent b7f0759485
commit 06debb322b
1 changed files with 9 additions and 2 deletions

View File

@ -203,12 +203,19 @@ def update_reranking_model(
class Colbert:
def __init__(self, name) -> None:
self.ckpt = Checkpoint(name, colbert_config=ColBERTConfig())
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.ckpt = Checkpoint(name, colbert_config=ColBERTConfig()).to(
self.device
)
pass
def calculate_similarity_scores(
self, query_embeddings, document_embeddings
):
query_embeddings = query_embeddings.to(self.device)
document_embeddings = document_embeddings.to(self.device)
# Validate dimensions to ensure compatibility
if query_embeddings.dim() != 3:
raise ValueError(
@ -237,7 +244,7 @@ def update_reranking_model(
normalized_scores = torch.softmax(final_scores, dim=0)
return normalized_scores.numpy().astype(np.float32)
return normalized_scores.detach().cpu().numpy().astype(np.float32)
def predict(self, sentences):