import re import requests from fastapi import ( FastAPI, Request, Depends, HTTPException, status, UploadFile, File, Form, ) from fastapi.middleware.cors import CORSMiddleware from faster_whisper import WhisperModel from constants import ERROR_MESSAGES from utils.utils import ( get_current_user, get_admin_user, ) from utils.misc import calculate_sha256 from typing import Optional from pydantic import BaseModel from config import AUTOMATIC1111_BASE_URL app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL app.state.ENABLED = app.state.AUTOMATIC1111_BASE_URL != "" app.state.IMAGE_SIZE = "512x512" @app.get("/enabled", response_model=bool) async def get_enable_status(request: Request, user=Depends(get_admin_user)): return app.state.ENABLED @app.get("/enabled/toggle", response_model=bool) async def toggle_enabled(request: Request, user=Depends(get_admin_user)): try: r = requests.head(app.state.AUTOMATIC1111_BASE_URL) app.state.ENABLED = not app.state.ENABLED return app.state.ENABLED except Exception as e: raise HTTPException(status_code=r.status_code, detail=ERROR_MESSAGES.DEFAULT(e)) class UrlUpdateForm(BaseModel): url: str @app.get("/url") async def get_openai_url(user=Depends(get_admin_user)): return {"AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL} @app.post("/url/update") async def update_openai_url(form_data: UrlUpdateForm, user=Depends(get_admin_user)): if form_data.url == "": app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL else: app.state.AUTOMATIC1111_BASE_URL = form_data.url.strip("/") return { "AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL, "status": True, } class ImageSizeUpdateForm(BaseModel): size: str @app.get("/size") async def get_image_size(user=Depends(get_admin_user)): return {"IMAGE_SIZE": app.state.IMAGE_SIZE} @app.post("/size/update") async def update_image_size( form_data: ImageSizeUpdateForm, user=Depends(get_admin_user) ): pattern = r"^\d+x\d+$" # Regular expression pattern if re.match(pattern, form_data.size): app.state.IMAGE_SIZE = form_data.size return { "IMAGE_SIZE": app.state.IMAGE_SIZE, "status": True, } else: raise HTTPException( status_code=400, detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 512x512)."), ) @app.get("/models") def get_models(user=Depends(get_current_user)): try: r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models") models = r.json() return models except Exception as e: raise HTTPException(status_code=r.status_code, detail=ERROR_MESSAGES.DEFAULT(e)) @app.get("/models/default") async def get_default_model(user=Depends(get_admin_user)): try: r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options") options = r.json() return {"model": options["sd_model_checkpoint"]} except Exception as e: raise HTTPException(status_code=r.status_code, detail=ERROR_MESSAGES.DEFAULT(e)) class UpdateModelForm(BaseModel): model: str def set_model_handler(model: str): r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options") options = r.json() if model != options["sd_model_checkpoint"]: options["sd_model_checkpoint"] = model r = requests.post( url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", json=options ) return options @app.post("/models/default/update") def update_default_model( form_data: UpdateModelForm, user=Depends(get_current_user), ): return set_model_handler(form_data.model) class GenerateImageForm(BaseModel): model: Optional[str] = None prompt: str n: int = 1 size: str = "512x512" negative_prompt: Optional[str] = None @app.post("/generations") def generate_image( form_data: GenerateImageForm, user=Depends(get_current_user), ): print(form_data) try: if form_data.model: set_model_handler(form_data.model) width, height = tuple(map(int, app.state.IMAGE_SIZE.split("x"))) data = { "prompt": form_data.prompt, "batch_size": form_data.n, "width": width, "height": height, } if form_data.negative_prompt != None: data["negative_prompt"] = form_data.negative_prompt print(data) r = requests.post( url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img", json=data, ) return r.json() except Exception as e: print(e) raise HTTPException(status_code=r.status_code, detail=ERROR_MESSAGES.DEFAULT(e))