From ce85400817e91b55efd836074e3adc36c95afa76 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Tue, 22 Oct 2024 22:55:34 -0700 Subject: [PATCH] refac: feedback --- .../open_webui/apps/webui/models/feedbacks.py | 118 +++++++-- .../apps/webui/routers/evaluations.py | 89 +++++++ .../af906e964978_add_feedback_table.py | 6 + src/lib/apis/evaluations/index.ts | 152 ++++++++++++ src/lib/components/admin/Evaluations.svelte | 224 +++++++++++++----- .../admin/Evaluations/FeedbackMenu.svelte | 46 ++++ .../components/chat/Messages/Message.svelte | 1 + .../Messages/MultiResponseMessages.svelte | 1 + .../chat/Messages/ResponseMessage.svelte | 106 +++++++-- src/routes/(app)/admin/+page.svelte | 142 ++++++++--- 10 files changed, 753 insertions(+), 132 deletions(-) create mode 100644 src/lib/components/admin/Evaluations/FeedbackMenu.svelte diff --git a/backend/open_webui/apps/webui/models/feedbacks.py b/backend/open_webui/apps/webui/models/feedbacks.py index 76410f417..7255b48fd 100644 --- a/backend/open_webui/apps/webui/models/feedbacks.py +++ b/backend/open_webui/apps/webui/models/feedbacks.py @@ -23,9 +23,11 @@ class Feedback(Base): __tablename__ = "feedback" id = Column(Text, primary_key=True) user_id = Column(Text) + version = Column(BigInteger, default=0) type = Column(Text) data = Column(JSON, nullable=True) meta = Column(JSON, nullable=True) + snapshot = Column(JSON, nullable=True) created_at = Column(BigInteger) updated_at = Column(BigInteger) @@ -33,9 +35,11 @@ class Feedback(Base): class FeedbackModel(BaseModel): id: str user_id: str + version: int type: str data: Optional[dict] = None meta: Optional[dict] = None + snapshot: Optional[dict] = None created_at: int updated_at: int @@ -47,30 +51,44 @@ class FeedbackModel(BaseModel): #################### +class FeedbackResponse(BaseModel): + id: str + user_id: str + version: int + type: str + data: Optional[dict] = None + meta: Optional[dict] = None + created_at: int + updated_at: int + + class RatingData(BaseModel): - rating: str - comment: str - model_config = ConfigDict(extra="allow") - - -class VoteData(BaseModel): - rating: str - model_id: str - model_ids: list[str] + rating: Optional[str | int] = None + model_id: Optional[str] = None + sibling_model_ids: Optional[list[str]] = None + reason: Optional[str] = None + comment: Optional[str] = None model_config = ConfigDict(extra="allow") class MetaData(BaseModel): - chat: Optional[dict] = None + arena: Optional[bool] = None + chat_id: Optional[str] = None message_id: Optional[str] = None tags: Optional[list[str]] = None model_config = ConfigDict(extra="allow") +class SnapshotData(BaseModel): + chat: Optional[dict] = None + model_config = ConfigDict(extra="allow") + + class FeedbackForm(BaseModel): type: str - data: Optional[RatingData | VoteData] = None + data: Optional[RatingData] = None meta: Optional[dict] = None + snapshot: Optional[SnapshotData] = None model_config = ConfigDict(extra="allow") @@ -84,10 +102,10 @@ class FeedbackTable: **{ "id": id, "user_id": user_id, - "type": form_data.type, - "data": form_data.data, - "meta": form_data.meta, + "version": 0, + **form_data.model_dump(), "created_at": int(time.time()), + "updated_at": int(time.time()), } ) try: @@ -113,6 +131,25 @@ class FeedbackTable: except Exception: return None + def get_feedback_by_id_and_user_id( + self, id: str, user_id: str + ) -> Optional[FeedbackModel]: + try: + with get_db() as db: + feedback = db.query(Feedback).filter_by(id=id, user_id=user_id).first() + if not feedback: + return None + return FeedbackModel.model_validate(feedback) + except Exception: + return None + + def get_all_feedbacks(self) -> list[FeedbackModel]: + with get_db() as db: + return [ + FeedbackModel.model_validate(feedback) + for feedback in db.query(Feedback).all() + ] + def get_feedbacks_by_type(self, type: str) -> list[FeedbackModel]: with get_db() as db: return [ @@ -136,9 +173,31 @@ class FeedbackTable: return None if form_data.data: - feedback.data = form_data.data + feedback.data = form_data.data.model_dump() if form_data.meta: feedback.meta = form_data.meta + if form_data.snapshot: + feedback.snapshot = form_data.snapshot.model_dump() + + feedback.updated_at = int(time.time()) + + db.commit() + return FeedbackModel.model_validate(feedback) + + def update_feedback_by_id_and_user_id( + self, id: str, user_id: str, form_data: FeedbackForm + ) -> Optional[FeedbackModel]: + with get_db() as db: + feedback = db.query(Feedback).filter_by(id=id, user_id=user_id).first() + if not feedback: + return None + + if form_data.data: + feedback.data = form_data.data.model_dump() + if form_data.meta: + feedback.meta = form_data.meta + if form_data.snapshot: + feedback.snapshot = form_data.snapshot.model_dump() feedback.updated_at = int(time.time()) @@ -154,5 +213,34 @@ class FeedbackTable: db.commit() return True + def delete_feedback_by_id_and_user_id(self, id: str, user_id: str) -> bool: + with get_db() as db: + feedback = db.query(Feedback).filter_by(id=id, user_id=user_id).first() + if not feedback: + return False + db.delete(feedback) + db.commit() + return True + + def delete_feedbacks_by_user_id(self, user_id: str) -> bool: + with get_db() as db: + feedbacks = db.query(Feedback).filter_by(user_id=user_id).all() + if not feedbacks: + return False + for feedback in feedbacks: + db.delete(feedback) + db.commit() + return True + + def delete_all_feedbacks(self) -> bool: + with get_db() as db: + feedbacks = db.query(Feedback).all() + if not feedbacks: + return False + for feedback in feedbacks: + db.delete(feedback) + db.commit() + return True + Feedbacks = FeedbackTable() diff --git a/backend/open_webui/apps/webui/routers/evaluations.py b/backend/open_webui/apps/webui/routers/evaluations.py index b40953e49..f0e9236dc 100644 --- a/backend/open_webui/apps/webui/routers/evaluations.py +++ b/backend/open_webui/apps/webui/routers/evaluations.py @@ -3,6 +3,12 @@ from fastapi import APIRouter, Depends, HTTPException, status, Request from pydantic import BaseModel +from open_webui.apps.webui.models.feedbacks import ( + FeedbackModel, + FeedbackForm, + Feedbacks, +) + from open_webui.constants import ERROR_MESSAGES from open_webui.utils.utils import get_admin_user, get_verified_user @@ -47,3 +53,86 @@ async def update_config( "ENABLE_EVALUATION_ARENA_MODELS": config.ENABLE_EVALUATION_ARENA_MODELS, "EVALUATION_ARENA_MODELS": config.EVALUATION_ARENA_MODELS, } + + +@router.get("/feedbacks", response_model=list[FeedbackModel]) +async def get_feedbacks(user=Depends(get_verified_user)): + feedbacks = Feedbacks.get_feedbacks_by_user_id(user.id) + return feedbacks + + +@router.delete("/feedbacks", response_model=bool) +async def delete_feedbacks(user=Depends(get_verified_user)): + success = Feedbacks.delete_feedbacks_by_user_id(user.id) + return success + + +@router.delete("/feedbacks/all") +async def delete_all_feedbacks(user=Depends(get_admin_user)): + success = Feedbacks.delete_all_feedbacks() + return success + + +@router.get("/feedbacks/all", response_model=list[FeedbackModel]) +async def get_all_feedbacks(user=Depends(get_admin_user)): + feedbacks = Feedbacks.get_all_feedbacks() + return feedbacks + + +@router.post("/feedback", response_model=FeedbackModel) +async def create_feedback( + request: Request, + form_data: FeedbackForm, + user=Depends(get_verified_user), +): + feedback = Feedbacks.insert_new_feedback(user_id=user.id, form_data=form_data) + if not feedback: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(), + ) + + return feedback + + +@router.get("/feedback/{id}", response_model=FeedbackModel) +async def get_feedback_by_id(id: str, user=Depends(get_verified_user)): + feedback = Feedbacks.get_feedback_by_id_and_user_id(id=id, user_id=user.id) + + if not feedback: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND + ) + + return feedback + + +@router.post("/feedback/{id}", response_model=FeedbackModel) +async def update_feedback_by_id( + id: str, form_data: FeedbackForm, user=Depends(get_verified_user) +): + feedback = Feedbacks.update_feedback_by_id_and_user_id( + id=id, user_id=user.id, form_data=form_data + ) + + if not feedback: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND + ) + + return feedback + + +@router.delete("/feedback/{id}") +async def delete_feedback_by_id(id: str, user=Depends(get_verified_user)): + if user.role == "admin": + success = Feedbacks.delete_feedback_by_id(id=id) + else: + success = Feedbacks.delete_feedback_by_id_and_user_id(id=id, user_id=user.id) + + if not success: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND + ) + + return success diff --git a/backend/open_webui/migrations/versions/af906e964978_add_feedback_table.py b/backend/open_webui/migrations/versions/af906e964978_add_feedback_table.py index 8119c9396..9116aa388 100644 --- a/backend/open_webui/migrations/versions/af906e964978_add_feedback_table.py +++ b/backend/open_webui/migrations/versions/af906e964978_add_feedback_table.py @@ -26,11 +26,17 @@ def upgrade(): sa.Column( "user_id", sa.Text(), nullable=True ), # ID of the user providing the feedback (TEXT type) + sa.Column( + "version", sa.BigInteger(), default=0 + ), # Version of feedback (BIGINT type) sa.Column("type", sa.Text(), nullable=True), # Type of feedback (TEXT type) sa.Column("data", sa.JSON(), nullable=True), # Feedback data (JSON type) sa.Column( "meta", sa.JSON(), nullable=True ), # Metadata for feedback (JSON type) + sa.Column( + "snapshot", sa.JSON(), nullable=True + ), # snapshot data for feedback (JSON type) sa.Column( "created_at", sa.BigInteger(), nullable=False ), # Feedback creation timestamp (BIGINT representing epoch) diff --git a/src/lib/apis/evaluations/index.ts b/src/lib/apis/evaluations/index.ts index 21130bd17..854b3abb8 100644 --- a/src/lib/apis/evaluations/index.ts +++ b/src/lib/apis/evaluations/index.ts @@ -61,3 +61,155 @@ export const updateConfig = async (token: string, config: object) => { return res; }; + +export const getAllFeedbacks = async (token: string = '') => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/evaluations/feedbacks/all`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const createNewFeedback = async (token: string, feedback: object) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/evaluations/feedback`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + ...feedback + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = err.detail; + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const getFeedbackById = async (token: string, feedbackId: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/evaluations/feedback/${feedbackId}`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const updateFeedbackById = async (token: string, feedbackId: string, feedback: object) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/evaluations/feedback/${feedbackId}`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + ...feedback + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = err.detail; + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const deleteFeedbackById = async (token: string, feedbackId: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/evaluations/feedback/${feedbackId}`, { + method: 'DELETE', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = err.detail; + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; diff --git a/src/lib/components/admin/Evaluations.svelte b/src/lib/components/admin/Evaluations.svelte index 7fe3c5756..c9b678ffa 100644 --- a/src/lib/components/admin/Evaluations.svelte +++ b/src/lib/components/admin/Evaluations.svelte @@ -2,19 +2,24 @@ import { onMount, getContext } from 'svelte'; import { models } from '$lib/stores'; + import GarbageBin from '../icons/GarbageBin.svelte'; + import FeedbackMenu from './Evaluations/FeedbackMenu.svelte'; + import EllipsisHorizontal from '../icons/EllipsisHorizontal.svelte'; + import { getAllFeedbacks } from '$lib/apis/evaluations'; const i18n = getContext('i18n'); let rankedModels = []; + let feedbacks = []; + let loaded = false; - - onMount(() => { - loaded = true; - + onMount(async () => { + feedbacks = await getAllFeedbacks(localStorage.token); rankedModels = $models .filter((m) => m?.owned_by !== 'arena' && (m?.info?.meta?.hidden ?? false) !== true) .map((model) => { return { ...model, + ranking: '-', rating: '-', stats: { won: '-', @@ -34,11 +39,13 @@ // If both ratings are '-', sort alphabetically (by 'name') return a.name.localeCompare(b.name); }); + + loaded = true; }); {#if loaded} -
+
{$i18n.t('Leaderboard')} @@ -52,75 +59,164 @@
- - + {$i18n.t('No models found')} + + {:else} +
- - - - + + + + + + - - - - - - {#each rankedModels as model (model.id)} - - - - - - - - - + + - {/each} - -
- {$i18n.t('Model')} - - {$i18n.t('Rating')} - - {$i18n.t('Won')} -
+ {$i18n.t('RK')} + + {$i18n.t('Model')} + + {$i18n.t('Rating')} + + {$i18n.t('Won')} + - {$i18n.t('Draw')} - - {$i18n.t('Lost')} -
-
-
- {model.name} -
- -
- {model.name} -
-
-
- {model.rating} - {model.stats.won} - {model.stats.draw} - - {model.stats.lost} - + {$i18n.t('Draw')} + + {$i18n.t('Lost')} +
+ + + {#each rankedModels as model (model.id)} + + +
+ {model.ranking} +
+ + +
+
+ {model.name} +
+ +
+ {model.name} +
+
+ + + {model.rating} + + + + {model.stats.won} + + + + {model.stats.draw} + + + + {model.stats.lost} + + + {/each} + + + {/if}
-
+
- {$i18n.t('Rating History')} + {$i18n.t('Feedback History')}
+
+ {#if (feedbacks ?? []).length === 0} +
+ {$i18n.t('No feedbacks found')} +
+ {:else} + + + + + + + + + + + + + + + + {#each feedbacks as feedback (feedback.id)} + + + + + + + + + + + {/each} + +
+ {$i18n.t('Models')} + + {$i18n.t('Result')} + + {$i18n.t('User')} + + {$i18n.t('Created At')} +
+
+
+ {model.name} +
+
+ {model.name} +
+
+
+ {model.rating} + {model.stats.won} + {model.stats.draw} + + + + +
+ {/if} +
+
{/if} diff --git a/src/lib/components/admin/Evaluations/FeedbackMenu.svelte b/src/lib/components/admin/Evaluations/FeedbackMenu.svelte new file mode 100644 index 000000000..83defd804 --- /dev/null +++ b/src/lib/components/admin/Evaluations/FeedbackMenu.svelte @@ -0,0 +1,46 @@ + + + {}}> + + + + +
+ + { + dispatch('delete'); + show = false; + }} + > + +
{$i18n.t('Delete')}
+
+
+
+
diff --git a/src/lib/components/chat/Messages/Message.svelte b/src/lib/components/chat/Messages/Message.svelte index 742658390..283af001a 100644 --- a/src/lib/components/chat/Messages/Message.svelte +++ b/src/lib/components/chat/Messages/Message.svelte @@ -66,6 +66,7 @@ /> {:else if (history.messages[history.messages[messageId].parentId]?.models?.length ?? 1) === 1} { + console.log('Feedback', rating, annotation); + + const updatedMessage = { + ...message, + annotation: { + ...(message?.annotation ?? {}), + ...(rating !== null ? { rating: rating } : {}), + ...(annotation ? annotation : {}) + } + }; + + const chat = await getChatById(localStorage.token, chatId).catch((error) => { + toast.error(error); + }); + if (!chat) { + return; + } + + let feedbackItem = { + type: 'rating', + data: { + ...(updatedMessage?.annotation ? updatedMessage.annotation : {}), + model_id: message?.selectedModelId ?? message.model, + ...(history.messages[message.parentId].childrenIds.length > 1 + ? { + sibling_model_ids: history.messages[message.parentId].childrenIds + .filter((id) => id !== message.id) + .map((id) => history.messages[id]?.selectedModelId ?? history.messages[id].model) + } + : {}) + }, + meta: { + arena: message ? message.arena : false, + message_id: message.id, + chat_id: chatId + }, + snapshot: { + chat: chat + } + }; + + let feedback = null; + if (message?.feedbackId) { + feedback = await updateFeedbackById( + localStorage.token, + message.feedbackId, + feedbackItem + ).catch((error) => { + toast.error(error); + }); + } else { + feedback = await createNewFeedback(localStorage.token, feedbackItem).catch((error) => { + toast.error(error); + }); + + if (feedback) { + updatedMessage.feedbackId = feedback.id; + } + } + + console.log(updatedMessage); + dispatch('save', updatedMessage); + + await tick(); + + if (!annotation) { + showRateComment = true; + } + }; + $: if (!edit) { (async () => { await tick(); @@ -880,12 +957,13 @@