enh: colbert rerank support

This commit is contained in:
Timothy J. Baek 2024-09-16 11:46:39 +02:00
parent db0c576f48
commit b38986a0aa
2 changed files with 75 additions and 13 deletions

View File

@ -10,6 +10,9 @@ from datetime import datetime
from pathlib import Path
from typing import Iterator, Optional, Sequence, Union
import numpy as np
import torch
import requests
import validators
@ -114,6 +117,8 @@ from langchain_community.document_loaders import (
YoutubeLoader,
)
from langchain_core.documents import Document
from colbert.infra import ColBERTConfig
from colbert.modeling.checkpoint import Checkpoint
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
@ -193,18 +198,76 @@ def update_reranking_model(
update_model: bool = False,
):
if reranking_model:
import sentence_transformers
if reranking_model in ["jinaai/jina-colbert-v2"]:
try:
app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder(
get_model_path(reranking_model, update_model),
device=DEVICE_TYPE,
trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
)
except:
log.error("CrossEncoder error")
app.state.sentence_transformer_rf = None
app.state.config.ENABLE_RAG_HYBRID_SEARCH = False
class Colbert:
def __init__(self, name) -> None:
self.ckpt = Checkpoint(name, colbert_config=ColBERTConfig())
pass
def calculate_similarity_scores(query_embeddings, document_embeddings):
# Validate dimensions to ensure compatibility
if query_embeddings.dim() != 3:
raise ValueError(
f"Expected query embeddings to have 3 dimensions, but got {query_embeddings.dim()}."
)
if document_embeddings.dim() != 3:
raise ValueError(
f"Expected document embeddings to have 3 dimensions, but got {document_embeddings.dim()}."
)
if query_embeddings.size(0) not in [1, document_embeddings.size(0)]:
raise ValueError(
"There should be either one query or queries equal to the number of documents."
)
# Transpose the query embeddings to align for matrix multiplication
transposed_query_embeddings = query_embeddings.permute(0, 2, 1)
# Compute similarity scores using batch matrix multiplication
computed_scores = torch.matmul(
document_embeddings, transposed_query_embeddings
)
# Apply max pooling to extract the highest semantic similarity across each document's sequence
maximum_scores = torch.max(computed_scores, dim=1).values
# Sum up the maximum scores across features to get the overall document relevance scores
final_scores = maximum_scores.sum(dim=1)
normalized_scores = torch.softmax(final_scores, dim=0)
return normalized_scores.numpy().astype(np.float32)
def predict(self, sentences):
query = sentences[0][0]
docs = [i[1] for i in sentences]
# Embedding the documents
embedded_docs = self.ckpt.docFromText(docs, bsize=32)[0]
# Embedding the queries
embedded_queries = self.ckpt.queryFromText([query], bsize=32)
embedded_query = embedded_queries[0]
# Calculate retrieval scores for the query against all documents
scores = self.calculate_similarity_scores(
embedded_query.unsqueeze(0), embedded_docs
)
return scores
app.state.sentence_transformer_rf = Colbert(reranking_model)
else:
import sentence_transformers
try:
app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder(
get_model_path(reranking_model, update_model),
device=DEVICE_TYPE,
trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
)
except:
log.error("CrossEncoder error")
app.state.sentence_transformer_rf = None
app.state.config.ENABLE_RAG_HYBRID_SEARCH = False
else:
app.state.sentence_transformer_rf = None

View File

@ -232,8 +232,7 @@ def query_collection_with_hybrid_search(
if error:
raise Exception(
"Hybrid search failed for all collections. Using "
"Non hybrid search as fallback."
"Hybrid search failed for all collections. Using Non hybrid search as fallback."
)
return merge_and_sort_query_results(results, k=k, reverse=True)