diff --git a/backend/apps/socket/main.py b/backend/apps/socket/main.py index 18ce7a607..1d98d37ff 100644 --- a/backend/apps/socket/main.py +++ b/backend/apps/socket/main.py @@ -52,7 +52,6 @@ async def user_join(sid, data): user = Users.get_user_by_id(data["id"]) if user: - SESSION_POOL[sid] = user.id if user.id in USER_POOL: USER_POOL[user.id].append(sid) @@ -80,7 +79,6 @@ def get_models_in_use(): @sio.on("usage") async def usage(sid, data): - model_id = data["model"] # Cancel previous callback if there is one @@ -139,7 +137,7 @@ async def disconnect(sid): print(f"Unknown session ID {sid} disconnected") -async def get_event_emitter(request_info): +def get_event_emitter(request_info): async def __event_emitter__(event_data): await sio.emit( "chat-events", @@ -154,7 +152,7 @@ async def get_event_emitter(request_info): return __event_emitter__ -async def get_event_call(request_info): +def get_event_call(request_info): async def __event_call__(event_data): response = await sio.call( "chat-events", diff --git a/backend/apps/webui/main.py b/backend/apps/webui/main.py index 97165a11b..331713b07 100644 --- a/backend/apps/webui/main.py +++ b/backend/apps/webui/main.py @@ -1,9 +1,6 @@ -from fastapi import FastAPI, Depends -from fastapi.routing import APIRoute +from fastapi import FastAPI from fastapi.responses import StreamingResponse from fastapi.middleware.cors import CORSMiddleware -from starlette.middleware.sessions import SessionMiddleware -from sqlalchemy.orm import Session from apps.webui.routers import ( auths, users, @@ -27,7 +24,6 @@ from utils.task import prompt_template from config import ( - WEBUI_BUILD_HASH, SHOW_ADMIN_DETAILS, ADMIN_EMAIL, WEBUI_AUTH, @@ -55,7 +51,7 @@ import uuid import time import json -from typing import Iterator, Generator, AsyncGenerator, Optional +from typing import Iterator, Generator, AsyncGenerator from pydantic import BaseModel app = FastAPI() @@ -127,60 +123,60 @@ async def get_status(): } +def get_function_module(pipe_id: str): + # Check if function is already loaded + if pipe_id not in app.state.FUNCTIONS: + function_module, _, _ = load_function_module_by_id(pipe_id) + app.state.FUNCTIONS[pipe_id] = function_module + else: + function_module = app.state.FUNCTIONS[pipe_id] + + if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): + valves = Functions.get_function_valves_by_id(pipe_id) + function_module.valves = function_module.Valves(**(valves if valves else {})) + return function_module + + async def get_pipe_models(): pipes = Functions.get_functions_by_type("pipe", active_only=True) pipe_models = [] for pipe in pipes: - # Check if function is already loaded - if pipe.id not in app.state.FUNCTIONS: - function_module, function_type, frontmatter = load_function_module_by_id( - pipe.id - ) - app.state.FUNCTIONS[pipe.id] = function_module - else: - function_module = app.state.FUNCTIONS[pipe.id] - - if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): - valves = Functions.get_function_valves_by_id(pipe.id) - function_module.valves = function_module.Valves( - **(valves if valves else {}) - ) + function_module = get_function_module(pipe.id) # Check if function is a manifold if hasattr(function_module, "type"): - if function_module.type == "manifold": - manifold_pipes = [] + if not function_module.type == "manifold": + continue + manifold_pipes = [] - # Check if pipes is a function or a list - if callable(function_module.pipes): - manifold_pipes = function_module.pipes() - else: - manifold_pipes = function_module.pipes + # Check if pipes is a function or a list + if callable(function_module.pipes): + manifold_pipes = function_module.pipes() + else: + manifold_pipes = function_module.pipes - for p in manifold_pipes: - manifold_pipe_id = f'{pipe.id}.{p["id"]}' - manifold_pipe_name = p["name"] + for p in manifold_pipes: + manifold_pipe_id = f'{pipe.id}.{p["id"]}' + manifold_pipe_name = p["name"] - if hasattr(function_module, "name"): - manifold_pipe_name = ( - f"{function_module.name}{manifold_pipe_name}" - ) + if hasattr(function_module, "name"): + manifold_pipe_name = f"{function_module.name}{manifold_pipe_name}" - pipe_flag = {"type": pipe.type} - if hasattr(function_module, "ChatValves"): - pipe_flag["valves_spec"] = function_module.ChatValves.schema() + pipe_flag = {"type": pipe.type} + if hasattr(function_module, "ChatValves"): + pipe_flag["valves_spec"] = function_module.ChatValves.schema() - pipe_models.append( - { - "id": manifold_pipe_id, - "name": manifold_pipe_name, - "object": "model", - "created": pipe.created_at, - "owned_by": "openai", - "pipe": pipe_flag, - } - ) + pipe_models.append( + { + "id": manifold_pipe_id, + "name": manifold_pipe_name, + "object": "model", + "created": pipe.created_at, + "owned_by": "openai", + "pipe": pipe_flag, + } + ) else: pipe_flag = {"type": "pipe"} if hasattr(function_module, "ChatValves"): @@ -200,162 +196,179 @@ async def get_pipe_models(): return pipe_models +async def execute_pipe(pipe, params): + if inspect.iscoroutinefunction(pipe): + return await pipe(**params) + else: + return pipe(**params) + + +async def get_message(res: str | Generator | AsyncGenerator) -> str: + if isinstance(res, str): + return res + if isinstance(res, Generator): + return "".join(map(str, res)) + if isinstance(res, AsyncGenerator): + return "".join([str(stream) async for stream in res]) + + +def get_final_message(form_data: dict, message: str | None = None) -> dict: + choice = { + "index": 0, + "logprobs": None, + "finish_reason": "stop", + } + + # If message is None, we're dealing with a chunk + if not message: + choice["delta"] = {} + else: + choice["message"] = {"role": "assistant", "content": message} + + return { + "id": f"{form_data['model']}-{str(uuid.uuid4())}", + "created": int(time.time()), + "model": form_data["model"], + "object": "chat.completion" if message is not None else "chat.completion.chunk", + "choices": [choice], + } + + +def process_line(form_data: dict, line): + if isinstance(line, BaseModel): + line = line.model_dump_json() + line = f"data: {line}" + if isinstance(line, dict): + line = f"data: {json.dumps(line)}" + + try: + line = line.decode("utf-8") + except Exception: + pass + + if line.startswith("data:"): + return f"{line}\n\n" + else: + line = stream_message_template(form_data["model"], line) + return f"data: {json.dumps(line)}\n\n" + + +def get_pipe_id(form_data: dict) -> str: + pipe_id = form_data["model"] + if "." in pipe_id: + pipe_id, _ = pipe_id.split(".", 1) + print(pipe_id) + return pipe_id + + +def get_params_dict(pipe, form_data, user, extra_params, function_module): + pipe_id = get_pipe_id(form_data) + # Get the signature of the function + sig = inspect.signature(pipe) + params = {"body": form_data} + + for key, value in extra_params.items(): + if key in sig.parameters: + params[key] = value + + if "__user__" in sig.parameters: + __user__ = { + "id": user.id, + "email": user.email, + "name": user.name, + "role": user.role, + } + + try: + if hasattr(function_module, "UserValves"): + __user__["valves"] = function_module.UserValves( + **Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id) + ) + except Exception as e: + print(e) + + params["__user__"] = __user__ + return params + + async def generate_function_chat_completion(form_data, user): model_id = form_data.get("model") model_info = Models.get_model_by_id(model_id) - metadata = None - if "metadata" in form_data: - metadata = form_data["metadata"] - del form_data["metadata"] + metadata = form_data.pop("metadata", None) - __event_emitter__ = None - __event_call__ = None - __task__ = None + __event_emitter__ = __event_call__ = __task__ = None if metadata: - if ( - metadata.get("session_id") - and metadata.get("chat_id") - and metadata.get("message_id") - ): - __event_emitter__ = await get_event_emitter(metadata) - __event_call__ = await get_event_call(metadata) + if all(k in metadata for k in ("session_id", "chat_id", "message_id")): + __event_emitter__ = get_event_emitter(metadata) + __event_call__ = get_event_call(metadata) + __task__ = metadata.get("task", None) - if metadata.get("task"): - __task__ = metadata.get("task") + if not model_info: + return - if model_info: - if model_info.base_model_id: - form_data["model"] = model_info.base_model_id + if model_info.base_model_id: + form_data["model"] = model_info.base_model_id - model_info.params = model_info.params.model_dump() + params = model_info.params.model_dump() - if model_info.params: - if model_info.params.get("temperature", None) is not None: - form_data["temperature"] = float(model_info.params.get("temperature")) + if params: + mappings = { + "temperature": float, + "top_p": int, + "max_tokens": int, + "frequency_penalty": int, + "seed": lambda x: x, + "stop": lambda x: [bytes(s, "utf-8").decode("unicode_escape") for s in x], + } - if model_info.params.get("top_p", None): - form_data["top_p"] = int(model_info.params.get("top_p", None)) + for key, cast_func in mappings.items(): + if (value := params.get(key)) is not None: + form_data[key] = cast_func(value) - if model_info.params.get("max_tokens", None): - form_data["max_tokens"] = int(model_info.params.get("max_tokens", None)) - - if model_info.params.get("frequency_penalty", None): - form_data["frequency_penalty"] = int( - model_info.params.get("frequency_penalty", None) - ) - - if model_info.params.get("seed", None): - form_data["seed"] = model_info.params.get("seed", None) - - if model_info.params.get("stop", None): - form_data["stop"] = ( - [ - bytes(stop, "utf-8").decode("unicode_escape") - for stop in model_info.params["stop"] - ] - if model_info.params.get("stop", None) - else None - ) - - system = model_info.params.get("system", None) - if system: - system = prompt_template( - system, - **( - { - "user_name": user.name, - "user_location": ( - user.info.get("location") if user.info else None - ), - } - if user - else {} - ), - ) - # Check if the payload already has a system message - # If not, add a system message to the payload - if form_data.get("messages"): - for message in form_data["messages"]: - if message.get("role") == "system": - message["content"] = system + message["content"] - break - else: - form_data["messages"].insert( - 0, - { - "role": "system", - "content": system, - }, - ) + system = params.get("system", None) + if not system: + return + if user: + template_params = { + "user_name": user.name, + "user_location": user.info.get("location") if user.info else None, + } else: - pass + template_params = {} + + system = prompt_template(system, **template_params) + + # Check if the payload already has a system message + # If not, add a system message to the payload + for message in form_data.get("messages", []): + if message.get("role") == "system": + message["content"] = system + message["content"] + break + else: + if form_data.get("messages"): + form_data["messages"].insert(0, {"role": "system", "content": system}) + + extra_params = { + "__event_emitter__": __event_emitter__, + "__event_call__": __event_call__, + "__task__": __task__, + } async def job(): - pipe_id = form_data["model"] - if "." in pipe_id: - pipe_id, sub_pipe_id = pipe_id.split(".", 1) - print(pipe_id) - - # Check if function is already loaded - if pipe_id not in app.state.FUNCTIONS: - function_module, function_type, frontmatter = load_function_module_by_id( - pipe_id - ) - app.state.FUNCTIONS[pipe_id] = function_module - else: - function_module = app.state.FUNCTIONS[pipe_id] - - if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): - - valves = Functions.get_function_valves_by_id(pipe_id) - function_module.valves = function_module.Valves( - **(valves if valves else {}) - ) + pipe_id = get_pipe_id(form_data) + function_module = get_function_module(pipe_id) pipe = function_module.pipe - - # Get the signature of the function - sig = inspect.signature(pipe) - params = {"body": form_data} - - if "__user__" in sig.parameters: - __user__ = { - "id": user.id, - "email": user.email, - "name": user.name, - "role": user.role, - } - - try: - if hasattr(function_module, "UserValves"): - __user__["valves"] = function_module.UserValves( - **Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id) - ) - except Exception as e: - print(e) - - params = {**params, "__user__": __user__} - - if "__event_emitter__" in sig.parameters: - params = {**params, "__event_emitter__": __event_emitter__} - - if "__event_call__" in sig.parameters: - params = {**params, "__event_call__": __event_call__} - - if "__task__" in sig.parameters: - params = {**params, "__task__": __task__} + params = get_params_dict(pipe, form_data, user, extra_params, function_module) if form_data["stream"]: async def stream_content(): try: - if inspect.iscoroutinefunction(pipe): - res = await pipe(**params) - else: - res = pipe(**params) + res = await execute_pipe(pipe, params) # Directly return if the response is a StreamingResponse if isinstance(res, StreamingResponse): @@ -377,107 +390,32 @@ async def generate_function_chat_completion(form_data, user): if isinstance(res, Iterator): for line in res: - if isinstance(line, BaseModel): - line = line.model_dump_json() - line = f"data: {line}" - if isinstance(line, dict): - line = f"data: {json.dumps(line)}" - - try: - line = line.decode("utf-8") - except: - pass - - if line.startswith("data:"): - yield f"{line}\n\n" - else: - line = stream_message_template(form_data["model"], line) - yield f"data: {json.dumps(line)}\n\n" - - if isinstance(res, str) or isinstance(res, Generator): - finish_message = { - "id": f"{form_data['model']}-{str(uuid.uuid4())}", - "object": "chat.completion.chunk", - "created": int(time.time()), - "model": form_data["model"], - "choices": [ - { - "index": 0, - "delta": {}, - "logprobs": None, - "finish_reason": "stop", - } - ], - } - - yield f"data: {json.dumps(finish_message)}\n\n" - yield f"data: [DONE]" + yield process_line(form_data, line) if isinstance(res, AsyncGenerator): async for line in res: - if isinstance(line, BaseModel): - line = line.model_dump_json() - line = f"data: {line}" - if isinstance(line, dict): - line = f"data: {json.dumps(line)}" + yield process_line(form_data, line) - try: - line = line.decode("utf-8") - except: - pass - - if line.startswith("data:"): - yield f"{line}\n\n" - else: - line = stream_message_template(form_data["model"], line) - yield f"data: {json.dumps(line)}\n\n" + if isinstance(res, str) or isinstance(res, Generator): + finish_message = get_final_message(form_data) + yield f"data: {json.dumps(finish_message)}\n\n" + yield "data: [DONE]" return StreamingResponse(stream_content(), media_type="text/event-stream") else: - try: - if inspect.iscoroutinefunction(pipe): - res = await pipe(**params) - else: - res = pipe(**params) + res = await execute_pipe(pipe, params) - if isinstance(res, StreamingResponse): - return res except Exception as e: print(f"Error: {e}") return {"error": {"detail": str(e)}} - if isinstance(res, dict): + if isinstance(res, StreamingResponse) or isinstance(res, dict): return res - elif isinstance(res, BaseModel): + if isinstance(res, BaseModel): return res.model_dump() - else: - message = "" - if isinstance(res, str): - message = res - elif isinstance(res, Generator): - for stream in res: - message = f"{message}{stream}" - elif isinstance(res, AsyncGenerator): - async for stream in res: - message = f"{message}{stream}" - return { - "id": f"{form_data['model']}-{str(uuid.uuid4())}", - "object": "chat.completion", - "created": int(time.time()), - "model": form_data["model"], - "choices": [ - { - "index": 0, - "message": { - "role": "assistant", - "content": message, - }, - "logprobs": None, - "finish_reason": "stop", - } - ], - } + message = await get_message(res) + return get_final_message(form_data, message) return await job() diff --git a/backend/apps/webui/models/models.py b/backend/apps/webui/models/models.py index 3b128c7d6..8277d1d0b 100644 --- a/backend/apps/webui/models/models.py +++ b/backend/apps/webui/models/models.py @@ -1,13 +1,11 @@ -import json import logging -from typing import Optional +from typing import Optional, List from pydantic import BaseModel, ConfigDict -from sqlalchemy import String, Column, BigInteger, Text +from sqlalchemy import Column, BigInteger, Text from apps.webui.internal.db import Base, JSONField, get_db -from typing import List, Union, Optional from config import SRC_LOG_LEVELS import time @@ -113,7 +111,6 @@ class ModelForm(BaseModel): class ModelsTable: - def insert_new_model( self, form_data: ModelForm, user_id: str ) -> Optional[ModelModel]: @@ -126,9 +123,7 @@ class ModelsTable: } ) try: - with get_db() as db: - result = Model(**model.model_dump()) db.add(result) db.commit() @@ -144,13 +139,11 @@ class ModelsTable: def get_all_models(self) -> List[ModelModel]: with get_db() as db: - return [ModelModel.model_validate(model) for model in db.query(Model).all()] def get_model_by_id(self, id: str) -> Optional[ModelModel]: try: with get_db() as db: - model = db.get(Model, id) return ModelModel.model_validate(model) except: @@ -178,7 +171,6 @@ class ModelsTable: def delete_model_by_id(self, id: str) -> bool: try: with get_db() as db: - db.query(Model).filter_by(id=id).delete() db.commit() diff --git a/backend/main.py b/backend/main.py index 360f5f415..915abe333 100644 --- a/backend/main.py +++ b/backend/main.py @@ -13,8 +13,6 @@ import aiohttp import requests import mimetypes import shutil -import os -import uuid import inspect from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form @@ -29,7 +27,7 @@ from starlette.middleware.sessions import SessionMiddleware from starlette.responses import StreamingResponse, Response, RedirectResponse -from apps.socket.main import sio, app as socket_app, get_event_emitter, get_event_call +from apps.socket.main import app as socket_app, get_event_emitter, get_event_call from apps.ollama.main import ( app as ollama_app, get_all_models as get_ollama_models, @@ -639,10 +637,10 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): message_id = body["id"] del body["id"] - __event_emitter__ = await get_event_emitter( + __event_emitter__ = get_event_emitter( {"chat_id": chat_id, "message_id": message_id, "session_id": session_id} ) - __event_call__ = await get_event_call( + __event_call__ = get_event_call( {"chat_id": chat_id, "message_id": message_id, "session_id": session_id} ) @@ -1191,13 +1189,13 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)): status_code=r.status_code, content=res, ) - except: + except Exception: pass else: pass - __event_emitter__ = await get_event_emitter( + __event_emitter__ = get_event_emitter( { "chat_id": data["chat_id"], "message_id": data["id"], @@ -1205,7 +1203,7 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)): } ) - __event_call__ = await get_event_call( + __event_call__ = get_event_call( { "chat_id": data["chat_id"], "message_id": data["id"], @@ -1334,14 +1332,14 @@ async def chat_completed( ) model = app.state.MODELS[model_id] - __event_emitter__ = await get_event_emitter( + __event_emitter__ = get_event_emitter( { "chat_id": data["chat_id"], "message_id": data["id"], "session_id": data["session_id"], } ) - __event_call__ = await get_event_call( + __event_call__ = get_event_call( { "chat_id": data["chat_id"], "message_id": data["id"], @@ -1770,7 +1768,6 @@ class AddPipelineForm(BaseModel): @app.post("/api/pipelines/add") async def add_pipeline(form_data: AddPipelineForm, user=Depends(get_admin_user)): - r = None try: urlIdx = form_data.urlIdx @@ -1813,7 +1810,6 @@ class DeletePipelineForm(BaseModel): @app.delete("/api/pipelines/delete") async def delete_pipeline(form_data: DeletePipelineForm, user=Depends(get_admin_user)): - r = None try: urlIdx = form_data.urlIdx @@ -1891,7 +1887,6 @@ async def get_pipeline_valves( models = await get_all_models() r = None try: - url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]