mirror of
https://github.com/open-webui/open-webui
synced 2025-04-05 13:15:36 +00:00
refac
This commit is contained in:
parent
3dc40030a1
commit
cafc5413f5
@ -256,7 +256,7 @@ def query_collection(
|
|||||||
) -> dict:
|
) -> dict:
|
||||||
results = []
|
results = []
|
||||||
for query in queries:
|
for query in queries:
|
||||||
query_embedding = embedding_function(query, RAG_EMBEDDING_QUERY_PREFIX)
|
query_embedding = embedding_function(query, prefix=RAG_EMBEDDING_QUERY_PREFIX)
|
||||||
for collection_name in collection_names:
|
for collection_name in collection_names:
|
||||||
if collection_name:
|
if collection_name:
|
||||||
try:
|
try:
|
||||||
@ -334,11 +334,11 @@ def get_embedding_function(
|
|||||||
embedding_batch_size,
|
embedding_batch_size,
|
||||||
):
|
):
|
||||||
if embedding_engine == "":
|
if embedding_engine == "":
|
||||||
return lambda query, prefix, user=None: embedding_function.encode(
|
return lambda query, prefix=None, user=None: embedding_function.encode(
|
||||||
query, prompt=prefix if prefix else None
|
query, prompt=prefix if prefix else None
|
||||||
).tolist()
|
).tolist()
|
||||||
elif embedding_engine in ["ollama", "openai"]:
|
elif embedding_engine in ["ollama", "openai"]:
|
||||||
func = lambda query, prefix, user=None: generate_embeddings(
|
func = lambda query, prefix=None, user=None: generate_embeddings(
|
||||||
engine=embedding_engine,
|
engine=embedding_engine,
|
||||||
model=embedding_model,
|
model=embedding_model,
|
||||||
text=query,
|
text=query,
|
||||||
@ -363,7 +363,7 @@ def get_embedding_function(
|
|||||||
else:
|
else:
|
||||||
return func(query, prefix, user)
|
return func(query, prefix, user)
|
||||||
|
|
||||||
return lambda query, prefix, user=None: generate_multiple(
|
return lambda query, prefix=None, user=None: generate_multiple(
|
||||||
query, prefix, user, func
|
query, prefix, user, func
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
@ -57,7 +57,9 @@ async def add_memory(
|
|||||||
{
|
{
|
||||||
"id": memory.id,
|
"id": memory.id,
|
||||||
"text": memory.content,
|
"text": memory.content,
|
||||||
"vector": request.app.state.EMBEDDING_FUNCTION(memory.content, user),
|
"vector": request.app.state.EMBEDDING_FUNCTION(
|
||||||
|
memory.content, user=user
|
||||||
|
),
|
||||||
"metadata": {"created_at": memory.created_at},
|
"metadata": {"created_at": memory.created_at},
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@ -82,7 +84,7 @@ async def query_memory(
|
|||||||
):
|
):
|
||||||
results = VECTOR_DB_CLIENT.search(
|
results = VECTOR_DB_CLIENT.search(
|
||||||
collection_name=f"user-memory-{user.id}",
|
collection_name=f"user-memory-{user.id}",
|
||||||
vectors=[request.app.state.EMBEDDING_FUNCTION(form_data.content, user)],
|
vectors=[request.app.state.EMBEDDING_FUNCTION(form_data.content, user=user)],
|
||||||
limit=form_data.k,
|
limit=form_data.k,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -105,7 +107,9 @@ async def reset_memory_from_vector_db(
|
|||||||
{
|
{
|
||||||
"id": memory.id,
|
"id": memory.id,
|
||||||
"text": memory.content,
|
"text": memory.content,
|
||||||
"vector": request.app.state.EMBEDDING_FUNCTION(memory.content, user),
|
"vector": request.app.state.EMBEDDING_FUNCTION(
|
||||||
|
memory.content, user=user
|
||||||
|
),
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"created_at": memory.created_at,
|
"created_at": memory.created_at,
|
||||||
"updated_at": memory.updated_at,
|
"updated_at": memory.updated_at,
|
||||||
@ -161,7 +165,7 @@ async def update_memory_by_id(
|
|||||||
"id": memory.id,
|
"id": memory.id,
|
||||||
"text": memory.content,
|
"text": memory.content,
|
||||||
"vector": request.app.state.EMBEDDING_FUNCTION(
|
"vector": request.app.state.EMBEDDING_FUNCTION(
|
||||||
memory.content, user
|
memory.content, user=user
|
||||||
),
|
),
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"created_at": memory.created_at,
|
"created_at": memory.created_at,
|
||||||
|
@ -1518,8 +1518,8 @@ def query_doc_handler(
|
|||||||
return query_doc_with_hybrid_search(
|
return query_doc_with_hybrid_search(
|
||||||
collection_name=form_data.collection_name,
|
collection_name=form_data.collection_name,
|
||||||
query=form_data.query,
|
query=form_data.query,
|
||||||
embedding_function=lambda query: request.app.state.EMBEDDING_FUNCTION(
|
embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION(
|
||||||
query, user=user
|
query, prefix=prefix, user=user
|
||||||
),
|
),
|
||||||
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
|
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
|
||||||
reranking_function=request.app.state.rf,
|
reranking_function=request.app.state.rf,
|
||||||
@ -1569,8 +1569,8 @@ def query_collection_handler(
|
|||||||
return query_collection_with_hybrid_search(
|
return query_collection_with_hybrid_search(
|
||||||
collection_names=form_data.collection_names,
|
collection_names=form_data.collection_names,
|
||||||
queries=[form_data.query],
|
queries=[form_data.query],
|
||||||
embedding_function=lambda query: request.app.state.EMBEDDING_FUNCTION(
|
embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION(
|
||||||
query, user=user
|
query, prefix=prefix, user=user
|
||||||
),
|
),
|
||||||
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
|
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
|
||||||
reranking_function=request.app.state.rf,
|
reranking_function=request.app.state.rf,
|
||||||
@ -1586,8 +1586,8 @@ def query_collection_handler(
|
|||||||
return query_collection(
|
return query_collection(
|
||||||
collection_names=form_data.collection_names,
|
collection_names=form_data.collection_names,
|
||||||
queries=[form_data.query],
|
queries=[form_data.query],
|
||||||
embedding_function=lambda query: request.app.state.EMBEDDING_FUNCTION(
|
embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION(
|
||||||
query, user=user
|
query, prefix=prefix, user=user
|
||||||
),
|
),
|
||||||
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
|
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
|
||||||
)
|
)
|
||||||
@ -1666,7 +1666,7 @@ if ENV == "dev":
|
|||||||
async def get_embeddings(request: Request, text: Optional[str] = "Hello World!"):
|
async def get_embeddings(request: Request, text: Optional[str] = "Hello World!"):
|
||||||
return {
|
return {
|
||||||
"result": request.app.state.EMBEDDING_FUNCTION(
|
"result": request.app.state.EMBEDDING_FUNCTION(
|
||||||
text, RAG_EMBEDDING_QUERY_PREFIX
|
text, prefix=RAG_EMBEDDING_QUERY_PREFIX
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user