This commit is contained in:
Timothy J. Baek 2024-06-24 11:17:18 -07:00
parent 74a4f642fd
commit 1c4e7f0324
2 changed files with 169 additions and 170 deletions

View File

@ -1,5 +1,6 @@
from fastapi import FastAPI, Depends from fastapi import FastAPI, Depends
from fastapi.routing import APIRoute from fastapi.routing import APIRoute
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from apps.webui.routers import ( from apps.webui.routers import (
auths, auths,
@ -17,6 +18,7 @@ from apps.webui.routers import (
) )
from apps.webui.models.functions import Functions from apps.webui.models.functions import Functions
from apps.webui.utils import load_function_module_by_id from apps.webui.utils import load_function_module_by_id
from utils.misc import stream_message_template
from config import ( from config import (
WEBUI_BUILD_HASH, WEBUI_BUILD_HASH,
@ -37,6 +39,14 @@ from config import (
AppConfig, AppConfig,
) )
import inspect
import uuid
import time
import json
from typing import Iterator, Generator
from pydantic import BaseModel
app = FastAPI() app = FastAPI()
origins = ["*"] origins = ["*"]
@ -166,3 +176,152 @@ async def get_pipe_models():
) )
return pipe_models return pipe_models
async def generate_function_chat_completion(form_data, user):
async def job():
pipe_id = form_data["model"]
if "." in pipe_id:
pipe_id, sub_pipe_id = pipe_id.split(".", 1)
print(pipe_id)
# Check if function is already loaded
if pipe_id not in app.state.FUNCTIONS:
function_module, function_type, frontmatter = load_function_module_by_id(
pipe_id
)
app.state.FUNCTIONS[pipe_id] = function_module
else:
function_module = app.state.FUNCTIONS[pipe_id]
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
valves = Functions.get_function_valves_by_id(pipe_id)
function_module.valves = function_module.Valves(
**(valves if valves else {})
)
pipe = function_module.pipe
# Get the signature of the function
sig = inspect.signature(pipe)
params = {"body": form_data}
if "__user__" in sig.parameters:
__user__ = {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
}
try:
if hasattr(function_module, "UserValves"):
__user__["valves"] = function_module.UserValves(
**Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id)
)
except Exception as e:
print(e)
params = {**params, "__user__": __user__}
if form_data["stream"]:
async def stream_content():
try:
if inspect.iscoroutinefunction(pipe):
res = await pipe(**params)
else:
res = pipe(**params)
except Exception as e:
print(f"Error: {e}")
yield f"data: {json.dumps({'error': {'detail':str(e)}})}\n\n"
return
if isinstance(res, str):
message = stream_message_template(form_data["model"], res)
yield f"data: {json.dumps(message)}\n\n"
if isinstance(res, Iterator):
for line in res:
if isinstance(line, BaseModel):
line = line.model_dump_json()
line = f"data: {line}"
try:
line = line.decode("utf-8")
except:
pass
if line.startswith("data:"):
yield f"{line}\n\n"
else:
line = stream_message_template(form_data["model"], line)
yield f"data: {json.dumps(line)}\n\n"
if isinstance(res, str) or isinstance(res, Generator):
finish_message = {
"id": f"{form_data['model']}-{str(uuid.uuid4())}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": form_data["model"],
"choices": [
{
"index": 0,
"delta": {},
"logprobs": None,
"finish_reason": "stop",
}
],
}
yield f"data: {json.dumps(finish_message)}\n\n"
yield f"data: [DONE]"
return StreamingResponse(stream_content(), media_type="text/event-stream")
else:
try:
if inspect.iscoroutinefunction(pipe):
res = await pipe(**params)
else:
res = pipe(**params)
except Exception as e:
print(f"Error: {e}")
return {"error": {"detail": str(e)}}
if inspect.iscoroutinefunction(pipe):
res = await pipe(**params)
else:
res = pipe(**params)
if isinstance(res, dict):
return res
elif isinstance(res, BaseModel):
return res.model_dump()
else:
message = ""
if isinstance(res, str):
message = res
if isinstance(res, Generator):
for stream in res:
message = f"{message}{stream}"
return {
"id": f"{form_data['model']}-{str(uuid.uuid4())}",
"object": "chat.completion",
"created": int(time.time()),
"model": form_data["model"],
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": message,
},
"logprobs": None,
"finish_reason": "stop",
}
],
}
return await job()

View File

