feat: direct connections integration

This commit is contained in:
Timothy Jaeryang Baek
2025-02-12 22:56:33 -08:00
parent 304ce2a14d
commit c83e68282d
6 changed files with 387 additions and 94 deletions

View File

@@ -139,7 +139,12 @@ async def update_task_config(
async def generate_title(
request: Request, form_data: dict, user=Depends(get_verified_user)
):
models = request.app.state.MODELS
if request.state.direct and request.state.model:
models = {
request.state.model["id"]: request.state.model,
}
else:
models = request.app.state.MODELS
model_id = form_data["model"]
if model_id not in models:
@@ -198,6 +203,7 @@ async def generate_title(
}
),
"metadata": {
**(request.state.metadata if request.state.metadata else {}),
"task": str(TASKS.TITLE_GENERATION),
"task_body": form_data,
"chat_id": form_data.get("chat_id", None),
@@ -225,7 +231,12 @@ async def generate_chat_tags(
content={"detail": "Tags generation is disabled"},
)
models = request.app.state.MODELS
if request.state.direct and request.state.model:
models = {
request.state.model["id"]: request.state.model,
}
else:
models = request.app.state.MODELS
model_id = form_data["model"]
if model_id not in models:
@@ -261,6 +272,7 @@ async def generate_chat_tags(
"messages": [{"role": "user", "content": content}],
"stream": False,
"metadata": {
**(request.state.metadata if request.state.metadata else {}),
"task": str(TASKS.TAGS_GENERATION),
"task_body": form_data,
"chat_id": form_data.get("chat_id", None),
@@ -281,7 +293,12 @@ async def generate_chat_tags(
async def generate_image_prompt(
request: Request, form_data: dict, user=Depends(get_verified_user)
):
models = request.app.state.MODELS
if request.state.direct and request.state.model:
models = {
request.state.model["id"]: request.state.model,
}
else:
models = request.app.state.MODELS
model_id = form_data["model"]
if model_id not in models:
@@ -321,6 +338,7 @@ async def generate_image_prompt(
"messages": [{"role": "user", "content": content}],
"stream": False,
"metadata": {
**(request.state.metadata if request.state.metadata else {}),
"task": str(TASKS.IMAGE_PROMPT_GENERATION),
"task_body": form_data,
"chat_id": form_data.get("chat_id", None),
@@ -356,7 +374,12 @@ async def generate_queries(
detail=f"Query generation is disabled",
)
models = request.app.state.MODELS
if request.state.direct and request.state.model:
models = {
request.state.model["id"]: request.state.model,
}
else:
models = request.app.state.MODELS
model_id = form_data["model"]
if model_id not in models:
@@ -392,6 +415,7 @@ async def generate_queries(
"messages": [{"role": "user", "content": content}],
"stream": False,
"metadata": {
**(request.state.metadata if request.state.metadata else {}),
"task": str(TASKS.QUERY_GENERATION),
"task_body": form_data,
"chat_id": form_data.get("chat_id", None),
@@ -431,7 +455,12 @@ async def generate_autocompletion(
detail=f"Input prompt exceeds maximum length of {request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH}",
)
models = request.app.state.MODELS
if request.state.direct and request.state.model:
models = {
request.state.model["id"]: request.state.model,
}
else:
models = request.app.state.MODELS
model_id = form_data["model"]
if model_id not in models:
@@ -467,6 +496,7 @@ async def generate_autocompletion(
"messages": [{"role": "user", "content": content}],
"stream": False,
"metadata": {
**(request.state.metadata if request.state.metadata else {}),
"task": str(TASKS.AUTOCOMPLETE_GENERATION),
"task_body": form_data,
"chat_id": form_data.get("chat_id", None),
@@ -488,7 +518,12 @@ async def generate_emoji(
request: Request, form_data: dict, user=Depends(get_verified_user)
):
models = request.app.state.MODELS
if request.state.direct and request.state.model:
models = {
request.state.model["id"]: request.state.model,
}
else:
models = request.app.state.MODELS
model_id = form_data["model"]
if model_id not in models:
@@ -531,7 +566,11 @@ async def generate_emoji(
}
),
"chat_id": form_data.get("chat_id", None),
"metadata": {"task": str(TASKS.EMOJI_GENERATION), "task_body": form_data},
"metadata": {
**(request.state.metadata if request.state.metadata else {}),
"task": str(TASKS.EMOJI_GENERATION),
"task_body": form_data,
},
}
try:
@@ -548,7 +587,13 @@ async def generate_moa_response(
request: Request, form_data: dict, user=Depends(get_verified_user)
):
models = request.app.state.MODELS
if request.state.direct and request.state.model:
models = {
request.state.model["id"]: request.state.model,
}
else:
models = request.app.state.MODELS
model_id = form_data["model"]
if model_id not in models:
@@ -581,6 +626,7 @@ async def generate_moa_response(
"messages": [{"role": "user", "content": content}],
"stream": form_data.get("stream", False),
"metadata": {
**(request.state.metadata if request.state.metadata else {}),
"chat_id": form_data.get("chat_id", None),
"task": str(TASKS.MOA_RESPONSE_GENERATION),
"task_body": form_data,