This commit is contained in:
Timothy J. Baek 2024-07-01 19:37:54 -07:00
parent c7a9b5ccfa
commit e5895af7a0

View File

@ -212,6 +212,70 @@ origins = ["*"]
##################################
async def get_body_and_model_and_user(request):
# Read the original request body
body = await request.body()
body_str = body.decode("utf-8")
body = json.loads(body_str) if body_str else {}
model_id = body["model"]
if model_id not in app.state.MODELS:
raise "Model not found"
model = app.state.MODELS[model_id]
user = get_current_user(
request,
get_http_authorization_cred(request.headers.get("Authorization")),
)
return body, model, user
def get_task_model_id(default_model_id):
# Set the task model
task_model_id = default_model_id
# Check if the user has a custom task model and use that model
if app.state.MODELS[task_model_id]["owned_by"] == "ollama":
if (
app.state.config.TASK_MODEL
and app.state.config.TASK_MODEL in app.state.MODELS
):
task_model_id = app.state.config.TASK_MODEL
else:
if (
app.state.config.TASK_MODEL_EXTERNAL
and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS
):
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
return task_model_id
def get_filter_function_ids(model):
def get_priority(function_id):
function = Functions.get_function_by_id(function_id)
if function is not None and hasattr(function, "valves"):
return (function.valves if function.valves else {}).get("priority", 0)
return 0
filter_ids = [function.id for function in Functions.get_global_filter_functions()]
if "info" in model and "meta" in model["info"]:
filter_ids.extend(model["info"]["meta"].get("filterIds", []))
filter_ids = list(set(filter_ids))
enabled_filter_ids = [
function.id
for function in Functions.get_functions_by_type("filter", active_only=True)
]
filter_ids = [
filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids
]
filter_ids.sort(key=get_priority)
return filter_ids
async def get_function_call_response(
messages, files, tool_id, template, task_model_id, user, model
):
@ -373,51 +437,6 @@ async def get_function_call_response(
return None, None, False
def get_task_model_id(default_model_id):
# Set the task model
task_model_id = default_model_id
# Check if the user has a custom task model and use that model
if app.state.MODELS[task_model_id]["owned_by"] == "ollama":
if (
app.state.config.TASK_MODEL
and app.state.config.TASK_MODEL in app.state.MODELS
):
task_model_id = app.state.config.TASK_MODEL
else:
if (
app.state.config.TASK_MODEL_EXTERNAL
and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS
):
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
return task_model_id
def get_filter_function_ids(model):
def get_priority(function_id):
function = Functions.get_function_by_id(function_id)
if function is not None and hasattr(function, "valves"):
return (function.valves if function.valves else {}).get("priority", 0)
return 0
filter_ids = [function.id for function in Functions.get_global_filter_functions()]
if "info" in model and "meta" in model["info"]:
filter_ids.extend(model["info"]["meta"].get("filterIds", []))
filter_ids = list(set(filter_ids))
enabled_filter_ids = [
function.id
for function in Functions.get_functions_by_type("filter", active_only=True)
]
filter_ids = [
filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids
]
filter_ids.sort(key=get_priority)
return filter_ids
async def chat_completion_functions_handler(body, model, user):
skip_files = None
@ -579,25 +598,6 @@ async def chat_completion_files_handler(body):
}
async def get_body_and_model_and_user(request):
# Read the original request body
body = await request.body()
body_str = body.decode("utf-8")
body = json.loads(body_str) if body_str else {}
model_id = body["model"]
if model_id not in app.state.MODELS:
raise "Model not found"
model = app.state.MODELS[model_id]
user = get_current_user(
request,
get_http_authorization_cred(request.headers.get("Authorization")),
)
return body, model, user
class ChatCompletionMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
if request.method == "POST" and any(