This commit is contained in:
Timothy Jaeryang Baek 2025-03-31 14:13:27 -07:00
parent 3dc40030a1
commit cafc5413f5
3 changed files with 19 additions and 15 deletions

View File

@ -256,7 +256,7 @@ def query_collection(
) -> dict: ) -> dict:
results = [] results = []
for query in queries: for query in queries:
query_embedding = embedding_function(query, RAG_EMBEDDING_QUERY_PREFIX) query_embedding = embedding_function(query, prefix=RAG_EMBEDDING_QUERY_PREFIX)
for collection_name in collection_names: for collection_name in collection_names:
if collection_name: if collection_name:
try: try:
@ -334,11 +334,11 @@ def get_embedding_function(
embedding_batch_size, embedding_batch_size,
): ):
if embedding_engine == "": if embedding_engine == "":
return lambda query, prefix, user=None: embedding_function.encode( return lambda query, prefix=None, user=None: embedding_function.encode(
query, prompt=prefix if prefix else None query, prompt=prefix if prefix else None
).tolist() ).tolist()
elif embedding_engine in ["ollama", "openai"]: elif embedding_engine in ["ollama", "openai"]:
func = lambda query, prefix, user=None: generate_embeddings( func = lambda query, prefix=None, user=None: generate_embeddings(
engine=embedding_engine, engine=embedding_engine,
model=embedding_model, model=embedding_model,
text=query, text=query,
@ -363,7 +363,7 @@ def get_embedding_function(
else: else:
return func(query, prefix, user) return func(query, prefix, user)
return lambda query, prefix, user=None: generate_multiple( return lambda query, prefix=None, user=None: generate_multiple(
query, prefix, user, func query, prefix, user, func
) )
else: else:

View File

@ -57,7 +57,9 @@ async def add_memory(
{ {
"id": memory.id, "id": memory.id,
"text": memory.content, "text": memory.content,
"vector": request.app.state.EMBEDDING_FUNCTION(memory.content, user), "vector": request.app.state.EMBEDDING_FUNCTION(
memory.content, user=user
),
"metadata": {"created_at": memory.created_at}, "metadata": {"created_at": memory.created_at},
} }
], ],
@ -82,7 +84,7 @@ async def query_memory(
): ):
results = VECTOR_DB_CLIENT.search( results = VECTOR_DB_CLIENT.search(
collection_name=f"user-memory-{user.id}", collection_name=f"user-memory-{user.id}",
vectors=[request.app.state.EMBEDDING_FUNCTION(form_data.content, user)], vectors=[request.app.state.EMBEDDING_FUNCTION(form_data.content, user=user)],
limit=form_data.k, limit=form_data.k,
) )
@ -105,7 +107,9 @@ async def reset_memory_from_vector_db(
{ {
"id": memory.id, "id": memory.id,
"text": memory.content, "text": memory.content,
"vector": request.app.state.EMBEDDING_FUNCTION(memory.content, user), "vector": request.app.state.EMBEDDING_FUNCTION(
memory.content, user=user
),
"metadata": { "metadata": {
"created_at": memory.created_at, "created_at": memory.created_at,
"updated_at": memory.updated_at, "updated_at": memory.updated_at,
@ -161,7 +165,7 @@ async def update_memory_by_id(
"id": memory.id, "id": memory.id,
"text": memory.content, "text": memory.content,
"vector": request.app.state.EMBEDDING_FUNCTION( "vector": request.app.state.EMBEDDING_FUNCTION(
memory.content, user memory.content, user=user
), ),
"metadata": { "metadata": {
"created_at": memory.created_at, "created_at": memory.created_at,

View File

@ -1518,8 +1518,8 @@ def query_doc_handler(
return query_doc_with_hybrid_search( return query_doc_with_hybrid_search(
collection_name=form_data.collection_name, collection_name=form_data.collection_name,
query=form_data.query, query=form_data.query,
embedding_function=lambda query: request.app.state.EMBEDDING_FUNCTION( embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION(
query, user=user query, prefix=prefix, user=user
), ),
k=form_data.k if form_data.k else request.app.state.config.TOP_K, k=form_data.k if form_data.k else request.app.state.config.TOP_K,
reranking_function=request.app.state.rf, reranking_function=request.app.state.rf,
@ -1569,8 +1569,8 @@ def query_collection_handler(
return query_collection_with_hybrid_search( return query_collection_with_hybrid_search(
collection_names=form_data.collection_names, collection_names=form_data.collection_names,
queries=[form_data.query], queries=[form_data.query],
embedding_function=lambda query: request.app.state.EMBEDDING_FUNCTION( embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION(
query, user=user query, prefix=prefix, user=user
), ),
k=form_data.k if form_data.k else request.app.state.config.TOP_K, k=form_data.k if form_data.k else request.app.state.config.TOP_K,
reranking_function=request.app.state.rf, reranking_function=request.app.state.rf,
@ -1586,8 +1586,8 @@ def query_collection_handler(
return query_collection( return query_collection(
collection_names=form_data.collection_names, collection_names=form_data.collection_names,
queries=[form_data.query], queries=[form_data.query],
embedding_function=lambda query: request.app.state.EMBEDDING_FUNCTION( embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION(
query, user=user query, prefix=prefix, user=user
), ),
k=form_data.k if form_data.k else request.app.state.config.TOP_K, k=form_data.k if form_data.k else request.app.state.config.TOP_K,
) )
@ -1666,7 +1666,7 @@ if ENV == "dev":
async def get_embeddings(request: Request, text: Optional[str] = "Hello World!"): async def get_embeddings(request: Request, text: Optional[str] = "Hello World!"):
return { return {
"result": request.app.state.EMBEDDING_FUNCTION( "result": request.app.state.EMBEDDING_FUNCTION(
text, RAG_EMBEDDING_QUERY_PREFIX text, prefix=RAG_EMBEDDING_QUERY_PREFIX
) )
} }