mirror of
https://github.com/open-webui/open-webui
synced 2025-03-22 22:07:15 +00:00
enh: colbert rerank support
This commit is contained in:
parent
db0c576f48
commit
b38986a0aa
@ -10,6 +10,9 @@ from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Iterator, Optional, Sequence, Union
|
||||
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import requests
|
||||
import validators
|
||||
|
||||
@ -114,6 +117,8 @@ from langchain_community.document_loaders import (
|
||||
YoutubeLoader,
|
||||
)
|
||||
from langchain_core.documents import Document
|
||||
from colbert.infra import ColBERTConfig
|
||||
from colbert.modeling.checkpoint import Checkpoint
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
@ -193,18 +198,76 @@ def update_reranking_model(
|
||||
update_model: bool = False,
|
||||
):
|
||||
if reranking_model:
|
||||
import sentence_transformers
|
||||
if reranking_model in ["jinaai/jina-colbert-v2"]:
|
||||
|
||||
try:
|
||||
app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder(
|
||||
get_model_path(reranking_model, update_model),
|
||||
device=DEVICE_TYPE,
|
||||
trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
|
||||
)
|
||||
except:
|
||||
log.error("CrossEncoder error")
|
||||
app.state.sentence_transformer_rf = None
|
||||
app.state.config.ENABLE_RAG_HYBRID_SEARCH = False
|
||||
class Colbert:
|
||||
def __init__(self, name) -> None:
|
||||
self.ckpt = Checkpoint(name, colbert_config=ColBERTConfig())
|
||||
pass
|
||||
|
||||
def calculate_similarity_scores(query_embeddings, document_embeddings):
|
||||
# Validate dimensions to ensure compatibility
|
||||
if query_embeddings.dim() != 3:
|
||||
raise ValueError(
|
||||
f"Expected query embeddings to have 3 dimensions, but got {query_embeddings.dim()}."
|
||||
)
|
||||
if document_embeddings.dim() != 3:
|
||||
raise ValueError(
|
||||
f"Expected document embeddings to have 3 dimensions, but got {document_embeddings.dim()}."
|
||||
)
|
||||
if query_embeddings.size(0) not in [1, document_embeddings.size(0)]:
|
||||
raise ValueError(
|
||||
"There should be either one query or queries equal to the number of documents."
|
||||
)
|
||||
|
||||
# Transpose the query embeddings to align for matrix multiplication
|
||||
transposed_query_embeddings = query_embeddings.permute(0, 2, 1)
|
||||
# Compute similarity scores using batch matrix multiplication
|
||||
computed_scores = torch.matmul(
|
||||
document_embeddings, transposed_query_embeddings
|
||||
)
|
||||
# Apply max pooling to extract the highest semantic similarity across each document's sequence
|
||||
maximum_scores = torch.max(computed_scores, dim=1).values
|
||||
|
||||
# Sum up the maximum scores across features to get the overall document relevance scores
|
||||
final_scores = maximum_scores.sum(dim=1)
|
||||
|
||||
normalized_scores = torch.softmax(final_scores, dim=0)
|
||||
|
||||
return normalized_scores.numpy().astype(np.float32)
|
||||
|
||||
def predict(self, sentences):
|
||||
|
||||
query = sentences[0][0]
|
||||
docs = [i[1] for i in sentences]
|
||||
|
||||
# Embedding the documents
|
||||
embedded_docs = self.ckpt.docFromText(docs, bsize=32)[0]
|
||||
# Embedding the queries
|
||||
embedded_queries = self.ckpt.queryFromText([query], bsize=32)
|
||||
embedded_query = embedded_queries[0]
|
||||
|
||||
# Calculate retrieval scores for the query against all documents
|
||||
scores = self.calculate_similarity_scores(
|
||||
embedded_query.unsqueeze(0), embedded_docs
|
||||
)
|
||||
|
||||
return scores
|
||||
|
||||
app.state.sentence_transformer_rf = Colbert(reranking_model)
|
||||
else:
|
||||
import sentence_transformers
|
||||
|
||||
try:
|
||||
app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder(
|
||||
get_model_path(reranking_model, update_model),
|
||||
device=DEVICE_TYPE,
|
||||
trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
|
||||
)
|
||||
except:
|
||||
log.error("CrossEncoder error")
|
||||
app.state.sentence_transformer_rf = None
|
||||
app.state.config.ENABLE_RAG_HYBRID_SEARCH = False
|
||||
else:
|
||||
app.state.sentence_transformer_rf = None
|
||||
|
||||
|
@ -232,8 +232,7 @@ def query_collection_with_hybrid_search(
|
||||
|
||||
if error:
|
||||
raise Exception(
|
||||
"Hybrid search failed for all collections. Using "
|
||||
"Non hybrid search as fallback."
|
||||
"Hybrid search failed for all collections. Using Non hybrid search as fallback."
|
||||
)
|
||||
|
||||
return merge_and_sort_query_results(results, k=k, reverse=True)
|
||||
|
Loading…
Reference in New Issue
Block a user