refac: colbert cuda support

This commit is contained in:
Timothy J. Baek 2024-09-16 12:42:48 +02:00
parent b7f0759485
commit 06debb322b

View File

@ -203,12 +203,19 @@ def update_reranking_model(
class Colbert: class Colbert:
def __init__(self, name) -> None: def __init__(self, name) -> None:
self.ckpt = Checkpoint(name, colbert_config=ColBERTConfig()) self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.ckpt = Checkpoint(name, colbert_config=ColBERTConfig()).to(
self.device
)
pass pass
def calculate_similarity_scores( def calculate_similarity_scores(
self, query_embeddings, document_embeddings self, query_embeddings, document_embeddings
): ):
query_embeddings = query_embeddings.to(self.device)
document_embeddings = document_embeddings.to(self.device)
# Validate dimensions to ensure compatibility # Validate dimensions to ensure compatibility
if query_embeddings.dim() != 3: if query_embeddings.dim() != 3:
raise ValueError( raise ValueError(
@ -237,7 +244,7 @@ def update_reranking_model(
normalized_scores = torch.softmax(final_scores, dim=0) normalized_scores = torch.softmax(final_scores, dim=0)
return normalized_scores.numpy().astype(np.float32) return normalized_scores.detach().cpu().numpy().astype(np.float32)
def predict(self, sentences): def predict(self, sentences):