enh: colbert rerank support

This commit is contained in:
Timothy J. Baek 2024-09-16 11:46:39 +02:00
parent db0c576f48
commit b38986a0aa
2 changed files with 75 additions and 13 deletions

View File

@ -10,6 +10,9 @@ from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Iterator, Optional, Sequence, Union from typing import Iterator, Optional, Sequence, Union
import numpy as np
import torch
import requests import requests
import validators import validators
@ -114,6 +117,8 @@ from langchain_community.document_loaders import (
YoutubeLoader, YoutubeLoader,
) )
from langchain_core.documents import Document from langchain_core.documents import Document
from colbert.infra import ColBERTConfig
from colbert.modeling.checkpoint import Checkpoint
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"]) log.setLevel(SRC_LOG_LEVELS["RAG"])
@ -193,6 +198,64 @@ def update_reranking_model(
update_model: bool = False, update_model: bool = False,
): ):
if reranking_model: if reranking_model:
if reranking_model in ["jinaai/jina-colbert-v2"]:
class Colbert:
def __init__(self, name) -> None:
self.ckpt = Checkpoint(name, colbert_config=ColBERTConfig())
pass
def calculate_similarity_scores(query_embeddings, document_embeddings):
# Validate dimensions to ensure compatibility
if query_embeddings.dim() != 3:
raise ValueError(
f"Expected query embeddings to have 3 dimensions, but got {query_embeddings.dim()}."
)
if document_embeddings.dim() != 3:
raise ValueError(
f"Expected document embeddings to have 3 dimensions, but got {document_embeddings.dim()}."
)
if query_embeddings.size(0) not in [1, document_embeddings.size(0)]:
raise ValueError(
"There should be either one query or queries equal to the number of documents."
)
# Transpose the query embeddings to align for matrix multiplication
transposed_query_embeddings = query_embeddings.permute(0, 2, 1)
# Compute similarity scores using batch matrix multiplication
computed_scores = torch.matmul(
document_embeddings, transposed_query_embeddings
)
# Apply max pooling to extract the highest semantic similarity across each document's sequence
maximum_scores = torch.max(computed_scores, dim=1).values
# Sum up the maximum scores across features to get the overall document relevance scores
final_scores = maximum_scores.sum(dim=1)
normalized_scores = torch.softmax(final_scores, dim=0)
return normalized_scores.numpy().astype(np.float32)
def predict(self, sentences):
query = sentences[0][0]
docs = [i[1] for i in sentences]
# Embedding the documents
embedded_docs = self.ckpt.docFromText(docs, bsize=32)[0]
# Embedding the queries
embedded_queries = self.ckpt.queryFromText([query], bsize=32)
embedded_query = embedded_queries[0]
# Calculate retrieval scores for the query against all documents
scores = self.calculate_similarity_scores(
embedded_query.unsqueeze(0), embedded_docs
)
return scores
app.state.sentence_transformer_rf = Colbert(reranking_model)
else:
import sentence_transformers import sentence_transformers
try: try:

View File

@ -232,8 +232,7 @@ def query_collection_with_hybrid_search(
if error: if error:
raise Exception( raise Exception(
"Hybrid search failed for all collections. Using " "Hybrid search failed for all collections. Using Non hybrid search as fallback."
"Non hybrid search as fallback."
) )
return merge_and_sort_query_results(results, k=k, reverse=True) return merge_and_sort_query_results(results, k=k, reverse=True)