mirror of
https://github.com/open-webui/open-webui
synced 2025-05-17 20:05:08 +00:00
Merge pull request #9068 from df-cgdm/main
**feat** Add user related headers when calling an external embedding api
This commit is contained in:
commit
f6f8c08cb0
@ -15,8 +15,9 @@ from langchain_core.documents import Document
|
|||||||
from open_webui.config import VECTOR_DB
|
from open_webui.config import VECTOR_DB
|
||||||
from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT
|
from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT
|
||||||
from open_webui.utils.misc import get_last_user_message
|
from open_webui.utils.misc import get_last_user_message
|
||||||
|
from open_webui.models.users import UserModel
|
||||||
|
|
||||||
from open_webui.env import SRC_LOG_LEVELS, OFFLINE_MODE
|
from open_webui.env import SRC_LOG_LEVELS, OFFLINE_MODE, ENABLE_FORWARD_USER_INFO_HEADERS
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
@ -64,6 +65,7 @@ def query_doc(
|
|||||||
collection_name: str,
|
collection_name: str,
|
||||||
query_embedding: list[float],
|
query_embedding: list[float],
|
||||||
k: int,
|
k: int,
|
||||||
|
user: UserModel=None
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
result = VECTOR_DB_CLIENT.search(
|
result = VECTOR_DB_CLIENT.search(
|
||||||
@ -256,29 +258,32 @@ def get_embedding_function(
|
|||||||
embedding_function,
|
embedding_function,
|
||||||
url,
|
url,
|
||||||
key,
|
key,
|
||||||
embedding_batch_size,
|
embedding_batch_size
|
||||||
):
|
):
|
||||||
if embedding_engine == "":
|
if embedding_engine == "":
|
||||||
return lambda query: embedding_function.encode(query).tolist()
|
return lambda query, user=None: embedding_function.encode(query).tolist()
|
||||||
elif embedding_engine in ["ollama", "openai"]:
|
elif embedding_engine in ["ollama", "openai"]:
|
||||||
func = lambda query: generate_embeddings(
|
func = lambda query, user=None: generate_embeddings(
|
||||||
engine=embedding_engine,
|
engine=embedding_engine,
|
||||||
model=embedding_model,
|
model=embedding_model,
|
||||||
text=query,
|
text=query,
|
||||||
url=url,
|
url=url,
|
||||||
key=key,
|
key=key,
|
||||||
|
user=user
|
||||||
)
|
)
|
||||||
|
|
||||||
def generate_multiple(query, func):
|
def generate_multiple(query, user, func):
|
||||||
if isinstance(query, list):
|
if isinstance(query, list):
|
||||||
embeddings = []
|
embeddings = []
|
||||||
for i in range(0, len(query), embedding_batch_size):
|
for i in range(0, len(query), embedding_batch_size):
|
||||||
embeddings.extend(func(query[i : i + embedding_batch_size]))
|
embeddings.extend(func(query[i : i + embedding_batch_size], user=user))
|
||||||
return embeddings
|
return embeddings
|
||||||
else:
|
else:
|
||||||
return func(query)
|
return func(query, user)
|
||||||
|
|
||||||
return lambda query: generate_multiple(query, func)
|
return lambda query, user=None: generate_multiple(query, user, func)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown embedding engine: {embedding_engine}")
|
||||||
|
|
||||||
|
|
||||||
def get_sources_from_files(
|
def get_sources_from_files(
|
||||||
@ -423,7 +428,7 @@ def get_model_path(model: str, update_model: bool = False):
|
|||||||
|
|
||||||
|
|
||||||
def generate_openai_batch_embeddings(
|
def generate_openai_batch_embeddings(
|
||||||
model: str, texts: list[str], url: str = "https://api.openai.com/v1", key: str = ""
|
model: str, texts: list[str], url: str = "https://api.openai.com/v1", key: str = "", user: UserModel = None
|
||||||
) -> Optional[list[list[float]]]:
|
) -> Optional[list[list[float]]]:
|
||||||
try:
|
try:
|
||||||
r = requests.post(
|
r = requests.post(
|
||||||
@ -431,6 +436,16 @@ def generate_openai_batch_embeddings(
|
|||||||
headers={
|
headers={
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
"Authorization": f"Bearer {key}",
|
"Authorization": f"Bearer {key}",
|
||||||
|
**(
|
||||||
|
{
|
||||||
|
"X-OpenWebUI-User-Name": user.name,
|
||||||
|
"X-OpenWebUI-User-Id": user.id,
|
||||||
|
"X-OpenWebUI-User-Email": user.email,
|
||||||
|
"X-OpenWebUI-User-Role": user.role,
|
||||||
|
}
|
||||||
|
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||||
|
else {}
|
||||||
|
),
|
||||||
},
|
},
|
||||||
json={"input": texts, "model": model},
|
json={"input": texts, "model": model},
|
||||||
)
|
)
|
||||||
@ -446,7 +461,7 @@ def generate_openai_batch_embeddings(
|
|||||||
|
|
||||||
|
|
||||||
def generate_ollama_batch_embeddings(
|
def generate_ollama_batch_embeddings(
|
||||||
model: str, texts: list[str], url: str, key: str = ""
|
model: str, texts: list[str], url: str, key: str = "", user: UserModel = None
|
||||||
) -> Optional[list[list[float]]]:
|
) -> Optional[list[list[float]]]:
|
||||||
try:
|
try:
|
||||||
r = requests.post(
|
r = requests.post(
|
||||||
@ -454,6 +469,16 @@ def generate_ollama_batch_embeddings(
|
|||||||
headers={
|
headers={
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
"Authorization": f"Bearer {key}",
|
"Authorization": f"Bearer {key}",
|
||||||
|
**(
|
||||||
|
{
|
||||||
|
"X-OpenWebUI-User-Name": user.name,
|
||||||
|
"X-OpenWebUI-User-Id": user.id,
|
||||||
|
"X-OpenWebUI-User-Email": user.email,
|
||||||
|
"X-OpenWebUI-User-Role": user.role,
|
||||||
|
}
|
||||||
|
if ENABLE_FORWARD_USER_INFO_HEADERS
|
||||||
|
else {}
|
||||||
|
),
|
||||||
},
|
},
|
||||||
json={"input": texts, "model": model},
|
json={"input": texts, "model": model},
|
||||||
)
|
)
|
||||||
@ -472,22 +497,23 @@ def generate_ollama_batch_embeddings(
|
|||||||
def generate_embeddings(engine: str, model: str, text: Union[str, list[str]], **kwargs):
|
def generate_embeddings(engine: str, model: str, text: Union[str, list[str]], **kwargs):
|
||||||
url = kwargs.get("url", "")
|
url = kwargs.get("url", "")
|
||||||
key = kwargs.get("key", "")
|
key = kwargs.get("key", "")
|
||||||
|
user = kwargs.get("user")
|
||||||
|
|
||||||
if engine == "ollama":
|
if engine == "ollama":
|
||||||
if isinstance(text, list):
|
if isinstance(text, list):
|
||||||
embeddings = generate_ollama_batch_embeddings(
|
embeddings = generate_ollama_batch_embeddings(
|
||||||
**{"model": model, "texts": text, "url": url, "key": key}
|
**{"model": model, "texts": text, "url": url, "key": key, "user": user}
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
embeddings = generate_ollama_batch_embeddings(
|
embeddings = generate_ollama_batch_embeddings(
|
||||||
**{"model": model, "texts": [text], "url": url, "key": key}
|
**{"model": model, "texts": [text], "url": url, "key": key, "user": user}
|
||||||
)
|
)
|
||||||
return embeddings[0] if isinstance(text, str) else embeddings
|
return embeddings[0] if isinstance(text, str) else embeddings
|
||||||
elif engine == "openai":
|
elif engine == "openai":
|
||||||
if isinstance(text, list):
|
if isinstance(text, list):
|
||||||
embeddings = generate_openai_batch_embeddings(model, text, url, key)
|
embeddings = generate_openai_batch_embeddings(model, text, url, key, user)
|
||||||
else:
|
else:
|
||||||
embeddings = generate_openai_batch_embeddings(model, [text], url, key)
|
embeddings = generate_openai_batch_embeddings(model, [text], url, key, user)
|
||||||
|
|
||||||
return embeddings[0] if isinstance(text, str) else embeddings
|
return embeddings[0] if isinstance(text, str) else embeddings
|
||||||
|
|
||||||
|
@ -71,7 +71,7 @@ def upload_file(
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
process_file(request, ProcessFileForm(file_id=id))
|
process_file(request, ProcessFileForm(file_id=id), user=user)
|
||||||
file_item = Files.get_file_by_id(id=id)
|
file_item = Files.get_file_by_id(id=id)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(e)
|
log.exception(e)
|
||||||
@ -193,7 +193,9 @@ async def update_file_data_content_by_id(
|
|||||||
if file and (file.user_id == user.id or user.role == "admin"):
|
if file and (file.user_id == user.id or user.role == "admin"):
|
||||||
try:
|
try:
|
||||||
process_file(
|
process_file(
|
||||||
request, ProcessFileForm(file_id=id, content=form_data.content)
|
request,
|
||||||
|
ProcessFileForm(file_id=id, content=form_data.content),
|
||||||
|
user=user
|
||||||
)
|
)
|
||||||
file = Files.get_file_by_id(id=id)
|
file = Files.get_file_by_id(id=id)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -289,7 +289,9 @@ def add_file_to_knowledge_by_id(
|
|||||||
# Add content to the vector database
|
# Add content to the vector database
|
||||||
try:
|
try:
|
||||||
process_file(
|
process_file(
|
||||||
request, ProcessFileForm(file_id=form_data.file_id, collection_name=id)
|
request,
|
||||||
|
ProcessFileForm(file_id=form_data.file_id, collection_name=id),
|
||||||
|
user=user
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.debug(e)
|
log.debug(e)
|
||||||
@ -372,7 +374,9 @@ def update_file_from_knowledge_by_id(
|
|||||||
# Add content to the vector database
|
# Add content to the vector database
|
||||||
try:
|
try:
|
||||||
process_file(
|
process_file(
|
||||||
request, ProcessFileForm(file_id=form_data.file_id, collection_name=id)
|
request,
|
||||||
|
ProcessFileForm(file_id=form_data.file_id, collection_name=id),
|
||||||
|
user=user
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
@ -57,7 +57,7 @@ async def add_memory(
|
|||||||
{
|
{
|
||||||
"id": memory.id,
|
"id": memory.id,
|
||||||
"text": memory.content,
|
"text": memory.content,
|
||||||
"vector": request.app.state.EMBEDDING_FUNCTION(memory.content),
|
"vector": request.app.state.EMBEDDING_FUNCTION(memory.content, user),
|
||||||
"metadata": {"created_at": memory.created_at},
|
"metadata": {"created_at": memory.created_at},
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@ -82,7 +82,7 @@ async def query_memory(
|
|||||||
):
|
):
|
||||||
results = VECTOR_DB_CLIENT.search(
|
results = VECTOR_DB_CLIENT.search(
|
||||||
collection_name=f"user-memory-{user.id}",
|
collection_name=f"user-memory-{user.id}",
|
||||||
vectors=[request.app.state.EMBEDDING_FUNCTION(form_data.content)],
|
vectors=[request.app.state.EMBEDDING_FUNCTION(form_data.content, user)],
|
||||||
limit=form_data.k,
|
limit=form_data.k,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -105,7 +105,7 @@ async def reset_memory_from_vector_db(
|
|||||||
{
|
{
|
||||||
"id": memory.id,
|
"id": memory.id,
|
||||||
"text": memory.content,
|
"text": memory.content,
|
||||||
"vector": request.app.state.EMBEDDING_FUNCTION(memory.content),
|
"vector": request.app.state.EMBEDDING_FUNCTION(memory.content, user),
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"created_at": memory.created_at,
|
"created_at": memory.created_at,
|
||||||
"updated_at": memory.updated_at,
|
"updated_at": memory.updated_at,
|
||||||
@ -160,7 +160,7 @@ async def update_memory_by_id(
|
|||||||
{
|
{
|
||||||
"id": memory.id,
|
"id": memory.id,
|
||||||
"text": memory.content,
|
"text": memory.content,
|
||||||
"vector": request.app.state.EMBEDDING_FUNCTION(memory.content),
|
"vector": request.app.state.EMBEDDING_FUNCTION(memory.content, user),
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"created_at": memory.created_at,
|
"created_at": memory.created_at,
|
||||||
"updated_at": memory.updated_at,
|
"updated_at": memory.updated_at,
|
||||||
|
@ -666,6 +666,7 @@ def save_docs_to_vector_db(
|
|||||||
overwrite: bool = False,
|
overwrite: bool = False,
|
||||||
split: bool = True,
|
split: bool = True,
|
||||||
add: bool = False,
|
add: bool = False,
|
||||||
|
user = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
def _get_docs_info(docs: list[Document]) -> str:
|
def _get_docs_info(docs: list[Document]) -> str:
|
||||||
docs_info = set()
|
docs_info = set()
|
||||||
@ -781,7 +782,8 @@ def save_docs_to_vector_db(
|
|||||||
)
|
)
|
||||||
|
|
||||||
embeddings = embedding_function(
|
embeddings = embedding_function(
|
||||||
list(map(lambda x: x.replace("\n", " "), texts))
|
list(map(lambda x: x.replace("\n", " "), texts)),
|
||||||
|
user = user
|
||||||
)
|
)
|
||||||
|
|
||||||
items = [
|
items = [
|
||||||
@ -939,6 +941,7 @@ def process_file(
|
|||||||
"hash": hash,
|
"hash": hash,
|
||||||
},
|
},
|
||||||
add=(True if form_data.collection_name else False),
|
add=(True if form_data.collection_name else False),
|
||||||
|
user=user
|
||||||
)
|
)
|
||||||
|
|
||||||
if result:
|
if result:
|
||||||
@ -996,7 +999,7 @@ def process_text(
|
|||||||
text_content = form_data.content
|
text_content = form_data.content
|
||||||
log.debug(f"text_content: {text_content}")
|
log.debug(f"text_content: {text_content}")
|
||||||
|
|
||||||
result = save_docs_to_vector_db(request, docs, collection_name)
|
result = save_docs_to_vector_db(request, docs, collection_name, user=user)
|
||||||
if result:
|
if result:
|
||||||
return {
|
return {
|
||||||
"status": True,
|
"status": True,
|
||||||
@ -1029,7 +1032,7 @@ def process_youtube_video(
|
|||||||
content = " ".join([doc.page_content for doc in docs])
|
content = " ".join([doc.page_content for doc in docs])
|
||||||
log.debug(f"text_content: {content}")
|
log.debug(f"text_content: {content}")
|
||||||
|
|
||||||
save_docs_to_vector_db(request, docs, collection_name, overwrite=True)
|
save_docs_to_vector_db(request, docs, collection_name, overwrite=True, user=user)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"status": True,
|
"status": True,
|
||||||
@ -1070,7 +1073,7 @@ def process_web(
|
|||||||
content = " ".join([doc.page_content for doc in docs])
|
content = " ".join([doc.page_content for doc in docs])
|
||||||
|
|
||||||
log.debug(f"text_content: {content}")
|
log.debug(f"text_content: {content}")
|
||||||
save_docs_to_vector_db(request, docs, collection_name, overwrite=True)
|
save_docs_to_vector_db(request, docs, collection_name, overwrite=True, user=user)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"status": True,
|
"status": True,
|
||||||
@ -1286,7 +1289,7 @@ def process_web_search(
|
|||||||
requests_per_second=request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
|
requests_per_second=request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
|
||||||
)
|
)
|
||||||
docs = loader.load()
|
docs = loader.load()
|
||||||
save_docs_to_vector_db(request, docs, collection_name, overwrite=True)
|
save_docs_to_vector_db(request, docs, collection_name, overwrite=True, user=user)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"status": True,
|
"status": True,
|
||||||
@ -1320,7 +1323,7 @@ def query_doc_handler(
|
|||||||
return query_doc_with_hybrid_search(
|
return query_doc_with_hybrid_search(
|
||||||
collection_name=form_data.collection_name,
|
collection_name=form_data.collection_name,
|
||||||
query=form_data.query,
|
query=form_data.query,
|
||||||
embedding_function=request.app.state.EMBEDDING_FUNCTION,
|
embedding_function=lambda query: request.app.state.EMBEDDING_FUNCTION(query, user=user),
|
||||||
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
|
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
|
||||||
reranking_function=request.app.state.rf,
|
reranking_function=request.app.state.rf,
|
||||||
r=(
|
r=(
|
||||||
@ -1328,12 +1331,14 @@ def query_doc_handler(
|
|||||||
if form_data.r
|
if form_data.r
|
||||||
else request.app.state.config.RELEVANCE_THRESHOLD
|
else request.app.state.config.RELEVANCE_THRESHOLD
|
||||||
),
|
),
|
||||||
|
user=user
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return query_doc(
|
return query_doc(
|
||||||
collection_name=form_data.collection_name,
|
collection_name=form_data.collection_name,
|
||||||
query_embedding=request.app.state.EMBEDDING_FUNCTION(form_data.query),
|
query_embedding=request.app.state.EMBEDDING_FUNCTION(form_data.query, user=user),
|
||||||
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
|
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
|
||||||
|
user=user
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(e)
|
log.exception(e)
|
||||||
@ -1362,7 +1367,7 @@ def query_collection_handler(
|
|||||||
return query_collection_with_hybrid_search(
|
return query_collection_with_hybrid_search(
|
||||||
collection_names=form_data.collection_names,
|
collection_names=form_data.collection_names,
|
||||||
queries=[form_data.query],
|
queries=[form_data.query],
|
||||||
embedding_function=request.app.state.EMBEDDING_FUNCTION,
|
embedding_function=lambda query: request.app.state.EMBEDDING_FUNCTION(query, user=user),
|
||||||
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
|
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
|
||||||
reranking_function=request.app.state.rf,
|
reranking_function=request.app.state.rf,
|
||||||
r=(
|
r=(
|
||||||
@ -1375,7 +1380,7 @@ def query_collection_handler(
|
|||||||
return query_collection(
|
return query_collection(
|
||||||
collection_names=form_data.collection_names,
|
collection_names=form_data.collection_names,
|
||||||
queries=[form_data.query],
|
queries=[form_data.query],
|
||||||
embedding_function=request.app.state.EMBEDDING_FUNCTION,
|
embedding_function=lambda query: request.app.state.EMBEDDING_FUNCTION(query,user=user),
|
||||||
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
|
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1523,6 +1528,7 @@ def process_files_batch(
|
|||||||
docs=all_docs,
|
docs=all_docs,
|
||||||
collection_name=collection_name,
|
collection_name=collection_name,
|
||||||
add=True,
|
add=True,
|
||||||
|
user=user,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update all files with collection name
|
# Update all files with collection name
|
||||||
|
@ -634,7 +634,7 @@ async def chat_completion_files_handler(
|
|||||||
lambda: get_sources_from_files(
|
lambda: get_sources_from_files(
|
||||||
files=files,
|
files=files,
|
||||||
queries=queries,
|
queries=queries,
|
||||||
embedding_function=request.app.state.EMBEDDING_FUNCTION,
|
embedding_function=lambda query: request.app.state.EMBEDDING_FUNCTION(query,user=user),
|
||||||
k=request.app.state.config.TOP_K,
|
k=request.app.state.config.TOP_K,
|
||||||
reranking_function=request.app.state.rf,
|
reranking_function=request.app.state.rf,
|
||||||
r=request.app.state.config.RELEVANCE_THRESHOLD,
|
r=request.app.state.config.RELEVANCE_THRESHOLD,
|
||||||
|
Loading…
Reference in New Issue
Block a user