mirror of
https://github.com/open-webui/open-webui
synced 2025-04-22 15:33:21 +00:00
refac: colbert cuda support
This commit is contained in:
parent
b7f0759485
commit
06debb322b
@ -203,12 +203,19 @@ def update_reranking_model(
|
|||||||
|
|
||||||
class Colbert:
|
class Colbert:
|
||||||
def __init__(self, name) -> None:
|
def __init__(self, name) -> None:
|
||||||
self.ckpt = Checkpoint(name, colbert_config=ColBERTConfig())
|
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
self.ckpt = Checkpoint(name, colbert_config=ColBERTConfig()).to(
|
||||||
|
self.device
|
||||||
|
)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def calculate_similarity_scores(
|
def calculate_similarity_scores(
|
||||||
self, query_embeddings, document_embeddings
|
self, query_embeddings, document_embeddings
|
||||||
):
|
):
|
||||||
|
|
||||||
|
query_embeddings = query_embeddings.to(self.device)
|
||||||
|
document_embeddings = document_embeddings.to(self.device)
|
||||||
|
|
||||||
# Validate dimensions to ensure compatibility
|
# Validate dimensions to ensure compatibility
|
||||||
if query_embeddings.dim() != 3:
|
if query_embeddings.dim() != 3:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -237,7 +244,7 @@ def update_reranking_model(
|
|||||||
|
|
||||||
normalized_scores = torch.softmax(final_scores, dim=0)
|
normalized_scores = torch.softmax(final_scores, dim=0)
|
||||||
|
|
||||||
return normalized_scores.numpy().astype(np.float32)
|
return normalized_scores.detach().cpu().numpy().astype(np.float32)
|
||||||
|
|
||||||
def predict(self, sentences):
|
def predict(self, sentences):
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user