From cc50cc10e64644caa62113069e0247cb75ef502e Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Wed, 21 Feb 2024 18:36:40 -0800 Subject: [PATCH] feat: sd frontend integration --- example.env => .env.example | 2 + README.md | 2 +- backend/apps/images/main.py | 35 ++++--- src/lib/apis/images/index.ts | 99 +++++++++++++++++++ src/lib/components/chat/Messages.svelte | 1 + .../chat/Messages/ResponseMessage.svelte | 99 ++++++++++++++++--- .../components/chat/Settings/Images.svelte | 9 +- 7 files changed, 218 insertions(+), 29 deletions(-) rename example.env => .env.example (83%) diff --git a/example.env b/.env.example similarity index 83% rename from example.env rename to .env.example index 4a4fdaa6c..de763f31c 100644 --- a/example.env +++ b/.env.example @@ -5,6 +5,8 @@ OLLAMA_API_BASE_URL='http://localhost:11434/api' OPENAI_API_BASE_URL='' OPENAI_API_KEY='' +# AUTOMATIC1111_BASE_URL="http://localhost:7860" + # DO NOT TRACK SCARF_NO_ANALYTICS=true DO_NOT_TRACK=true \ No newline at end of file diff --git a/README.md b/README.md index ef18a0acc..7645418ad 100644 --- a/README.md +++ b/README.md @@ -283,7 +283,7 @@ git clone https://github.com/open-webui/open-webui.git cd open-webui/ # Copying required .env file -cp -RPp example.env .env +cp -RPp .env.example .env # Building Frontend Using Node npm i diff --git a/backend/apps/images/main.py b/backend/apps/images/main.py index c80903c14..4539f8066 100644 --- a/backend/apps/images/main.py +++ b/backend/apps/images/main.py @@ -33,7 +33,7 @@ app.add_middleware( ) app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL -app.state.ENABLED = False +app.state.ENABLED = app.state.AUTOMATIC1111_BASE_URL != "" @app.get("/enabled", response_model=bool) @@ -129,20 +129,33 @@ def generate_image( form_data: GenerateImageForm, user=Depends(get_current_user), ): - if form_data.model: - set_model_handler(form_data.model) - width, height = tuple(map(int, form_data.size.split("x"))) + print(form_data) - r = requests.get( - url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img", - json={ + try: + if form_data.model: + set_model_handler(form_data.model) + + width, height = tuple(map(int, form_data.size.split("x"))) + + data = { "prompt": form_data.prompt, - "negative_prompt": form_data.negative_prompt, "batch_size": form_data.n, "width": width, "height": height, - }, - ) + } - return r.json() + if form_data.negative_prompt != None: + data["negative_prompt"] = form_data.negative_prompt + + print(data) + + r = requests.post( + url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img", + json=data, + ) + + return r.json() + except Exception as e: + print(e) + raise HTTPException(status_code=r.status_code, detail=ERROR_MESSAGES.DEFAULT(e)) diff --git a/src/lib/apis/images/index.ts b/src/lib/apis/images/index.ts index 63bc04a95..b25499d64 100644 --- a/src/lib/apis/images/index.ts +++ b/src/lib/apis/images/index.ts @@ -1,5 +1,69 @@ import { IMAGES_API_BASE_URL } from '$lib/constants'; +export const getImageGenerationEnabledStatus = async (token: string = '') => { + let error = null; + + const res = await fetch(`${IMAGES_API_BASE_URL}/enabled`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + if ('detail' in err) { + error = err.detail; + } else { + error = 'Server connection failed'; + } + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const toggleImageGenerationEnabledStatus = async (token: string = '') => { + let error = null; + + const res = await fetch(`${IMAGES_API_BASE_URL}/enabled/toggle`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + if ('detail' in err) { + error = err.detail; + } else { + error = 'Server connection failed'; + } + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + export const getAUTOMATIC1111Url = async (token: string = '') => { let error = null; @@ -165,3 +229,38 @@ export const updateDefaultDiffusionModel = async (token: string = '', model: str return res.model; }; + +export const imageGenerations = async (token: string = '', prompt: string) => { + let error = null; + + const res = await fetch(`${IMAGES_API_BASE_URL}/generations`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + }, + body: JSON.stringify({ + prompt: prompt + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + if ('detail' in err) { + error = err.detail; + } else { + error = 'Server connection failed'; + } + return null; + }); + + if (error) { + throw error; + } + + return res; +}; diff --git a/src/lib/components/chat/Messages.svelte b/src/lib/components/chat/Messages.svelte index b02ba1166..071f715c5 100644 --- a/src/lib/components/chat/Messages.svelte +++ b/src/lib/components/chat/Messages.svelte @@ -11,6 +11,7 @@ import ResponseMessage from './Messages/ResponseMessage.svelte'; import Placeholder from './Messages/Placeholder.svelte'; import Spinner from '../common/Spinner.svelte'; + import { imageGenerations } from '$lib/apis/images'; export let chatId = ''; export let sendPrompt: Function; diff --git a/src/lib/components/chat/Messages/ResponseMessage.svelte b/src/lib/components/chat/Messages/ResponseMessage.svelte index 182a66b3d..d2925b049 100644 --- a/src/lib/components/chat/Messages/ResponseMessage.svelte +++ b/src/lib/components/chat/Messages/ResponseMessage.svelte @@ -16,6 +16,7 @@ import { synthesizeOpenAISpeech } from '$lib/apis/openai'; import { extractSentences } from '$lib/utils'; + import { imageGenerations } from '$lib/apis/images'; export let modelfiles = []; export let message; @@ -43,6 +44,8 @@ let loadingSpeech = false; + let generatingImage = false; + $: tokens = marked.lexer(message.content); const renderer = new marked.Renderer(); @@ -267,6 +270,21 @@ renderStyling(); }; + const generateImage = async (message) => { + generatingImage = true; + const res = await imageGenerations(localStorage.token, message.content); + console.log(res); + + if (res) { + message.files = res.images.map((image) => ({ + type: 'image', + url: `data:image/png;base64,${image}` + })); + } + + generatingImage = false; + }; + onMount(async () => { await tick(); renderStyling(); @@ -295,6 +313,18 @@ {#if message.content === ''} {:else} + {#if message.files} +
+ {#each message.files as file} +
+ {#if file.type === 'image'} + input + {/if} +
+ {/each} +
+ {/if} +
@@ -601,23 +631,62 @@ ? 'visible' : 'invisible group-hover:visible'} p-1 rounded dark:hover:text-white hover:text-black transition" on:click={() => { - // generateImage + if (!generatingImage) { + generateImage(message); + } }} > - - - + {#if generatingImage} + + {:else} + + + + {/if} {/if} diff --git a/src/lib/components/chat/Settings/Images.svelte b/src/lib/components/chat/Settings/Images.svelte index f9e8df5e8..d09d12e33 100644 --- a/src/lib/components/chat/Settings/Images.svelte +++ b/src/lib/components/chat/Settings/Images.svelte @@ -2,14 +2,17 @@ import toast from 'svelte-french-toast'; import { createEventDispatcher, onMount } from 'svelte'; - import { user } from '$lib/stores'; + import { config, user } from '$lib/stores'; import { getAUTOMATIC1111Url, getDefaultDiffusionModel, getDiffusionModels, + getImageGenerationEnabledStatus, + toggleImageGenerationEnabledStatus, updateAUTOMATIC1111Url, updateDefaultDiffusionModel } from '$lib/apis/images'; + import { getBackendConfig } from '$lib/apis'; const dispatch = createEventDispatcher(); export let saveSettings: Function; @@ -42,11 +45,13 @@ }; const toggleImageGeneration = async () => { - enableImageGeneration = !enableImageGeneration; + enableImageGeneration = await toggleImageGenerationEnabledStatus(localStorage.token); + config.set(await getBackendConfig(localStorage.token)); }; onMount(async () => { if ($user.role === 'admin') { + enableImageGeneration = await getImageGenerationEnabledStatus(localStorage.token); AUTOMATIC1111_BASE_URL = await getAUTOMATIC1111Url(localStorage.token); if (AUTOMATIC1111_BASE_URL) {