diff --git a/example.env b/.env.example
similarity index 83%
rename from example.env
rename to .env.example
index 4a4fdaa6c..de763f31c 100644
--- a/example.env
+++ b/.env.example
@@ -5,6 +5,8 @@ OLLAMA_API_BASE_URL='http://localhost:11434/api'
OPENAI_API_BASE_URL=''
OPENAI_API_KEY=''
+# AUTOMATIC1111_BASE_URL="http://localhost:7860"
+
# DO NOT TRACK
SCARF_NO_ANALYTICS=true
DO_NOT_TRACK=true
\ No newline at end of file
diff --git a/README.md b/README.md
index ef18a0acc..7645418ad 100644
--- a/README.md
+++ b/README.md
@@ -283,7 +283,7 @@ git clone https://github.com/open-webui/open-webui.git
cd open-webui/
# Copying required .env file
-cp -RPp example.env .env
+cp -RPp .env.example .env
# Building Frontend Using Node
npm i
diff --git a/backend/apps/images/main.py b/backend/apps/images/main.py
index c80903c14..4539f8066 100644
--- a/backend/apps/images/main.py
+++ b/backend/apps/images/main.py
@@ -33,7 +33,7 @@ app.add_middleware(
)
app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
-app.state.ENABLED = False
+app.state.ENABLED = app.state.AUTOMATIC1111_BASE_URL != ""
@app.get("/enabled", response_model=bool)
@@ -129,20 +129,33 @@ def generate_image(
form_data: GenerateImageForm,
user=Depends(get_current_user),
):
- if form_data.model:
- set_model_handler(form_data.model)
- width, height = tuple(map(int, form_data.size.split("x")))
+ print(form_data)
- r = requests.get(
- url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
- json={
+ try:
+ if form_data.model:
+ set_model_handler(form_data.model)
+
+ width, height = tuple(map(int, form_data.size.split("x")))
+
+ data = {
"prompt": form_data.prompt,
- "negative_prompt": form_data.negative_prompt,
"batch_size": form_data.n,
"width": width,
"height": height,
- },
- )
+ }
- return r.json()
+ if form_data.negative_prompt != None:
+ data["negative_prompt"] = form_data.negative_prompt
+
+ print(data)
+
+ r = requests.post(
+ url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
+ json=data,
+ )
+
+ return r.json()
+ except Exception as e:
+ print(e)
+ raise HTTPException(status_code=r.status_code, detail=ERROR_MESSAGES.DEFAULT(e))
diff --git a/src/lib/apis/images/index.ts b/src/lib/apis/images/index.ts
index 63bc04a95..b25499d64 100644
--- a/src/lib/apis/images/index.ts
+++ b/src/lib/apis/images/index.ts
@@ -1,5 +1,69 @@
import { IMAGES_API_BASE_URL } from '$lib/constants';
+export const getImageGenerationEnabledStatus = async (token: string = '') => {
+ let error = null;
+
+ const res = await fetch(`${IMAGES_API_BASE_URL}/enabled`, {
+ method: 'GET',
+ headers: {
+ Accept: 'application/json',
+ 'Content-Type': 'application/json',
+ ...(token && { authorization: `Bearer ${token}` })
+ }
+ })
+ .then(async (res) => {
+ if (!res.ok) throw await res.json();
+ return res.json();
+ })
+ .catch((err) => {
+ console.log(err);
+ if ('detail' in err) {
+ error = err.detail;
+ } else {
+ error = 'Server connection failed';
+ }
+ return null;
+ });
+
+ if (error) {
+ throw error;
+ }
+
+ return res;
+};
+
+export const toggleImageGenerationEnabledStatus = async (token: string = '') => {
+ let error = null;
+
+ const res = await fetch(`${IMAGES_API_BASE_URL}/enabled/toggle`, {
+ method: 'GET',
+ headers: {
+ Accept: 'application/json',
+ 'Content-Type': 'application/json',
+ ...(token && { authorization: `Bearer ${token}` })
+ }
+ })
+ .then(async (res) => {
+ if (!res.ok) throw await res.json();
+ return res.json();
+ })
+ .catch((err) => {
+ console.log(err);
+ if ('detail' in err) {
+ error = err.detail;
+ } else {
+ error = 'Server connection failed';
+ }
+ return null;
+ });
+
+ if (error) {
+ throw error;
+ }
+
+ return res;
+};
+
export const getAUTOMATIC1111Url = async (token: string = '') => {
let error = null;
@@ -165,3 +229,38 @@ export const updateDefaultDiffusionModel = async (token: string = '', model: str
return res.model;
};
+
+export const imageGenerations = async (token: string = '', prompt: string) => {
+ let error = null;
+
+ const res = await fetch(`${IMAGES_API_BASE_URL}/generations`, {
+ method: 'POST',
+ headers: {
+ Accept: 'application/json',
+ 'Content-Type': 'application/json',
+ ...(token && { authorization: `Bearer ${token}` })
+ },
+ body: JSON.stringify({
+ prompt: prompt
+ })
+ })
+ .then(async (res) => {
+ if (!res.ok) throw await res.json();
+ return res.json();
+ })
+ .catch((err) => {
+ console.log(err);
+ if ('detail' in err) {
+ error = err.detail;
+ } else {
+ error = 'Server connection failed';
+ }
+ return null;
+ });
+
+ if (error) {
+ throw error;
+ }
+
+ return res;
+};
diff --git a/src/lib/components/chat/Messages.svelte b/src/lib/components/chat/Messages.svelte
index b02ba1166..071f715c5 100644
--- a/src/lib/components/chat/Messages.svelte
+++ b/src/lib/components/chat/Messages.svelte
@@ -11,6 +11,7 @@
import ResponseMessage from './Messages/ResponseMessage.svelte';
import Placeholder from './Messages/Placeholder.svelte';
import Spinner from '../common/Spinner.svelte';
+ import { imageGenerations } from '$lib/apis/images';
export let chatId = '';
export let sendPrompt: Function;
diff --git a/src/lib/components/chat/Messages/ResponseMessage.svelte b/src/lib/components/chat/Messages/ResponseMessage.svelte
index 182a66b3d..d2925b049 100644
--- a/src/lib/components/chat/Messages/ResponseMessage.svelte
+++ b/src/lib/components/chat/Messages/ResponseMessage.svelte
@@ -16,6 +16,7 @@
import { synthesizeOpenAISpeech } from '$lib/apis/openai';
import { extractSentences } from '$lib/utils';
+ import { imageGenerations } from '$lib/apis/images';
export let modelfiles = [];
export let message;
@@ -43,6 +44,8 @@
let loadingSpeech = false;
+ let generatingImage = false;
+
$: tokens = marked.lexer(message.content);
const renderer = new marked.Renderer();
@@ -267,6 +270,21 @@
renderStyling();
};
+ const generateImage = async (message) => {
+ generatingImage = true;
+ const res = await imageGenerations(localStorage.token, message.content);
+ console.log(res);
+
+ if (res) {
+ message.files = res.images.map((image) => ({
+ type: 'image',
+ url: `data:image/png;base64,${image}`
+ }));
+ }
+
+ generatingImage = false;
+ };
+
onMount(async () => {
await tick();
renderStyling();
@@ -295,6 +313,18 @@
{#if message.content === ''}