This commit is contained in:
Timothy J. Baek 2024-06-24 11:17:18 -07:00
parent 74a4f642fd
commit 1c4e7f0324
2 changed files with 169 additions and 170 deletions

View File

@ -1,5 +1,6 @@
from fastapi import FastAPI, Depends
from fastapi.routing import APIRoute
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from apps.webui.routers import (
auths,
@ -17,6 +18,7 @@ from apps.webui.routers import (
)
from apps.webui.models.functions import Functions
from apps.webui.utils import load_function_module_by_id
from utils.misc import stream_message_template
from config import (
WEBUI_BUILD_HASH,
@ -37,6 +39,14 @@ from config import (
AppConfig,
)
import inspect
import uuid
import time
import json
from typing import Iterator, Generator
from pydantic import BaseModel
app = FastAPI()
origins = ["*"]
@ -166,3 +176,152 @@ async def get_pipe_models():
)
return pipe_models
async def generate_function_chat_completion(form_data, user):
async def job():
pipe_id = form_data["model"]
if "." in pipe_id:
pipe_id, sub_pipe_id = pipe_id.split(".", 1)
print(pipe_id)
# Check if function is already loaded
if pipe_id not in app.state.FUNCTIONS:
function_module, function_type, frontmatter = load_function_module_by_id(
pipe_id
)
app.state.FUNCTIONS[pipe_id] = function_module
else:
function_module = app.state.FUNCTIONS[pipe_id]
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
valves = Functions.get_function_valves_by_id(pipe_id)
function_module.valves = function_module.Valves(
**(valves if valves else {})
)
pipe = function_module.pipe
# Get the signature of the function
sig = inspect.signature(pipe)
params = {"body": form_data}
if "__user__" in sig.parameters:
__user__ = {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
}
try:
if hasattr(function_module, "UserValves"):
__user__["valves"] = function_module.UserValves(
**Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id)
)
except Exception as e:
print(e)
params = {**params, "__user__": __user__}
if form_data["stream"]:
async def stream_content():
try:
if inspect.iscoroutinefunction(pipe):
res = await pipe(**params)
else:
res = pipe(**params)
except Exception as e:
print(f"Error: {e}")
yield f"data: {json.dumps({'error': {'detail':str(e)}})}\n\n"
return
if isinstance(res, str):
message = stream_message_template(form_data["model"], res)
yield f"data: {json.dumps(message)}\n\n"
if isinstance(res, Iterator):
for line in res:
if isinstance(line, BaseModel):
line = line.model_dump_json()
line = f"data: {line}"
try:
line = line.decode("utf-8")
except:
pass
if line.startswith("data:"):
yield f"{line}\n\n"
else:
line = stream_message_template(form_data["model"], line)
yield f"data: {json.dumps(line)}\n\n"
if isinstance(res, str) or isinstance(res, Generator):
finish_message = {
"id": f"{form_data['model']}-{str(uuid.uuid4())}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": form_data["model"],
"choices": [
{
"index": 0,
"delta": {},
"logprobs": None,
"finish_reason": "stop",
}
],
}
yield f"data: {json.dumps(finish_message)}\n\n"
yield f"data: [DONE]"
return StreamingResponse(stream_content(), media_type="text/event-stream")
else:
try:
if inspect.iscoroutinefunction(pipe):
res = await pipe(**params)
else:
res = pipe(**params)
except Exception as e:
print(f"Error: {e}")
return {"error": {"detail": str(e)}}
if inspect.iscoroutinefunction(pipe):
res = await pipe(**params)
else:
res = pipe(**params)
if isinstance(res, dict):
return res
elif isinstance(res, BaseModel):
return res.model_dump()
else:
message = ""
if isinstance(res, str):
message = res
if isinstance(res, Generator):
for stream in res:
message = f"{message}{stream}"
return {
"id": f"{form_data['model']}-{str(uuid.uuid4())}",
"object": "chat.completion",
"created": int(time.time()),
"model": form_data["model"],
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": message,
},
"logprobs": None,
"finish_reason": "stop",
}
],
}
return await job()

View File

