From aac25eac9ebf7037008a7e89cdff4eb38a0efd2a Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Fri, 23 May 2025 01:29:48 +0400 Subject: [PATCH] refac: reranker Co-Authored-By: Tornike Gurgenidze --- backend/open_webui/retrieval/models/base_reranker.py | 8 ++++++++ backend/open_webui/retrieval/models/colbert.py | 4 +++- backend/open_webui/retrieval/models/external.py | 4 +++- 3 files changed, 14 insertions(+), 2 deletions(-) create mode 100644 backend/open_webui/retrieval/models/base_reranker.py diff --git a/backend/open_webui/retrieval/models/base_reranker.py b/backend/open_webui/retrieval/models/base_reranker.py new file mode 100644 index 000000000..6be7a5649 --- /dev/null +++ b/backend/open_webui/retrieval/models/base_reranker.py @@ -0,0 +1,8 @@ +from abc import ABC, abstractmethod +from typing import Optional, List, Tuple + + +class BaseReranker(ABC): + @abstractmethod + def predict(self, sentences: List[Tuple[str, str]]) -> Optional[List[float]]: + pass diff --git a/backend/open_webui/retrieval/models/colbert.py b/backend/open_webui/retrieval/models/colbert.py index 5b7499fd1..7ec888437 100644 --- a/backend/open_webui/retrieval/models/colbert.py +++ b/backend/open_webui/retrieval/models/colbert.py @@ -7,11 +7,13 @@ from colbert.modeling.checkpoint import Checkpoint from open_webui.env import SRC_LOG_LEVELS +from open_webui.retrieval.models.base_reranker import BaseReranker + log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) -class ColBERT: +class ColBERT(BaseReranker): def __init__(self, name, **kwargs) -> None: log.info("ColBERT: Loading model", name) self.device = "cuda" if torch.cuda.is_available() else "cpu" diff --git a/backend/open_webui/retrieval/models/external.py b/backend/open_webui/retrieval/models/external.py index 187d66e38..5ebc3e52e 100644 --- a/backend/open_webui/retrieval/models/external.py +++ b/backend/open_webui/retrieval/models/external.py @@ -3,12 +3,14 @@ import requests from typing import Optional, List, Tuple from open_webui.env import SRC_LOG_LEVELS +from open_webui.retrieval.models.base_reranker import BaseReranker + log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) -class ExternalReranker: +class ExternalReranker(BaseReranker): def __init__( self, api_key: str,