@ -43,7 +43,11 @@ from apps.openai.main import (
from apps.audio.main import app as audio_app from apps.audio.main import app as audio_app
from apps.images.main import app as images_app from apps.images.main import app as images_app
from apps.rag.main import app as rag_app from apps.rag.main import app as rag_app
from apps.webui.main import app as webui_app, get_pipe_models from apps.webui.main import (
app as webui_app,
get_pipe_models,
generate_function_chat_completion,
)
from pydantic import BaseModel from pydantic import BaseModel
@ -228,10 +232,7 @@ async def get_function_call_response(
response = None response = None
try: try:
if model["owned_by"] == "ollama": response = await generate_chat_completions(form_data=payload, user=user)
response = await generate_ollama_chat_completion(payload, user=user)
else:
response = await generate_openai_chat_completion(payload, user=user)
content = None content = None
@ -900,159 +901,7 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
pipe = model.get("pipe") pipe = model.get("pipe")
if pipe: if pipe:
return await generate_function_chat_completion(form_data, user=user)
async def job():
pipe_id = form_data["model"]
if "." in pipe_id:
pipe_id, sub_pipe_id = pipe_id.split(".", 1)
print(pipe_id)
# Check if function is already loaded
if pipe_id not in webui_app.state.FUNCTIONS:
function_module, function_type, frontmatter = (
load_function_module_by_id(pipe_id)
)
webui_app.state.FUNCTIONS[pipe_id] = function_module
else:
function_module = webui_app.state.FUNCTIONS[pipe_id]
if hasattr(function_module, "valves") and hasattr(
function_module, "Valves"
):
valves = Functions.get_function_valves_by_id(pipe_id)
function_module.valves = function_module.Valves(
**(valves if valves else {})
)
pipe = function_module.pipe
# Get the signature of the function
sig = inspect.signature(pipe)
params = {"body": form_data}
if "__user__" in sig.parameters:
__user__ = {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
}
try:
if hasattr(function_module, "UserValves"):
__user__["valves"] = function_module.UserValves(
**Functions.get_user_valves_by_id_and_user_id(
pipe_id, user.id
)
)
except Exception as e:
print(e)
params = {**params, "__user__": __user__}
if form_data["stream"]:
async def stream_content():
try:
if inspect.iscoroutinefunction(pipe):
res = await pipe(**params)
else:
res = pipe(**params)
except Exception as e:
print(f"Error: {e}")
yield f"data: {json.dumps({'error': {'detail':str(e)}})}\n\n"
return
if isinstance(res, str):
message = stream_message_template(form_data["model"], res)
yield f"data: {json.dumps(message)}\n\n"
if isinstance(res, Iterator):
for line in res:
if isinstance(line, BaseModel):
line = line.model_dump_json()
line = f"data: {line}"
try:
line = line.decode("utf-8")
except:
pass
if line.startswith("data:"):
yield f"{line}\n\n"
else:
line = stream_message_template(form_data["model"], line)
yield f"data: {json.dumps(line)}\n\n"
if isinstance(res, str) or isinstance(res, Generator):
finish_message = {
"id": f"{form_data['model']}-{str(uuid.uuid4())}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": form_data["model"],
"choices": [
{
"index": 0,
"delta": {},
"logprobs": None,
"finish_reason": "stop",
}
],
}
yield f"data: {json.dumps(finish_message)}\n\n"
yield f"data: [DONE]"
return StreamingResponse(
stream_content(), media_type="text/event-stream"
)
else:
try:
if inspect.iscoroutinefunction(pipe):
res = await pipe(**params)
else:
res = pipe(**params)
except Exception as e:
print(f"Error: {e}")
return {"error": {"detail": str(e)}}
if inspect.iscoroutinefunction(pipe):
res = await pipe(**params)
else:
res = pipe(**params)
if isinstance(res, dict):
return res
elif isinstance(res, BaseModel):
return res.model_dump()
else:
message = ""
if isinstance(res, str):
message = res
if isinstance(res, Generator):
for stream in res:
message = f"{message}{stream}"
return {
"id": f"{form_data['model']}-{str(uuid.uuid4())}",
"object": "chat.completion",
"created": int(time.time()),
"model": form_data["model"],
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": message,
},
"logprobs": None,
"finish_reason": "stop",
}
],
}
return await job()
if model["owned_by"] == "ollama": if model["owned_by"] == "ollama":
return await generate_ollama_chat_completion(form_data, user=user) return await generate_ollama_chat_completion(form_data, user=user)
else: else:
@ -1334,10 +1183,7 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
content={"detail": e.args[1]}, content={"detail": e.args[1]},
) )
if model["owned_by"] == "ollama": return await generate_chat_completions(form_data=payload, user=user)
return await generate_ollama_chat_completion(payload, user=user)
else:
return await generate_openai_chat_completion(payload, user=user)
@app.post("/api/task/query/completions") @app.post("/api/task/query/completions")
@ -1397,10 +1243,7 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
content={"detail": e.args[1]}, content={"detail": e.args[1]},
) )
if model["owned_by"] == "ollama": return await generate_chat_completions(form_data=payload, user=user)
return await generate_ollama_chat_completion(payload, user=user)
else:
return await generate_openai_chat_completion(payload, user=user)
@app.post("/api/task/emoji/completions") @app.post("/api/task/emoji/completions")
@ -1464,10 +1307,7 @@ Message: """{{prompt}}"""
content={"detail": e.args[1]}, content={"detail": e.args[1]},
) )
if model["owned_by"] == "ollama": return await generate_chat_completions(form_data=payload, user=user)
return await generate_ollama_chat_completion(payload, user=user)
else:
return await generate_openai_chat_completion(payload, user=user)
@app.post("/api/task/tools/completions") @app.post("/api/task/tools/completions")