diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py index bb211cfe4..b639d949c 100644 --- a/backend/apps/ollama/main.py +++ b/backend/apps/ollama/main.py @@ -1,4 +1,4 @@ -from flask import Flask, request, Response +from flask import Flask, request, Response, jsonify from flask_cors import CORS @@ -6,7 +6,10 @@ import requests import json -from config import OLLAMA_API_BASE_URL +from apps.web.models.users import Users +from constants import ERROR_MESSAGES +from utils import extract_token_from_auth_header +from config import OLLAMA_API_BASE_URL, OLLAMA_WEBUI_AUTH app = Flask(__name__) CORS( @@ -28,6 +31,21 @@ def proxy(path): data = request.get_data() headers = dict(request.headers) + if OLLAMA_WEBUI_AUTH: + if "Authorization" in headers: + token = extract_token_from_auth_header(headers["Authorization"]) + user = Users.get_user_by_token(token) + if user: + print(user) + pass + else: + return jsonify({"detail": ERROR_MESSAGES.UNAUTHORIZED}), 401 + else: + return jsonify({"detail": ERROR_MESSAGES.UNAUTHORIZED}), 401 + + else: + pass + # Make a request to the target server target_response = requests.request( method=request.method, diff --git a/backend/apps/web/main.py b/backend/apps/web/main.py new file mode 100644 index 000000000..27b857888 --- /dev/null +++ b/backend/apps/web/main.py @@ -0,0 +1,25 @@ +from fastapi import FastAPI, Request, Depends, HTTPException +from fastapi.middleware.cors import CORSMiddleware + +from apps.web.routers import auths +from config import OLLAMA_WEBUI_VERSION, OLLAMA_WEBUI_AUTH + +app = FastAPI() + +origins = ["*"] + +app.add_middleware( + CORSMiddleware, + allow_origins=origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + +app.include_router(auths.router, prefix="/auths", tags=["auths"]) + + +@app.get("/") +async def get_status(): + return {"status": True, "version": OLLAMA_WEBUI_VERSION, "auth": OLLAMA_WEBUI_AUTH} diff --git a/backend/apps/web/models/auths.py b/backend/apps/web/models/auths.py new file mode 100644 index 000000000..f00a09cb8 --- /dev/null +++ b/backend/apps/web/models/auths.py @@ -0,0 +1,102 @@ +from pydantic import BaseModel +from typing import List, Union, Optional +import time +import uuid + + +from apps.web.models.users import UserModel, Users +from utils import ( + verify_password, + get_password_hash, + bearer_scheme, + create_token, +) + +import config + +DB = config.DB + +#################### +# DB MODEL +#################### + + +class AuthModel(BaseModel): + id: str + email: str + password: str + active: bool = True + + +#################### +# Forms +#################### + + +class Token(BaseModel): + token: str + token_type: str + + +class UserResponse(BaseModel): + id: str + email: str + name: str + role: str + + +class SigninResponse(Token, UserResponse): + pass + + +class SigninForm(BaseModel): + email: str + password: str + + +class SignupForm(BaseModel): + name: str + email: str + password: str + + +class AuthsTable: + def __init__(self, db): + self.db = db + self.table = db.auths + + def insert_new_auth( + self, email: str, password: str, name: str, role: str = "user" + ) -> Optional[UserModel]: + print("insert_new_auth") + + id = str(uuid.uuid4()) + + auth = AuthModel( + **{"id": id, "email": email, "password": password, "active": True} + ) + result = self.table.insert_one(auth.model_dump()) + user = Users.insert_new_user(id, name, email, role) + + print(result, user) + if result and user: + return user + else: + return None + + def authenticate_user(self, email: str, password: str) -> Optional[UserModel]: + print("authenticate_user") + + auth = self.table.find_one({"email": email, "active": True}) + + if auth: + if verify_password(password, auth["password"]): + user = self.db.users.find_one({"id": auth["id"]}) + return UserModel(**user) + else: + return None + else: + return None + + +Auths = AuthsTable(DB) diff --git a/backend/apps/web/models/users.py b/backend/apps/web/models/users.py new file mode 100644 index 000000000..0ea682140 --- /dev/null +++ b/backend/apps/web/models/users.py @@ -0,0 +1,76 @@ +from pydantic import BaseModel +from typing import List, Union, Optional +from pymongo import ReturnDocument +import time + +from utils import decode_token +from config import DB + +#################### +# User DB Schema +#################### + + +class UserModel(BaseModel): + id: str + name: str + email: str + role: str = "user" + created_at: int # timestamp in epoch + + +#################### +# Forms +#################### + + +class UsersTable: + def __init__(self, db): + self.db = db + self.table = db.users + + def insert_new_user( + self, id: str, name: str, email: str, role: str = "user" + ) -> Optional[UserModel]: + user = UserModel( + **{ + "id": id, + "name": name, + "email": email, + "role": role, + "created_at": int(time.time()), + } + ) + result = self.table.insert_one(user.model_dump()) + + if result: + return user + else: + return None + + def get_user_by_email(self, email: str) -> Optional[UserModel]: + user = self.table.find_one({"email": email}, {"_id": False}) + + if user: + return UserModel(**user) + else: + return None + + def get_user_by_token(self, token: str) -> Optional[UserModel]: + data = decode_token(token) + + if data != None and "email" in data: + return self.get_user_by_email(data["email"]) + else: + return None + + def get_users(self, skip: int = 0, limit: int = 50) -> Optional[UserModel]: + return [ + UserModel(**user) + for user in list(self.table.find({}, {"_id": False})) + .skip(skip) + .limit(limit) + ] + + +Users = UsersTable(DB) diff --git a/backend/apps/web/routers/auths.py b/backend/apps/web/routers/auths.py new file mode 100644 index 000000000..9ed7b187d --- /dev/null +++ b/backend/apps/web/routers/auths.py @@ -0,0 +1,107 @@ +from fastapi import Response +from fastapi import Depends, FastAPI, HTTPException, status +from datetime import datetime, timedelta +from typing import List, Union + +from fastapi import APIRouter +from pydantic import BaseModel +import time +import uuid + +from constants import ERROR_MESSAGES +from utils import ( + get_password_hash, + bearer_scheme, + create_token, +) + +from apps.web.models.auths import ( + SigninForm, + SignupForm, + UserResponse, + SigninResponse, + Auths, +) +from apps.web.models.users import Users +import config + +router = APIRouter() + +DB = config.DB + + +############################ +# GetSessionUser +############################ + + +@router.get("/", response_model=UserResponse) +async def get_session_user(cred=Depends(bearer_scheme)): + token = cred.credentials + user = Users.get_user_by_token(token) + if user: + return { + "id": user.id, + "email": user.email, + "name": user.name, + "role": user.role, + } + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + ) + + +############################ +# SignIn +############################ + + +@router.post("/signin", response_model=SigninResponse) +async def signin(form_data: SigninForm): + user = Auths.authenticate_user(form_data.email.lower(), form_data.password) + if user: + token = create_token(data={"email": user.email}) + + return { + "token": token, + "token_type": "Bearer", + "id": user.id, + "email": user.email, + "name": user.name, + "role": user.role, + } + else: + raise HTTPException(400, detail=ERROR_MESSAGES.DEFAULT()) + + +############################ +# SignUp +############################ + + +@router.post("/signup", response_model=SigninResponse) +async def signup(form_data: SignupForm): + if not Users.get_user_by_email(form_data.email.lower()): + try: + hashed = get_password_hash(form_data.password) + user = Auths.insert_new_auth(form_data.email, hashed, form_data.name) + + if user: + token = create_token(data={"email": user.email}) + # response.set_cookie(key='token', value=token, httponly=True) + + return { + "token": token, + "token_type": "Bearer", + "id": user.id, + "email": user.email, + "name": user.name, + "role": user.role, + } + else: + raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err)) + except Exception as err: + raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err)) + else: + raise HTTPException(400, detail=ERROR_MESSAGES.DEFAULT()) diff --git a/backend/config.py b/backend/config.py index 2a33818d9..3b67ef7f8 100644 --- a/backend/config.py +++ b/backend/config.py @@ -1,11 +1,22 @@ -import sys -import os from dotenv import load_dotenv, find_dotenv +from pymongo import MongoClient + +from secrets import token_bytes +from base64 import b64encode +import os load_dotenv(find_dotenv()) +#################################### +# ENV (dev,test,prod) +#################################### + ENV = os.environ.get("ENV", "dev") +#################################### +# OLLAMA_API_BASE_URL +#################################### + OLLAMA_API_BASE_URL = os.environ.get( "OLLAMA_API_BASE_URL", "http://localhost:11434/api" ) @@ -13,3 +24,42 @@ OLLAMA_API_BASE_URL = os.environ.get( if ENV == "prod": if OLLAMA_API_BASE_URL == "/ollama/api": OLLAMA_API_BASE_URL = "http://host.docker.internal:11434/api" + +#################################### +# OLLAMA_WEBUI_VERSION +#################################### + +OLLAMA_WEBUI_VERSION = os.environ.get("OLLAMA_WEBUI_VERSION", "v1.0.0-alpha.9") + +#################################### +# OLLAMA_WEBUI_AUTH +#################################### + +OLLAMA_WEBUI_AUTH = ( + True if os.environ.get("OLLAMA_WEBUI_AUTH", "TRUE") == "TRUE" else False +) + + +if OLLAMA_WEBUI_AUTH: + #################################### + # OLLAMA_WEBUI_DB + #################################### + + OLLAMA_WEBUI_DB_URL = os.environ.get( + "OLLAMA_WEBUI_DB_URL", "mongodb://root:root@localhost:27017/" + ) + + DB_CLIENT = MongoClient(f"{OLLAMA_WEBUI_DB_URL}?authSource=admin") + DB = DB_CLIENT["ollama-webui"] + + #################################### + # OLLAMA_WEBUI_JWT_SECRET_KEY + #################################### + + OLLAMA_WEBUI_JWT_SECRET_KEY = os.environ.get( + "OLLAMA_WEBUI_JWT_SECRET_KEY", "t0p-s3cr3t" + ) + + if ENV == "prod": + if OLLAMA_WEBUI_JWT_SECRET_KEY == "": + OLLAMA_WEBUI_JWT_SECRET_KEY = str(b64encode(token_bytes(32)).decode()) diff --git a/backend/constants.py b/backend/constants.py new file mode 100644 index 000000000..62103c305 --- /dev/null +++ b/backend/constants.py @@ -0,0 +1,13 @@ +from enum import Enum + + +class MESSAGES(str, Enum): + DEFAULT = lambda msg="": f"{msg if msg else ''}" + + +class ERROR_MESSAGES(str, Enum): + DEFAULT = lambda err="": f"Something went wrong :/\n{err if err else ''}" + UNAUTHORIZED = "401 Unauthorized" + NOT_FOUND = "We could not find what you're looking for :/" + USER_NOT_FOUND = "We could not find what you're looking for :/" + MALICIOUS = "Unusual activities detected, please try again in a few minutes." diff --git a/backend/main.py b/backend/main.py index 2851df20f..24bad0c92 100644 --- a/backend/main.py +++ b/backend/main.py @@ -1,16 +1,14 @@ -import time -import sys - from fastapi import FastAPI, Request from fastapi.staticfiles import StaticFiles - from fastapi import HTTPException -from starlette.exceptions import HTTPException as StarletteHTTPException - from fastapi.middleware.wsgi import WSGIMiddleware from fastapi.middleware.cors import CORSMiddleware +from starlette.exceptions import HTTPException as StarletteHTTPException from apps.ollama.main import app as ollama_app +from apps.web.main import app as webui_app + +import time class SPAStaticFiles(StaticFiles): @@ -47,5 +45,6 @@ async def check_url(request: Request, call_next): return response +app.mount("/api/v1", webui_app) app.mount("/ollama/api", WSGIMiddleware(ollama_app)) app.mount("/", SPAStaticFiles(directory="../build", html=True), name="spa-static-files") diff --git a/backend/utils.py b/backend/utils.py new file mode 100644 index 000000000..96f40ddc3 --- /dev/null +++ b/backend/utils.py @@ -0,0 +1,68 @@ +from fastapi.security import HTTPBasicCredentials, HTTPBearer +from pydantic import BaseModel +from typing import Union, Optional + +from passlib.context import CryptContext +from datetime import datetime, timedelta +import requests +import jwt + +import config + +JWT_SECRET_KEY = config.OLLAMA_WEBUI_JWT_SECRET_KEY +ALGORITHM = "HS256" + +############## +# Auth Utils +############## + +bearer_scheme = HTTPBearer() +pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") + + +def verify_password(plain_password, hashed_password): + return ( + pwd_context.verify(plain_password, hashed_password) if hashed_password else None + ) + + +def get_password_hash(password): + return pwd_context.hash(password) + + +def create_token(data: dict, expires_delta: Union[timedelta, None] = None) -> str: + payload = data.copy() + + if expires_delta: + expire = datetime.utcnow() + expires_delta + payload.update({"exp": expire}) + + encoded_jwt = jwt.encode(payload, JWT_SECRET_KEY, algorithm=ALGORITHM) + return encoded_jwt + + +def decode_token(token: str) -> Optional[dict]: + try: + decoded = jwt.decode(token, JWT_SECRET_KEY, options={"verify_signature": False}) + return decoded + except Exception as e: + return None + + +def extract_token_from_auth_header(auth_header: str): + return auth_header[len("Bearer ") :] + + +def verify_token(request): + try: + bearer = request.headers["authorization"] + if bearer: + token = bearer[len("Bearer ") :] + decoded = jwt.decode( + token, JWT_SECRET_KEY, options={"verify_signature": False} + ) + return decoded + else: + return None + except Exception as e: + return None diff --git a/compose.yaml b/compose.yaml index b50363542..1c1db1ada 100644 --- a/compose.yaml +++ b/compose.yaml @@ -22,6 +22,15 @@ services: restart: unless-stopped image: ollama/ollama:latest + ollama-webui-db: + image: mongo + container_name: ollama-webui-db + restart: always + # Make sure to change the username/password! + environment: + MONGO_INITDB_ROOT_USERNAME: root + MONGO_INITDB_ROOT_PASSWORD: example + ollama-webui: build: context: . @@ -32,10 +41,12 @@ services: container_name: ollama-webui depends_on: - ollama + - ollama-webui-db ports: - 3000:8080 environment: - "OLLAMA_API_BASE_URL=http://ollama:11434/api" + - "OLLAMA_WEBUI_DB_URL=mongodb://root:example@ollama-webui-db:27017/" extra_hosts: - host.docker.internal:host-gateway restart: unless-stopped diff --git a/src/app.css b/src/app.css index 7afbeafe8..27d2d7ef5 100644 --- a/src/app.css +++ b/src/app.css @@ -4,8 +4,13 @@ font-display: swap; } +@font-face { + font-family: 'Mona Sans'; + src: url('/assets/fonts/Mona-Sans.woff2'); + font-display: swap; +} + html { - @apply bg-gray-800; word-break: break-word; } diff --git a/src/lib/components/chat/SettingsModal.svelte b/src/lib/components/chat/SettingsModal.svelte index 80ee7fde3..61916e8e2 100644 --- a/src/lib/components/chat/SettingsModal.svelte +++ b/src/lib/components/chat/SettingsModal.svelte @@ -2,9 +2,10 @@ import sha256 from 'js-sha256'; import Modal from '../common/Modal.svelte'; - import { WEB_UI_VERSION, API_BASE_URL as BUILD_TIME_API_BASE_URL } from '$lib/constants'; + import { WEB_UI_VERSION, OLLAMA_API_BASE_URL as BUILD_TIME_API_BASE_URL } from '$lib/constants'; import toast from 'svelte-french-toast'; import { onMount } from 'svelte'; + import { config, user } from '$lib/stores'; export let show = false; export let saveSettings: Function; @@ -119,7 +120,8 @@ const res = await fetch(`${API_BASE_URL}/pull`, { method: 'POST', headers: { - 'Content-Type': 'text/event-stream' + 'Content-Type': 'text/event-stream', + ...($user && { Authorization: `Bearer ${localStorage.token}` }) }, body: JSON.stringify({ name: modelTag @@ -175,7 +177,8 @@ const res = await fetch(`${API_BASE_URL}/delete`, { method: 'DELETE', headers: { - 'Content-Type': 'text/event-stream' + 'Content-Type': 'text/event-stream', + ...($user && { Authorization: `Bearer ${localStorage.token}` }) }, body: JSON.stringify({ name: deleteModelTag @@ -992,7 +995,7 @@
Ollama Web UI Version
- {WEB_UI_VERSION} + {$config && $config.version ? $config.version : WEB_UI_VERSION}
diff --git a/src/lib/components/layout/Navbar.svelte b/src/lib/components/layout/Navbar.svelte index 4db018d3d..b7a6f4f07 100644 --- a/src/lib/components/layout/Navbar.svelte +++ b/src/lib/components/layout/Navbar.svelte @@ -1,4 +1,6 @@ + +{#if $config !== undefined} + +{/if} diff --git a/src/routes/+page.svelte b/src/routes/(app)/+page.svelte similarity index 98% rename from src/routes/+page.svelte rename to src/routes/(app)/+page.svelte index 56be3a1f1..d91ed5858 100644 --- a/src/routes/+page.svelte +++ b/src/routes/(app)/+page.svelte @@ -10,17 +10,17 @@ import 'katex/dist/katex.min.css'; import toast from 'svelte-french-toast'; - import { API_BASE_URL as BUILD_TIME_API_BASE_URL } from '$lib/constants'; + import { OLLAMA_API_BASE_URL as BUILD_TIME_API_BASE_URL } from '$lib/constants'; import { onMount, tick } from 'svelte'; import Navbar from '$lib/components/layout/Navbar.svelte'; import SettingsModal from '$lib/components/chat/SettingsModal.svelte'; import Suggestions from '$lib/components/chat/Suggestions.svelte'; + import { user } from '$lib/stores'; let API_BASE_URL = BUILD_TIME_API_BASE_URL; let db; - // let selectedModel = ''; let selectedModels = ['']; let settings = { system: null, @@ -619,7 +619,8 @@ headers: { Accept: 'application/json', 'Content-Type': 'application/json', - ...(settings.authHeader && { Authorization: settings.authHeader }) + ...(settings.authHeader && { Authorization: settings.authHeader }), + ...($user && { Authorization: `Bearer ${localStorage.token}` }) } }) .then(async (res) => { @@ -628,7 +629,11 @@ }) .catch((error) => { console.log(error); - toast.error('Server connection failed'); + if ('detail' in error) { + toast.error(error.detail); + } else { + toast.error('Server connection failed'); + } return null; }); @@ -687,13 +692,6 @@ } }) ); - - // if (selectedModel.includes('gpt-')) { - // await sendPromptOpenAI(userPrompt, parentId); - // } else { - // await sendPromptOllama(userPrompt, parentId); - // } - console.log(history); }; @@ -724,7 +722,8 @@ method: 'POST', headers: { 'Content-Type': 'text/event-stream', - ...(settings.authHeader && { Authorization: settings.authHeader }) + ...(settings.authHeader && { Authorization: settings.authHeader }), + ...($user && { Authorization: `Bearer ${localStorage.token}` }) }, body: JSON.stringify({ model: model, @@ -779,6 +778,8 @@ responseMessage.content += data.response; messages = messages; } + } else if ('detail' in data) { + throw data; } else { responseMessage.done = true; responseMessage.context = data.context; @@ -791,6 +792,10 @@ } } catch (error) { console.log(error); + if ('detail' in error) { + toast.error(error.detail); + } + break; } if (autoScroll) { @@ -817,7 +822,7 @@ window.scrollTo({ top: document.body.scrollHeight }); } - if (messages.length == 2) { + if (messages.length == 2 && messages.at(1).content !== '') { await generateChatTitle(chatId, userPrompt); } }; @@ -1034,7 +1039,8 @@ method: 'POST', headers: { 'Content-Type': 'text/event-stream', - ...(settings.authHeader && { Authorization: settings.authHeader }) + ...(settings.authHeader && { Authorization: settings.authHeader }), + ...($user && { Authorization: `Bearer ${localStorage.token}` }) }, body: JSON.stringify({ model: selectedModels[0], @@ -1047,6 +1053,9 @@ return res.json(); }) .catch((error) => { + if ('detail' in error) { + toast.error(error.detail); + } console.log(error); return null; }); diff --git a/src/routes/(app)/c/[id]/+page.svelte b/src/routes/(app)/c/[id]/+page.svelte new file mode 100644 index 000000000..e69de29bb diff --git a/src/routes/+layout.svelte b/src/routes/+layout.svelte index e7678f9f4..744909fc9 100644 --- a/src/routes/+layout.svelte +++ b/src/routes/+layout.svelte @@ -1,13 +1,71 @@ Ollama - - + +{#if $config !== undefined && loaded} + +{/if} diff --git a/src/routes/auth/+page.svelte b/src/routes/auth/+page.svelte new file mode 100644 index 000000000..d9ebf381a --- /dev/null +++ b/src/routes/auth/+page.svelte @@ -0,0 +1,1091 @@ + + +{#if $config && $config.auth} +
+
+
+ +
+
+
+ +
+ + +
+
+
{ + if (mode === 'signin') { + signInHandler(); + } else { + signUpHandler(); + } + }} + > +
+ {mode === 'signin' ? 'Sign in' : 'Sign up'} to Ollama Web UI +
+ +
+ +
+ {#if mode === 'signup'} +
+
Name
+ +
+ {/if} + +
+
Email
+ +
+ +
+
Password
+ +
+
+ +
+ + +
+ {mode === 'signin' ? `Don't have an account?` : `Already have an account?`} + + +
+
+
+
+
+
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +{/if} diff --git a/static/assets/fonts/Mona-Sans.woff2 b/static/assets/fonts/Mona-Sans.woff2 new file mode 100644 index 000000000..d88d5ff27 Binary files /dev/null and b/static/assets/fonts/Mona-Sans.woff2 differ