@ -43,7 +43,11 @@ from apps.openai.main import (
from apps.audio.main import app as audio_app
from apps.images.main import app as images_app
from apps.rag.main import app as rag_app
from apps.webui.main import app as webui_app, get_pipe_models
from apps.webui.main import (
app as webui_app,
get_pipe_models,
generate_function_chat_completion,
)
from pydantic import BaseModel
@ -228,10 +232,7 @@ async def get_function_call_response(
response = None
try:
if model["owned_by"] == "ollama":
response = await generate_ollama_chat_completion(payload, user=user)
else:
response = await generate_openai_chat_completion(payload, user=user)
response = await generate_chat_completions(form_data=payload, user=user)
content = None
@ -900,159 +901,7 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
pipe = model.get("pipe")
if pipe:
async def job():
pipe_id = form_data["model"]
if "." in pipe_id:
pipe_id, sub_pipe_id = pipe_id.split(".", 1)
print(pipe_id)
# Check if function is already loaded
if pipe_id not in webui_app.state.FUNCTIONS:
function_module, function_type, frontmatter = (
load_function_module_by_id(pipe_id)
)
webui_app.state.FUNCTIONS[pipe_id] = function_module
else:
function_module = webui_app.state.FUNCTIONS[pipe_id]
if hasattr(function_module, "valves") and hasattr(
function_module, "Valves"
):
valves = Functions.get_function_valves_by_id(pipe_id)
function_module.valves = function_module.Valves(
**(valves if valves else {})
)
pipe = function_module.pipe
# Get the signature of the function
sig = inspect.signature(pipe)
params = {"body": form_data}
if "__user__" in sig.parameters:
__user__ = {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
}
try:
if hasattr(function_module, "UserValves"):
__user__["valves"] = function_module.UserValves(
**Functions.get_user_valves_by_id_and_user_id(
pipe_id, user.id
)
)
except Exception as e:
print(e)
params = {**params, "__user__": __user__}
if form_data["stream"]:
async def stream_content():
try:
if inspect.iscoroutinefunction(pipe):
res = await pipe(**params)
else:
res = pipe(**params)
except Exception as e:
print(f"Error: {e}")
yield f"data: {json.dumps({'error': {'detail':str(e)}})}\n\n"
return
if isinstance(res, str):
message = stream_message_template(form_data["model"], res)
yield f"data: {json.dumps(message)}\n\n"
if isinstance(res, Iterator):
for line in res:
if isinstance(line, BaseModel):
line = line.model_dump_json()
line = f"data: {line}"
try:
line = line.decode("utf-8")
except:
pass
if line.startswith("data:"):
yield f"{line}\n\n"
else:
line = stream_message_template(form_data["model"], line)
yield f"data: {json.dumps(line)}\n\n"
if isinstance(res, str) or isinstance(res, Generator):
finish_message = {
"id": f"{form_data['model']}-{str(uuid.uuid4())}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": form_data["model"],
"choices": [
{
"index": 0,
"delta": {},
"logprobs": None,
"finish_reason": "stop",
}
],
}
yield f"data: {json.dumps(finish_message)}\n\n"
yield f"data: [DONE]"
return StreamingResponse(
stream_content(), media_type="text/event-stream"
)
else:
try:
if inspect.iscoroutinefunction(pipe):
res = await pipe(**params)
else:
res = pipe(**params)
except Exception as e:
print(f"Error: {e}")
return {"error": {"detail": str(e)}}
if inspect.iscoroutinefunction(pipe):
res = await pipe(**params)
else:
res = pipe(**params)
if isinstance(res, dict):
return res
elif isinstance(res, BaseModel):
return res.model_dump()
else:
message = ""
if isinstance(res, str):
message = res
if isinstance(res, Generator):
for stream in res:
message = f"{message}{stream}"
return {
"id": f"{form_data['model']}-{str(uuid.uuid4())}",
"object": "chat.completion",
"created": int(time.time()),
"model": form_data["model"],
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": message,
},
"logprobs": None,
"finish_reason": "stop",
}
],
}
return await job()
return await generate_function_chat_completion(form_data, user=user)
if model["owned_by"] == "ollama":
return await generate_ollama_chat_completion(form_data, user=user)
else:
@ -1334,10 +1183,7 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
content={"detail": e.args[1]},
)
if model["owned_by"] == "ollama":
return await generate_ollama_chat_completion(payload, user=user)
else:
return await generate_openai_chat_completion(payload, user=user)
return await generate_chat_completions(form_data=payload, user=user)
@app.post("/api/task/query/completions")
@ -1397,10 +1243,7 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
content={"detail": e.args[1]},
)
if model["owned_by"] == "ollama":
return await generate_ollama_chat_completion(payload, user=user)
else:
return await generate_openai_chat_completion(payload, user=user)
return await generate_chat_completions(form_data=payload, user=user)
@app.post("/api/task/emoji/completions")
@ -1464,10 +1307,7 @@ Message: """{{prompt}}"""
content={"detail": e.args[1]},
)
if model["owned_by"] == "ollama":
return await generate_ollama_chat_completion(payload, user=user)
else:
return await generate_openai_chat_completion(payload, user=user)
return await generate_chat_completions(form_data=payload, user=user)
@app.post("/api/task/tools/completions")