diff --git a/backend/open_webui/apps/rag/main.py b/backend/open_webui/apps/rag/main.py index 423c1ab3c..c0495f3fa 100644 --- a/backend/open_webui/apps/rag/main.py +++ b/backend/open_webui/apps/rag/main.py @@ -203,12 +203,19 @@ def update_reranking_model( class Colbert: def __init__(self, name) -> None: - self.ckpt = Checkpoint(name, colbert_config=ColBERTConfig()) + self.device = "cuda" if torch.cuda.is_available() else "cpu" + self.ckpt = Checkpoint(name, colbert_config=ColBERTConfig()).to( + self.device + ) pass def calculate_similarity_scores( self, query_embeddings, document_embeddings ): + + query_embeddings = query_embeddings.to(self.device) + document_embeddings = document_embeddings.to(self.device) + # Validate dimensions to ensure compatibility if query_embeddings.dim() != 3: raise ValueError( @@ -237,7 +244,7 @@ def update_reranking_model( normalized_scores = torch.softmax(final_scores, dim=0) - return normalized_scores.numpy().astype(np.float32) + return normalized_scores.detach().cpu().numpy().astype(np.float32) def predict(self, sentences):