Merge pull request #2637 from open-webui/pipelines

feat: pipeline valves
This commit is contained in:
Timothy Jaeryang Baek 2024-05-28 13:06:46 -07:00 committed by GitHub
commit cf1c8be85f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 514 additions and 20 deletions

View File

@ -12,6 +12,7 @@ import mimetypes
from fastapi import FastAPI, Request, Depends, status from fastapi import FastAPI, Request, Depends, status
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from fastapi.responses import JSONResponse
from fastapi import HTTPException from fastapi import HTTPException
from fastapi.middleware.wsgi import WSGIMiddleware from fastapi.middleware.wsgi import WSGIMiddleware
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
@ -123,15 +124,6 @@ app.state.MODELS = {}
origins = ["*"] origins = ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Custom middleware to add security headers # Custom middleware to add security headers
# class SecurityHeadersMiddleware(BaseHTTPMiddleware): # class SecurityHeadersMiddleware(BaseHTTPMiddleware):
# async def dispatch(self, request: Request, call_next): # async def dispatch(self, request: Request, call_next):
@ -253,6 +245,7 @@ class PipelineMiddleware(BaseHTTPMiddleware):
model model
for model in app.state.MODELS.values() for model in app.state.MODELS.values()
if "pipeline" in model if "pipeline" in model
and "type" in model["pipeline"]
and model["pipeline"]["type"] == "filter" and model["pipeline"]["type"] == "filter"
and ( and (
model["pipeline"]["pipelines"] == ["*"] model["pipeline"]["pipelines"] == ["*"]
@ -276,10 +269,8 @@ class PipelineMiddleware(BaseHTTPMiddleware):
except: except:
pass pass
print(sorted_filters)
for filter in sorted_filters: for filter in sorted_filters:
r = None
try: try:
urlIdx = filter["urlIdx"] urlIdx = filter["urlIdx"]
@ -289,11 +280,10 @@ class PipelineMiddleware(BaseHTTPMiddleware):
if key != "": if key != "":
headers = {"Authorization": f"Bearer {key}"} headers = {"Authorization": f"Bearer {key}"}
r = requests.post( r = requests.post(
f"{url}/filter", f"{url}/{filter['id']}/filter",
headers=headers, headers=headers,
json={ json={
"user": user, "user": user,
"model": filter["id"],
"body": data, "body": data,
}, },
) )
@ -303,7 +293,20 @@ class PipelineMiddleware(BaseHTTPMiddleware):
except Exception as e: except Exception as e:
# Handle connection error here # Handle connection error here
print(f"Connection error: {e}") print(f"Connection error: {e}")
pass
if r is not None:
try:
res = r.json()
if "detail" in res:
return JSONResponse(
status_code=r.status_code,
content=res,
)
except:
pass
else:
pass
modified_body_bytes = json.dumps(data).encode("utf-8") modified_body_bytes = json.dumps(data).encode("utf-8")
# Replace the request body with the modified one # Replace the request body with the modified one
@ -328,6 +331,15 @@ class PipelineMiddleware(BaseHTTPMiddleware):
app.add_middleware(PipelineMiddleware) app.add_middleware(PipelineMiddleware)
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.middleware("http") @app.middleware("http")
async def check_url(request: Request, call_next): async def check_url(request: Request, call_next):
if len(app.state.MODELS) == 0: if len(app.state.MODELS) == 0:
@ -436,7 +448,7 @@ async def get_models(user=Depends(get_verified_user)):
models = [ models = [
model model
for model in models for model in models
if "pipeline" not in model or model["pipeline"]["type"] != "filter" if "pipeline" not in model or model["pipeline"].get("type", None) != "filter"
] ]
if app.state.config.ENABLE_MODEL_FILTER: if app.state.config.ENABLE_MODEL_FILTER:
@ -452,6 +464,164 @@ async def get_models(user=Depends(get_verified_user)):
return {"data": models} return {"data": models}
@app.get("/api/pipelines")
async def get_pipelines(user=Depends(get_admin_user)):
models = await get_all_models()
pipelines = [model for model in models if "pipeline" in model]
return {"data": pipelines}
@app.get("/api/pipelines/{pipeline_id}/valves")
async def get_pipeline_valves(pipeline_id: str, user=Depends(get_admin_user)):
models = await get_all_models()
if pipeline_id in app.state.MODELS and "pipeline" in app.state.MODELS[pipeline_id]:
pipeline = app.state.MODELS[pipeline_id]
r = None
try:
urlIdx = pipeline["urlIdx"]
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
if key != "":
headers = {"Authorization": f"Bearer {key}"}
r = requests.get(f"{url}/{pipeline['id']}/valves", headers=headers)
r.raise_for_status()
data = r.json()
return {**data}
except Exception as e:
# Handle connection error here
print(f"Connection error: {e}")
detail = "Pipeline not found"
if r is not None:
try:
res = r.json()
if "detail" in res:
detail = res["detail"]
except:
pass
raise HTTPException(
status_code=(
r.status_code if r is not None else status.HTTP_404_NOT_FOUND
),
detail=detail,
)
else:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Pipeline not found",
)
@app.get("/api/pipelines/{pipeline_id}/valves/spec")
async def get_pipeline_valves_spec(pipeline_id: str, user=Depends(get_admin_user)):
models = await get_all_models()
if pipeline_id in app.state.MODELS and "pipeline" in app.state.MODELS[pipeline_id]:
pipeline = app.state.MODELS[pipeline_id]
r = None
try:
urlIdx = pipeline["urlIdx"]
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
if key != "":
headers = {"Authorization": f"Bearer {key}"}
r = requests.get(f"{url}/{pipeline['id']}/valves/spec", headers=headers)
r.raise_for_status()
data = r.json()
return {**data}
except Exception as e:
# Handle connection error here
print(f"Connection error: {e}")
detail = "Pipeline not found"
if r is not None:
try:
res = r.json()
if "detail" in res:
detail = res["detail"]
except:
pass
raise HTTPException(
status_code=(
r.status_code if r is not None else status.HTTP_404_NOT_FOUND
),
detail=detail,
)
else:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Pipeline not found",
)
@app.post("/api/pipelines/{pipeline_id}/valves/update")
async def update_pipeline_valves(
pipeline_id: str, form_data: dict, user=Depends(get_admin_user)
):
models = await get_all_models()
if pipeline_id in app.state.MODELS and "pipeline" in app.state.MODELS[pipeline_id]:
pipeline = app.state.MODELS[pipeline_id]
r = None
try:
urlIdx = pipeline["urlIdx"]
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
if key != "":
headers = {"Authorization": f"Bearer {key}"}
r = requests.post(
f"{url}/{pipeline['id']}/valves/update",
headers=headers,
json={**form_data},
)
r.raise_for_status()
data = r.json()
return {**data}
except Exception as e:
# Handle connection error here
print(f"Connection error: {e}")
detail = "Pipeline not found"
if r is not None:
try:
res = r.json()
if "detail" in res:
detail = res["detail"]
except:
pass
raise HTTPException(
status_code=(
r.status_code if r is not None else status.HTTP_404_NOT_FOUND
),
detail=detail,
)
else:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Pipeline not found",
)
@app.get("/api/config") @app.get("/api/config")
async def get_app_config(): async def get_app_config():
# Checking and Handling the Absence of 'ui' in CONFIG_DATA # Checking and Handling the Absence of 'ui' in CONFIG_DATA

