diff --git a/backend/open_webui/apps/rag/main.py b/backend/open_webui/apps/rag/main.py index c605b7211..2acb4fde9 100644 --- a/backend/open_webui/apps/rag/main.py +++ b/backend/open_webui/apps/rag/main.py @@ -10,6 +10,9 @@ from datetime import datetime from pathlib import Path from typing import Iterator, Optional, Sequence, Union + +import numpy as np +import torch import requests import validators @@ -114,6 +117,8 @@ from langchain_community.document_loaders import ( YoutubeLoader, ) from langchain_core.documents import Document +from colbert.infra import ColBERTConfig +from colbert.modeling.checkpoint import Checkpoint log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) @@ -193,18 +198,76 @@ def update_reranking_model( update_model: bool = False, ): if reranking_model: - import sentence_transformers + if reranking_model in ["jinaai/jina-colbert-v2"]: - try: - app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder( - get_model_path(reranking_model, update_model), - device=DEVICE_TYPE, - trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE, - ) - except: - log.error("CrossEncoder error") - app.state.sentence_transformer_rf = None - app.state.config.ENABLE_RAG_HYBRID_SEARCH = False + class Colbert: + def __init__(self, name) -> None: + self.ckpt = Checkpoint(name, colbert_config=ColBERTConfig()) + pass + + def calculate_similarity_scores(query_embeddings, document_embeddings): + # Validate dimensions to ensure compatibility + if query_embeddings.dim() != 3: + raise ValueError( + f"Expected query embeddings to have 3 dimensions, but got {query_embeddings.dim()}." + ) + if document_embeddings.dim() != 3: + raise ValueError( + f"Expected document embeddings to have 3 dimensions, but got {document_embeddings.dim()}." + ) + if query_embeddings.size(0) not in [1, document_embeddings.size(0)]: + raise ValueError( + "There should be either one query or queries equal to the number of documents." + ) + + # Transpose the query embeddings to align for matrix multiplication + transposed_query_embeddings = query_embeddings.permute(0, 2, 1) + # Compute similarity scores using batch matrix multiplication + computed_scores = torch.matmul( + document_embeddings, transposed_query_embeddings + ) + # Apply max pooling to extract the highest semantic similarity across each document's sequence + maximum_scores = torch.max(computed_scores, dim=1).values + + # Sum up the maximum scores across features to get the overall document relevance scores + final_scores = maximum_scores.sum(dim=1) + + normalized_scores = torch.softmax(final_scores, dim=0) + + return normalized_scores.numpy().astype(np.float32) + + def predict(self, sentences): + + query = sentences[0][0] + docs = [i[1] for i in sentences] + + # Embedding the documents + embedded_docs = self.ckpt.docFromText(docs, bsize=32)[0] + # Embedding the queries + embedded_queries = self.ckpt.queryFromText([query], bsize=32) + embedded_query = embedded_queries[0] + + # Calculate retrieval scores for the query against all documents + scores = self.calculate_similarity_scores( + embedded_query.unsqueeze(0), embedded_docs + ) + + return scores + + app.state.sentence_transformer_rf = Colbert(reranking_model) + else: + import sentence_transformers + + try: + app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder( + get_model_path(reranking_model, update_model), + device=DEVICE_TYPE, + trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE, + ) + except: + log.error("CrossEncoder error") + app.state.sentence_transformer_rf = None + app.state.config.ENABLE_RAG_HYBRID_SEARCH = False else: app.state.sentence_transformer_rf = None diff --git a/backend/open_webui/apps/rag/utils.py b/backend/open_webui/apps/rag/utils.py index bdec5f9da..7c0d02984 100644 --- a/backend/open_webui/apps/rag/utils.py +++ b/backend/open_webui/apps/rag/utils.py @@ -232,8 +232,7 @@ def query_collection_with_hybrid_search( if error: raise Exception( - "Hybrid search failed for all collections. Using " - "Non hybrid search as fallback." + "Hybrid search failed for all collections. Using Non hybrid search as fallback." ) return merge_and_sort_query_results(results, k=k, reverse=True)