diff --git a/backend/main.py b/backend/main.py index 0a0587159..11c78645b 100644 --- a/backend/main.py +++ b/backend/main.py @@ -170,6 +170,13 @@ app.state.MODELS = {} origins = ["*"] +################################## +# +# ChatCompletion Middleware +# +################################## + + async def get_function_call_response( messages, files, tool_id, template, task_model_id, user ): @@ -469,6 +476,12 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): app.add_middleware(ChatCompletionMiddleware) +################################## +# +# Pipeline Middleware +# +################################## + def filter_pipeline(payload, user): user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role} @@ -628,7 +641,6 @@ async def update_embedding_function(request: Request, call_next): app.mount("/ws", socket_app) - app.mount("/ollama", ollama_app) app.mount("/openai", openai_app) @@ -730,6 +742,104 @@ async def get_models(user=Depends(get_verified_user)): return {"data": models} +@app.post("/api/chat/completions") +async def generate_chat_completions(form_data: dict, user=Depends(get_verified_user)): + model_id = form_data["model"] + if model_id not in app.state.MODELS: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Model not found", + ) + + model = app.state.MODELS[model_id] + print(model) + + if model["owned_by"] == "ollama": + return await generate_ollama_chat_completion(form_data, user=user) + else: + return await generate_openai_chat_completion(form_data, user=user) + + +@app.post("/api/chat/completed") +async def chat_completed(form_data: dict, user=Depends(get_verified_user)): + data = form_data + model_id = data["model"] + + filters = [ + model + for model in app.state.MODELS.values() + if "pipeline" in model + and "type" in model["pipeline"] + and model["pipeline"]["type"] == "filter" + and ( + model["pipeline"]["pipelines"] == ["*"] + or any( + model_id == target_model_id + for target_model_id in model["pipeline"]["pipelines"] + ) + ) + ] + sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"]) + + print(model_id) + + if model_id in app.state.MODELS: + model = app.state.MODELS[model_id] + if "pipeline" in model: + sorted_filters = [model] + sorted_filters + + for filter in sorted_filters: + r = None + try: + urlIdx = filter["urlIdx"] + + url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] + key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] + + if key != "": + headers = {"Authorization": f"Bearer {key}"} + r = requests.post( + f"{url}/{filter['id']}/filter/outlet", + headers=headers, + json={ + "user": {"id": user.id, "name": user.name, "role": user.role}, + "body": data, + }, + ) + + r.raise_for_status() + data = r.json() + except Exception as e: + # Handle connection error here + print(f"Connection error: {e}") + + if r is not None: + try: + res = r.json() + if "detail" in res: + return JSONResponse( + status_code=r.status_code, + content=res, + ) + except: + pass + + else: + pass + + return data + + +################################## +# +# Task Endpoints +# +################################## + + +# TODO: Refactor task API endpoints below into a separate file + + @app.get("/api/task/config") async def get_task_config(user=Depends(get_verified_user)): return { @@ -1015,92 +1125,14 @@ async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_ ) -@app.post("/api/chat/completions") -async def generate_chat_completions(form_data: dict, user=Depends(get_verified_user)): - model_id = form_data["model"] - if model_id not in app.state.MODELS: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Model not found", - ) - - model = app.state.MODELS[model_id] - print(model) - - if model["owned_by"] == "ollama": - return await generate_ollama_chat_completion(form_data, user=user) - else: - return await generate_openai_chat_completion(form_data, user=user) +################################## +# +# Pipelines Endpoints +# +################################## -@app.post("/api/chat/completed") -async def chat_completed(form_data: dict, user=Depends(get_verified_user)): - data = form_data - model_id = data["model"] - - filters = [ - model - for model in app.state.MODELS.values() - if "pipeline" in model - and "type" in model["pipeline"] - and model["pipeline"]["type"] == "filter" - and ( - model["pipeline"]["pipelines"] == ["*"] - or any( - model_id == target_model_id - for target_model_id in model["pipeline"]["pipelines"] - ) - ) - ] - sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"]) - - print(model_id) - - if model_id in app.state.MODELS: - model = app.state.MODELS[model_id] - if "pipeline" in model: - sorted_filters = [model] + sorted_filters - - for filter in sorted_filters: - r = None - try: - urlIdx = filter["urlIdx"] - - url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] - key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] - - if key != "": - headers = {"Authorization": f"Bearer {key}"} - r = requests.post( - f"{url}/{filter['id']}/filter/outlet", - headers=headers, - json={ - "user": {"id": user.id, "name": user.name, "role": user.role}, - "body": data, - }, - ) - - r.raise_for_status() - data = r.json() - except Exception as e: - # Handle connection error here - print(f"Connection error: {e}") - - if r is not None: - try: - res = r.json() - if "detail" in res: - return JSONResponse( - status_code=r.status_code, - content=res, - ) - except: - pass - - else: - pass - - return data +# TODO: Refactor pipelines API endpoints below into a separate file @app.get("/api/pipelines/list") @@ -1423,6 +1455,13 @@ async def update_pipeline_valves( ) +################################## +# +# Config Endpoints +# +################################## + + @app.get("/api/config") async def get_app_config(): # Checking and Handling the Absence of 'ui' in CONFIG_DATA @@ -1486,6 +1525,9 @@ async def update_model_filter_config( } +# TODO: webhook endpoint should be under config endpoints + + @app.get("/api/webhook") async def get_webhook_url(user=Depends(get_admin_user)): return { diff --git a/src/lib/components/workspace/Functions/FunctionEditor.svelte b/src/lib/components/workspace/Functions/FunctionEditor.svelte index 12cb0386d..6e35616f2 100644 --- a/src/lib/components/workspace/Functions/FunctionEditor.svelte +++ b/src/lib/components/workspace/Functions/FunctionEditor.svelte @@ -30,9 +30,10 @@ let boilerplate = `from pydantic import BaseModel from typing import Optional + class Filter: class Valves(BaseModel): - max_turns: int + max_turns: int = 4 pass def __init__(self): @@ -42,14 +43,14 @@ class Filter: # Initialize 'valves' with specific configurations. Using 'Valves' instance helps encapsulate settings, # which ensures settings are managed cohesively and not confused with operational flags like 'file_handler'. - self.valves = self.Valves(**{"max_turns": 10}) + self.valves = self.Valves(**{"max_turns": 2}) pass def inlet(self, body: dict, user: Optional[dict] = None) -> dict: # Modify the request body or validate it before processing by the chat completion API. # This function is the pre-processor for the API where various checks on the input can be performed. # It can also modify the request before sending it to the API. - + print("inlet") print(body) print(user) @@ -65,7 +66,7 @@ class Filter: def outlet(self, body: dict, user: Optional[dict] = None) -> dict: # Modify or analyze the response body after processing by the API. - # This function is the post-processor for the API, which can be used to modify the response + # This function is the post-processor for the API, which can be used to modify the response # or perform additional checks and analytics. print(f"outlet") print(body)