mirror of
				https://github.com/open-webui/open-webui
				synced 2025-06-26 18:26:48 +00:00 
			
		
		
		
	Add user related headers when calling an external embedding api
This commit is contained in:
		
							parent
							
								
									b72150c881
								
							
						
					
					
						commit
						6ca295ec59
					
				@ -15,8 +15,9 @@ from langchain_core.documents import Document
 | 
			
		||||
from open_webui.config import VECTOR_DB
 | 
			
		||||
from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT
 | 
			
		||||
from open_webui.utils.misc import get_last_user_message
 | 
			
		||||
from open_webui.models.users import UserModel
 | 
			
		||||
 | 
			
		||||
from open_webui.env import SRC_LOG_LEVELS, OFFLINE_MODE
 | 
			
		||||
from open_webui.env import SRC_LOG_LEVELS, OFFLINE_MODE, ENABLE_FORWARD_USER_INFO_HEADERS
 | 
			
		||||
 | 
			
		||||
log = logging.getLogger(__name__)
 | 
			
		||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
 | 
			
		||||
@ -64,6 +65,7 @@ def query_doc(
 | 
			
		||||
    collection_name: str,
 | 
			
		||||
    query_embedding: list[float],
 | 
			
		||||
    k: int,
 | 
			
		||||
    user: UserModel=None
 | 
			
		||||
):
 | 
			
		||||
    try:
 | 
			
		||||
        result = VECTOR_DB_CLIENT.search(
 | 
			
		||||
@ -256,29 +258,32 @@ def get_embedding_function(
 | 
			
		||||
    embedding_function,
 | 
			
		||||
    url,
 | 
			
		||||
    key,
 | 
			
		||||
    embedding_batch_size,
 | 
			
		||||
    embedding_batch_size
 | 
			
		||||
):
 | 
			
		||||
    if embedding_engine == "":
 | 
			
		||||
        return lambda query: embedding_function.encode(query).tolist()
 | 
			
		||||
        return lambda query, user=None: embedding_function.encode(query).tolist()
 | 
			
		||||
    elif embedding_engine in ["ollama", "openai"]:
 | 
			
		||||
        func = lambda query: generate_embeddings(
 | 
			
		||||
        func = lambda query, user=None: generate_embeddings(
 | 
			
		||||
            engine=embedding_engine,
 | 
			
		||||
            model=embedding_model,
 | 
			
		||||
            text=query,
 | 
			
		||||
            url=url,
 | 
			
		||||
            key=key,
 | 
			
		||||
            user=user
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        def generate_multiple(query, func):
 | 
			
		||||
        def generate_multiple(query, user, func):
 | 
			
		||||
            if isinstance(query, list):
 | 
			
		||||
                embeddings = []
 | 
			
		||||
                for i in range(0, len(query), embedding_batch_size):
 | 
			
		||||
                    embeddings.extend(func(query[i : i + embedding_batch_size]))
 | 
			
		||||
                    embeddings.extend(func(query[i : i + embedding_batch_size], user=user))
 | 
			
		||||
                return embeddings
 | 
			
		||||
            else:
 | 
			
		||||
                return func(query)
 | 
			
		||||
                return func(query, user)
 | 
			
		||||
 | 
			
		||||
        return lambda query: generate_multiple(query, func)
 | 
			
		||||
        return lambda query, user=None: generate_multiple(query, user, func)
 | 
			
		||||
    else:
 | 
			
		||||
        raise ValueError(f"Unknown embedding engine: {embedding_engine}")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_sources_from_files(
 | 
			
		||||
@ -423,7 +428,7 @@ def get_model_path(model: str, update_model: bool = False):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def generate_openai_batch_embeddings(
 | 
			
		||||
    model: str, texts: list[str], url: str = "https://api.openai.com/v1", key: str = ""
 | 
			
		||||
    model: str, texts: list[str], url: str = "https://api.openai.com/v1", key: str = "", user: UserModel = None
 | 
			
		||||
) -> Optional[list[list[float]]]:
 | 
			
		||||
    try:
 | 
			
		||||
        r = requests.post(
 | 
			
		||||
@ -431,6 +436,16 @@ def generate_openai_batch_embeddings(
 | 
			
		||||
            headers={
 | 
			
		||||
                "Content-Type": "application/json",
 | 
			
		||||
                "Authorization": f"Bearer {key}",
 | 
			
		||||
                **(
 | 
			
		||||
                    {
 | 
			
		||||
                        "X-OpenWebUI-User-Name": user.name,
 | 
			
		||||
                        "X-OpenWebUI-User-Id": user.id,
 | 
			
		||||
                        "X-OpenWebUI-User-Email": user.email,
 | 
			
		||||
                        "X-OpenWebUI-User-Role": user.role,
 | 
			
		||||
                    }
 | 
			
		||||
                    if ENABLE_FORWARD_USER_INFO_HEADERS and user
 | 
			
		||||
                    else {}
 | 
			
		||||
                ),
 | 
			
		||||
            },
 | 
			
		||||
            json={"input": texts, "model": model},
 | 
			
		||||
        )
 | 
			
		||||
@ -446,7 +461,7 @@ def generate_openai_batch_embeddings(
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def generate_ollama_batch_embeddings(
 | 
			
		||||
    model: str, texts: list[str], url: str, key: str = ""
 | 
			
		||||
    model: str, texts: list[str], url: str, key: str = "", user: UserModel = None
 | 
			
		||||
) -> Optional[list[list[float]]]:
 | 
			
		||||
    try:
 | 
			
		||||
        r = requests.post(
 | 
			
		||||
@ -454,6 +469,16 @@ def generate_ollama_batch_embeddings(
 | 
			
		||||
            headers={
 | 
			
		||||
                "Content-Type": "application/json",
 | 
			
		||||
                "Authorization": f"Bearer {key}",
 | 
			
		||||
                **(
 | 
			
		||||
                    {
 | 
			
		||||
                        "X-OpenWebUI-User-Name": user.name,
 | 
			
		||||
                        "X-OpenWebUI-User-Id": user.id,
 | 
			
		||||
                        "X-OpenWebUI-User-Email": user.email,
 | 
			
		||||
                        "X-OpenWebUI-User-Role": user.role,
 | 
			
		||||
                    }
 | 
			
		||||
                    if ENABLE_FORWARD_USER_INFO_HEADERS
 | 
			
		||||
                    else {}
 | 
			
		||||
                ),
 | 
			
		||||
            },
 | 
			
		||||
            json={"input": texts, "model": model},
 | 
			
		||||
        )
 | 
			
		||||
@ -472,22 +497,23 @@ def generate_ollama_batch_embeddings(
 | 
			
		||||
def generate_embeddings(engine: str, model: str, text: Union[str, list[str]], **kwargs):
 | 
			
		||||
    url = kwargs.get("url", "")
 | 
			
		||||
    key = kwargs.get("key", "")
 | 
			
		||||
    user = kwargs.get("user")
 | 
			
		||||
 | 
			
		||||
    if engine == "ollama":
 | 
			
		||||
        if isinstance(text, list):
 | 
			
		||||
            embeddings = generate_ollama_batch_embeddings(
 | 
			
		||||
                **{"model": model, "texts": text, "url": url, "key": key}
 | 
			
		||||
                **{"model": model, "texts": text, "url": url, "key": key, "user": user}
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            embeddings = generate_ollama_batch_embeddings(
 | 
			
		||||
                **{"model": model, "texts": [text], "url": url, "key": key}
 | 
			
		||||
                **{"model": model, "texts": [text], "url": url, "key": key, "user": user}
 | 
			
		||||
            )
 | 
			
		||||
        return embeddings[0] if isinstance(text, str) else embeddings
 | 
			
		||||
    elif engine == "openai":
 | 
			
		||||
        if isinstance(text, list):
 | 
			
		||||
            embeddings = generate_openai_batch_embeddings(model, text, url, key)
 | 
			
		||||
            embeddings = generate_openai_batch_embeddings(model, text, url, key, user)
 | 
			
		||||
        else:
 | 
			
		||||
            embeddings = generate_openai_batch_embeddings(model, [text], url, key)
 | 
			
		||||
            embeddings = generate_openai_batch_embeddings(model, [text], url, key, user)
 | 
			
		||||
 | 
			
		||||
        return embeddings[0] if isinstance(text, str) else embeddings
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -71,7 +71,7 @@ def upload_file(
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            process_file(request, ProcessFileForm(file_id=id))
 | 
			
		||||
            process_file(request, ProcessFileForm(file_id=id), user=user)
 | 
			
		||||
            file_item = Files.get_file_by_id(id=id)
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            log.exception(e)
 | 
			
		||||
@ -193,7 +193,9 @@ async def update_file_data_content_by_id(
 | 
			
		||||
    if file and (file.user_id == user.id or user.role == "admin"):
 | 
			
		||||
        try:
 | 
			
		||||
            process_file(
 | 
			
		||||
                request, ProcessFileForm(file_id=id, content=form_data.content)
 | 
			
		||||
                request,
 | 
			
		||||
                ProcessFileForm(file_id=id, content=form_data.content),
 | 
			
		||||
                user=user
 | 
			
		||||
            )
 | 
			
		||||
            file = Files.get_file_by_id(id=id)
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
 | 
			
		||||
@ -285,7 +285,9 @@ def add_file_to_knowledge_by_id(
 | 
			
		||||
    # Add content to the vector database
 | 
			
		||||
    try:
 | 
			
		||||
        process_file(
 | 
			
		||||
            request, ProcessFileForm(file_id=form_data.file_id, collection_name=id)
 | 
			
		||||
            request,
 | 
			
		||||
            ProcessFileForm(file_id=form_data.file_id, collection_name=id),
 | 
			
		||||
            user=user
 | 
			
		||||
        )
 | 
			
		||||
    except Exception as e:
 | 
			
		||||
        log.debug(e)
 | 
			
		||||
@ -363,7 +365,9 @@ def update_file_from_knowledge_by_id(
 | 
			
		||||
    # Add content to the vector database
 | 
			
		||||
    try:
 | 
			
		||||
        process_file(
 | 
			
		||||
            request, ProcessFileForm(file_id=form_data.file_id, collection_name=id)
 | 
			
		||||
            request,
 | 
			
		||||
            ProcessFileForm(file_id=form_data.file_id, collection_name=id),
 | 
			
		||||
            user=user
 | 
			
		||||
        )
 | 
			
		||||
    except Exception as e:
 | 
			
		||||
        raise HTTPException(
 | 
			
		||||
 | 
			
		||||
@ -57,7 +57,7 @@ async def add_memory(
 | 
			
		||||
            {
 | 
			
		||||
                "id": memory.id,
 | 
			
		||||
                "text": memory.content,
 | 
			
		||||
                "vector": request.app.state.EMBEDDING_FUNCTION(memory.content),
 | 
			
		||||
                "vector": request.app.state.EMBEDDING_FUNCTION(memory.content, user),
 | 
			
		||||
                "metadata": {"created_at": memory.created_at},
 | 
			
		||||
            }
 | 
			
		||||
        ],
 | 
			
		||||
@ -82,7 +82,7 @@ async def query_memory(
 | 
			
		||||
):
 | 
			
		||||
    results = VECTOR_DB_CLIENT.search(
 | 
			
		||||
        collection_name=f"user-memory-{user.id}",
 | 
			
		||||
        vectors=[request.app.state.EMBEDDING_FUNCTION(form_data.content)],
 | 
			
		||||
        vectors=[request.app.state.EMBEDDING_FUNCTION(form_data.content, user)],
 | 
			
		||||
        limit=form_data.k,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
@ -105,7 +105,7 @@ async def reset_memory_from_vector_db(
 | 
			
		||||
            {
 | 
			
		||||
                "id": memory.id,
 | 
			
		||||
                "text": memory.content,
 | 
			
		||||
                "vector": request.app.state.EMBEDDING_FUNCTION(memory.content),
 | 
			
		||||
                "vector": request.app.state.EMBEDDING_FUNCTION(memory.content, user),
 | 
			
		||||
                "metadata": {
 | 
			
		||||
                    "created_at": memory.created_at,
 | 
			
		||||
                    "updated_at": memory.updated_at,
 | 
			
		||||
@ -160,7 +160,7 @@ async def update_memory_by_id(
 | 
			
		||||
                {
 | 
			
		||||
                    "id": memory.id,
 | 
			
		||||
                    "text": memory.content,
 | 
			
		||||
                    "vector": request.app.state.EMBEDDING_FUNCTION(memory.content),
 | 
			
		||||
                    "vector": request.app.state.EMBEDDING_FUNCTION(memory.content, user),
 | 
			
		||||
                    "metadata": {
 | 
			
		||||
                        "created_at": memory.created_at,
 | 
			
		||||
                        "updated_at": memory.updated_at,
 | 
			
		||||
 | 
			
		||||
@ -660,6 +660,7 @@ def save_docs_to_vector_db(
 | 
			
		||||
    overwrite: bool = False,
 | 
			
		||||
    split: bool = True,
 | 
			
		||||
    add: bool = False,
 | 
			
		||||
    user = None,
 | 
			
		||||
) -> bool:
 | 
			
		||||
    def _get_docs_info(docs: list[Document]) -> str:
 | 
			
		||||
        docs_info = set()
 | 
			
		||||
@ -775,7 +776,8 @@ def save_docs_to_vector_db(
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        embeddings = embedding_function(
 | 
			
		||||
            list(map(lambda x: x.replace("\n", " "), texts))
 | 
			
		||||
            list(map(lambda x: x.replace("\n", " "), texts)),
 | 
			
		||||
            user = user
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        items = [
 | 
			
		||||
@ -933,6 +935,7 @@ def process_file(
 | 
			
		||||
                    "hash": hash,
 | 
			
		||||
                },
 | 
			
		||||
                add=(True if form_data.collection_name else False),
 | 
			
		||||
                user=user
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            if result:
 | 
			
		||||
@ -990,7 +993,7 @@ def process_text(
 | 
			
		||||
    text_content = form_data.content
 | 
			
		||||
    log.debug(f"text_content: {text_content}")
 | 
			
		||||
 | 
			
		||||
    result = save_docs_to_vector_db(request, docs, collection_name)
 | 
			
		||||
    result = save_docs_to_vector_db(request, docs, collection_name, user=user)
 | 
			
		||||
    if result:
 | 
			
		||||
        return {
 | 
			
		||||
            "status": True,
 | 
			
		||||
@ -1023,7 +1026,7 @@ def process_youtube_video(
 | 
			
		||||
        content = " ".join([doc.page_content for doc in docs])
 | 
			
		||||
        log.debug(f"text_content: {content}")
 | 
			
		||||
 | 
			
		||||
        save_docs_to_vector_db(request, docs, collection_name, overwrite=True)
 | 
			
		||||
        save_docs_to_vector_db(request, docs, collection_name, overwrite=True, user=user)
 | 
			
		||||
 | 
			
		||||
        return {
 | 
			
		||||
            "status": True,
 | 
			
		||||
@ -1064,7 +1067,7 @@ def process_web(
 | 
			
		||||
        content = " ".join([doc.page_content for doc in docs])
 | 
			
		||||
 | 
			
		||||
        log.debug(f"text_content: {content}")
 | 
			
		||||
        save_docs_to_vector_db(request, docs, collection_name, overwrite=True)
 | 
			
		||||
        save_docs_to_vector_db(request, docs, collection_name, overwrite=True, user=user)
 | 
			
		||||
 | 
			
		||||
        return {
 | 
			
		||||
            "status": True,
 | 
			
		||||
@ -1272,7 +1275,7 @@ def process_web_search(
 | 
			
		||||
            requests_per_second=request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
 | 
			
		||||
        )
 | 
			
		||||
        docs = loader.load()
 | 
			
		||||
        save_docs_to_vector_db(request, docs, collection_name, overwrite=True)
 | 
			
		||||
        save_docs_to_vector_db(request, docs, collection_name, overwrite=True, user=user)
 | 
			
		||||
 | 
			
		||||
        return {
 | 
			
		||||
            "status": True,
 | 
			
		||||
@ -1306,7 +1309,7 @@ def query_doc_handler(
 | 
			
		||||
            return query_doc_with_hybrid_search(
 | 
			
		||||
                collection_name=form_data.collection_name,
 | 
			
		||||
                query=form_data.query,
 | 
			
		||||
                embedding_function=request.app.state.EMBEDDING_FUNCTION,
 | 
			
		||||
                embedding_function=lambda query: request.app.state.EMBEDDING_FUNCTION(query, user=user),
 | 
			
		||||
                k=form_data.k if form_data.k else request.app.state.config.TOP_K,
 | 
			
		||||
                reranking_function=request.app.state.rf,
 | 
			
		||||
                r=(
 | 
			
		||||
@ -1314,12 +1317,14 @@ def query_doc_handler(
 | 
			
		||||
                    if form_data.r
 | 
			
		||||
                    else request.app.state.config.RELEVANCE_THRESHOLD
 | 
			
		||||
                ),
 | 
			
		||||
                user=user
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            return query_doc(
 | 
			
		||||
                collection_name=form_data.collection_name,
 | 
			
		||||
                query_embedding=request.app.state.EMBEDDING_FUNCTION(form_data.query),
 | 
			
		||||
                query_embedding=request.app.state.EMBEDDING_FUNCTION(form_data.query, user=user),
 | 
			
		||||
                k=form_data.k if form_data.k else request.app.state.config.TOP_K,
 | 
			
		||||
                user=user
 | 
			
		||||
            )
 | 
			
		||||
    except Exception as e:
 | 
			
		||||
        log.exception(e)
 | 
			
		||||
@ -1348,7 +1353,7 @@ def query_collection_handler(
 | 
			
		||||
            return query_collection_with_hybrid_search(
 | 
			
		||||
                collection_names=form_data.collection_names,
 | 
			
		||||
                queries=[form_data.query],
 | 
			
		||||
                embedding_function=request.app.state.EMBEDDING_FUNCTION,
 | 
			
		||||
                embedding_function=lambda query: request.app.state.EMBEDDING_FUNCTION(query, user=user),
 | 
			
		||||
                k=form_data.k if form_data.k else request.app.state.config.TOP_K,
 | 
			
		||||
                reranking_function=request.app.state.rf,
 | 
			
		||||
                r=(
 | 
			
		||||
@ -1361,7 +1366,7 @@ def query_collection_handler(
 | 
			
		||||
            return query_collection(
 | 
			
		||||
                collection_names=form_data.collection_names,
 | 
			
		||||
                queries=[form_data.query],
 | 
			
		||||
                embedding_function=request.app.state.EMBEDDING_FUNCTION,
 | 
			
		||||
                embedding_function=lambda query: request.app.state.EMBEDDING_FUNCTION(query,user=user),
 | 
			
		||||
                k=form_data.k if form_data.k else request.app.state.config.TOP_K,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
@ -1509,6 +1514,7 @@ def process_files_batch(
 | 
			
		||||
                docs=all_docs,
 | 
			
		||||
                collection_name=collection_name,
 | 
			
		||||
                add=True,
 | 
			
		||||
                user=user,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            # Update all files with collection name
 | 
			
		||||
 | 
			
		||||
@ -630,7 +630,7 @@ async def chat_completion_files_handler(
 | 
			
		||||
                    lambda: get_sources_from_files(
 | 
			
		||||
                        files=files,
 | 
			
		||||
                        queries=queries,
 | 
			
		||||
                        embedding_function=request.app.state.EMBEDDING_FUNCTION,
 | 
			
		||||
                        embedding_function=lambda query: request.app.state.EMBEDDING_FUNCTION(query,user=user),
 | 
			
		||||
                        k=request.app.state.config.TOP_K,
 | 
			
		||||
                        reranking_function=request.app.state.rf,
 | 
			
		||||
                        r=request.app.state.config.RELEVANCE_THRESHOLD,
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user