mirror of
https://github.com/open-webui/open-webui
synced 2025-03-28 11:22:31 +00:00
enh: colbert rerank support
This commit is contained in:
parent
db0c576f48
commit
b38986a0aa
@ -10,6 +10,9 @@ from datetime import datetime
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Iterator, Optional, Sequence, Union
|
from typing import Iterator, Optional, Sequence, Union
|
||||||
|
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
import requests
|
import requests
|
||||||
import validators
|
import validators
|
||||||
|
|
||||||
@ -114,6 +117,8 @@ from langchain_community.document_loaders import (
|
|||||||
YoutubeLoader,
|
YoutubeLoader,
|
||||||
)
|
)
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
|
from colbert.infra import ColBERTConfig
|
||||||
|
from colbert.modeling.checkpoint import Checkpoint
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
@ -193,6 +198,64 @@ def update_reranking_model(
|
|||||||
update_model: bool = False,
|
update_model: bool = False,
|
||||||
):
|
):
|
||||||
if reranking_model:
|
if reranking_model:
|
||||||
|
if reranking_model in ["jinaai/jina-colbert-v2"]:
|
||||||
|
|
||||||
|
class Colbert:
|
||||||
|
def __init__(self, name) -> None:
|
||||||
|
self.ckpt = Checkpoint(name, colbert_config=ColBERTConfig())
|
||||||
|
pass
|
||||||
|
|
||||||
|
def calculate_similarity_scores(query_embeddings, document_embeddings):
|
||||||
|
# Validate dimensions to ensure compatibility
|
||||||
|
if query_embeddings.dim() != 3:
|
||||||
|
raise ValueError(
|
||||||
|
f"Expected query embeddings to have 3 dimensions, but got {query_embeddings.dim()}."
|
||||||
|
)
|
||||||
|
if document_embeddings.dim() != 3:
|
||||||
|
raise ValueError(
|
||||||
|
f"Expected document embeddings to have 3 dimensions, but got {document_embeddings.dim()}."
|
||||||
|
)
|
||||||
|
if query_embeddings.size(0) not in [1, document_embeddings.size(0)]:
|
||||||
|
raise ValueError(
|
||||||
|
"There should be either one query or queries equal to the number of documents."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Transpose the query embeddings to align for matrix multiplication
|
||||||
|
transposed_query_embeddings = query_embeddings.permute(0, 2, 1)
|
||||||
|
# Compute similarity scores using batch matrix multiplication
|
||||||
|
computed_scores = torch.matmul(
|
||||||
|
document_embeddings, transposed_query_embeddings
|
||||||
|
)
|
||||||
|
# Apply max pooling to extract the highest semantic similarity across each document's sequence
|
||||||
|
maximum_scores = torch.max(computed_scores, dim=1).values
|
||||||
|
|
||||||
|
# Sum up the maximum scores across features to get the overall document relevance scores
|
||||||
|
final_scores = maximum_scores.sum(dim=1)
|
||||||
|
|
||||||
|
normalized_scores = torch.softmax(final_scores, dim=0)
|
||||||
|
|
||||||
|
return normalized_scores.numpy().astype(np.float32)
|
||||||
|
|
||||||
|
def predict(self, sentences):
|
||||||
|
|
||||||
|
query = sentences[0][0]
|
||||||
|
docs = [i[1] for i in sentences]
|
||||||
|
|
||||||
|
# Embedding the documents
|
||||||
|
embedded_docs = self.ckpt.docFromText(docs, bsize=32)[0]
|
||||||
|
# Embedding the queries
|
||||||
|
embedded_queries = self.ckpt.queryFromText([query], bsize=32)
|
||||||
|
embedded_query = embedded_queries[0]
|
||||||
|
|
||||||
|
# Calculate retrieval scores for the query against all documents
|
||||||
|
scores = self.calculate_similarity_scores(
|
||||||
|
embedded_query.unsqueeze(0), embedded_docs
|
||||||
|
)
|
||||||
|
|
||||||
|
return scores
|
||||||
|
|
||||||
|
app.state.sentence_transformer_rf = Colbert(reranking_model)
|
||||||
|
else:
|
||||||
import sentence_transformers
|
import sentence_transformers
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -232,8 +232,7 @@ def query_collection_with_hybrid_search(
|
|||||||
|
|
||||||
if error:
|
if error:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
"Hybrid search failed for all collections. Using "
|
"Hybrid search failed for all collections. Using Non hybrid search as fallback."
|
||||||
"Non hybrid search as fallback."
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return merge_and_sort_query_results(results, k=k, reverse=True)
|
return merge_and_sort_query_results(results, k=k, reverse=True)
|
||||||
|
Loading…
Reference in New Issue
Block a user