View File

@ -49,6 +49,129 @@ export const getModels = async (token: string = '') => {
return models; return models;
}; };
export const getPipelines = async (token: string = '') => {
let error = null;
const res = await fetch(`${WEBUI_BASE_URL}/api/pipelines`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
...(token && { authorization: `Bearer ${token}` })
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
console.log(err);
error = err;
return null;
});
if (error) {
throw error;
}
let pipelines = res?.data ?? [];
return pipelines;
};
export const getPipelineValves = async (token: string = '', pipeline_id: string) => {
let error = null;
const res = await fetch(`${WEBUI_BASE_URL}/api/pipelines/${pipeline_id}/valves`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
...(token && { authorization: `Bearer ${token}` })
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
console.log(err);
error = err;
return null;
});
if (error) {
throw error;
}
return res;
};
export const getPipelineValvesSpec = async (token: string = '', pipeline_id: string) => {
let error = null;
const res = await fetch(`${WEBUI_BASE_URL}/api/pipelines/${pipeline_id}/valves/spec`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
...(token && { authorization: `Bearer ${token}` })
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
console.log(err);
error = err;
return null;
});
if (error) {
throw error;
}
return res;
};
export const updatePipelineValves = async (
token: string = '',
pipeline_id: string,
valves: object
) => {
let error = null;
const res = await fetch(`${WEBUI_BASE_URL}/api/pipelines/${pipeline_id}/valves/update`, {
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
...(token && { authorization: `Bearer ${token}` })
},
body: JSON.stringify(valves)
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
console.log(err);
if ('detail' in err) {
error = err.detail;
} else {
error = err;
}
return null;
});
if (error) {
throw error;
}
return res;
};
export const getBackendConfig = async () => { export const getBackendConfig = async () => {
let error = null; let error = null;

View File

@ -0,0 +1,168 @@
<script lang="ts">
import { v4 as uuidv4 } from 'uuid';
import { getContext, onMount, tick } from 'svelte';
import type { Writable } from 'svelte/store';
import type { i18n as i18nType } from 'i18next';
import { stringify } from 'postcss';
import {
getPipelineValves,
getPipelineValvesSpec,
updatePipelineValves,
getPipelines
} from '$lib/apis';
import Spinner from '$lib/components/common/Spinner.svelte';
import { toast } from 'svelte-sonner';
const i18n: Writable<i18nType> = getContext('i18n');
export let saveHandler: Function;
let pipelines = null;
let valves = null;
let valves_spec = null;
let selectedPipelineIdx = null;
const updateHandler = async () => {
const pipeline = pipelines[selectedPipelineIdx];
if (pipeline && (pipeline?.pipeline?.valves ?? false)) {
const res = await updatePipelineValves(localStorage.token, pipeline.id, valves).catch(
(error) => {
toast.error(error);
}
);
if (res) {
toast.success('Valves updated successfully');
saveHandler();
}
} else {
toast.error('No valves to update');
}
};
onMount(async () => {
pipelines = await getPipelines(localStorage.token);
if (pipelines.length > 0) {
selectedPipelineIdx = 0;
}
});
</script>
<form
class="flex flex-col h-full justify-between space-y-3 text-sm"
on:submit|preventDefault={async () => {
updateHandler();
}}
>
<div class=" space-y-2 pr-1.5 overflow-y-scroll max-h-80 h-full">
{#if pipelines !== null && pipelines.length > 0}
<div class="flex w-full justify-between mb-2">
<div class=" self-center text-sm font-semibold">
{$i18n.t('Pipelines')}
</div>
</div>
<div class="space-y-1">
{#if pipelines.length > 0}
<div class="flex gap-2">
<div class="flex-1 pb-1">
<select
class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
bind:value={selectedPipelineIdx}
placeholder={$i18n.t('Select a pipeline')}
on:change={async () => {
await tick();
valves_spec = await getPipelineValvesSpec(
localStorage.token,
pipelines[selectedPipelineIdx].id
);
valves = await getPipelineValves(
localStorage.token,
pipelines[selectedPipelineIdx].id
);
}}
>
{#each pipelines as pipeline, idx}
<option value={idx} class="bg-gray-100 dark:bg-gray-700"
>{pipeline.name} ({pipeline.pipeline.type ?? 'pipe'})</option
>
{/each}
</select>
</div>
</div>
{/if}
<div class="text-sm font-medium">{$i18n.t('Valves')}</div>
<div class="space-y-1">
{#if pipelines[selectedPipelineIdx].pipeline.valves}
{#if valves}
{#each Object.keys(valves_spec.properties) as property, idx}
<div class=" py-0.5 w-full justify-between">
<div class="flex w-full justify-between">
<div class=" self-center text-xs font-medium">
{valves_spec.properties[property].title}
</div>
<button
class="p-1 px-3 text-xs flex rounded transition"
type="button"
on:click={() => {
valves[property] = (valves[property] ?? null) === null ? '' : null;
}}
>
{#if (valves[property] ?? null) === null}
<span class="ml-2 self-center"> {$i18n.t('None')} </span>
{:else}
<span class="ml-2 self-center"> {$i18n.t('Custom')} </span>
{/if}
</button>
</div>
{#if (valves[property] ?? null) !== null}
<div class="flex mt-0.5 space-x-2">
<div class=" flex-1">
<input
class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
type="text"
placeholder={valves_spec.properties[property].title}
bind:value={valves[property]}
autocomplete="off"
/>
</div>
</div>
{/if}
</div>
{/each}
{:else}
<Spinner className="size-5" />
{/if}
{:else}
<div>No valves</div>
{/if}
</div>
</div>
{:else if pipelines !== null && pipelines.length === 0}
<div>Pipelines Not Detected</div>
{:else}
<div class="flex h-full justify-center">
<div class="my-auto">
<Spinner className="size-6" />
</div>
</div>
{/if}
</div>
<div class="flex justify-end pt-3 text-sm font-medium">
<button
class=" px-4 py-2 bg-emerald-700 hover:bg-emerald-800 text-gray-100 transition rounded-lg"
type="submit"
>
Save
</button>
</div>
</form>

View File

@ -8,6 +8,7 @@
import Banners from '$lib/components/admin/Settings/Banners.svelte'; import Banners from '$lib/components/admin/Settings/Banners.svelte';
import { toast } from 'svelte-sonner'; import { toast } from 'svelte-sonner';
import Pipelines from './Settings/Pipelines.svelte';
const i18n = getContext('i18n'); const i18n = getContext('i18n');
@ -149,33 +150,65 @@
</div> </div>
<div class=" self-center">{$i18n.t('Banners')}</div> <div class=" self-center">{$i18n.t('Banners')}</div>
</button> </button>
<button
class="px-2.5 py-2.5 min-w-fit rounded-lg flex-1 md:flex-none flex text-right transition {selectedTab ===
'pipelines'
? 'bg-gray-200 dark:bg-gray-700'
: ' hover:bg-gray-300 dark:hover:bg-gray-800'}"
on:click={() => {
selectedTab = 'pipelines';
}}
>
<div class=" self-center mr-2">
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 24 24"
fill="currentColor"
class="size-4"
>
<path
d="M11.644 1.59a.75.75 0 0 1 .712 0l9.75 5.25a.75.75 0 0 1 0 1.32l-9.75 5.25a.75.75 0 0 1-.712 0l-9.75-5.25a.75.75 0 0 1 0-1.32l9.75-5.25Z"
/>
<path
d="m3.265 10.602 7.668 4.129a2.25 2.25 0 0 0 2.134 0l7.668-4.13 1.37.739a.75.75 0 0 1 0 1.32l-9.75 5.25a.75.75 0 0 1-.71 0l-9.75-5.25a.75.75 0 0 1 0-1.32l1.37-.738Z"
/>
<path
d="m10.933 19.231-7.668-4.13-1.37.739a.75.75 0 0 0 0 1.32l9.75 5.25c.221.12.489.12.71 0l9.75-5.25a.75.75 0 0 0 0-1.32l-1.37-.738-7.668 4.13a2.25 2.25 0 0 1-2.134-.001Z"
/>
</svg>
</div>
<div class=" self-center">{$i18n.t('Pipelines')}</div>
</button>
</div> </div>
<div class="flex-1 md:min-h-[380px]"> <div class="flex-1 md:min-h-[380px]">
{#if selectedTab === 'general'} {#if selectedTab === 'general'}
<General <General
saveHandler={() => { saveHandler={() => {
show = false;
toast.success($i18n.t('Settings saved successfully!')); toast.success($i18n.t('Settings saved successfully!'));
}} }}
/> />
{:else if selectedTab === 'users'} {:else if selectedTab === 'users'}
<Users <Users
saveHandler={() => { saveHandler={() => {
show = false;
toast.success($i18n.t('Settings saved successfully!')); toast.success($i18n.t('Settings saved successfully!'));
}} }}
/> />
{:else if selectedTab === 'db'} {:else if selectedTab === 'db'}
<Database <Database
saveHandler={() => { saveHandler={() => {
show = false;
toast.success($i18n.t('Settings saved successfully!')); toast.success($i18n.t('Settings saved successfully!'));
}} }}
/> />
{:else if selectedTab === 'banners'} {:else if selectedTab === 'banners'}
<Banners <Banners
saveHandler={() => { saveHandler={() => {
show = false; toast.success($i18n.t('Settings saved successfully!'));
}}
/>
{:else if selectedTab === 'pipelines'}
<Pipelines
saveHandler={() => {
toast.success($i18n.t('Settings saved successfully!')); toast.success($i18n.t('Settings saved successfully!'));
}} }}
/> />