From 5e0657ce556bbf04cce22bb451ff9349def6b04b Mon Sep 17 00:00:00 2001 From: Dogtiti <499960698@qq.com> Date: Sat, 6 Jul 2024 11:27:53 +0800 Subject: [PATCH] feat: add getClientApi method --- app/client/api.ts | 11 +++++++++++ app/components/exporter.tsx | 17 +++-------------- app/components/home.tsx | 14 ++++---------- app/store/chat.ts | 26 ++++++++------------------ 4 files changed, 26 insertions(+), 42 deletions(-) diff --git a/app/client/api.ts b/app/client/api.ts index 41ccbd8e1..cd9d72c15 100644 --- a/app/client/api.ts +++ b/app/client/api.ts @@ -200,3 +200,14 @@ export function getHeaders() { return headers; } + +export function getClientApi(provider: ServiceProvider): ClientApi { + switch (provider) { + case ServiceProvider.Google: + return new ClientApi(ModelProvider.GeminiPro); + case ServiceProvider.Anthropic: + return new ClientApi(ModelProvider.Claude); + default: + return new ClientApi(ModelProvider.GPT); + } +} diff --git a/app/components/exporter.tsx b/app/components/exporter.tsx index 7281fc2f1..948807d4c 100644 --- a/app/components/exporter.tsx +++ b/app/components/exporter.tsx @@ -36,13 +36,9 @@ import { toBlob, toPng } from "html-to-image"; import { DEFAULT_MASK_AVATAR } from "../store/mask"; import { prettyObject } from "../utils/format"; -import { - EXPORT_MESSAGE_CLASS_NAME, - ModelProvider, - ServiceProvider, -} from "../constant"; +import { EXPORT_MESSAGE_CLASS_NAME } from "../constant"; import { getClientConfig } from "../config/client"; -import { ClientApi } from "../client/api"; +import { type ClientApi, getClientApi } from "../client/api"; import { getMessageTextContent } from "../utils"; const Markdown = dynamic(async () => (await import("./markdown")).Markdown, { @@ -316,14 +312,7 @@ export function PreviewActions(props: { const onRenderMsgs = (msgs: ChatMessage[]) => { setShouldExport(false); - var api: ClientApi; - if (config.modelConfig.providerName == ServiceProvider.Google) { - api = new ClientApi(ModelProvider.GeminiPro); - } else if (config.modelConfig.providerName == ServiceProvider.Anthropic) { - api = new ClientApi(ModelProvider.Claude); - } else { - api = new ClientApi(ModelProvider.GPT); - } + const api: ClientApi = getClientApi(config.modelConfig.providerName); api .share(msgs) diff --git a/app/components/home.tsx b/app/components/home.tsx index addb5e803..e127c65f8 100644 --- a/app/components/home.tsx +++ b/app/components/home.tsx @@ -12,7 +12,7 @@ import LoadingIcon from "../icons/three-dots.svg"; import { getCSSVar, useMobileScreen } from "../utils"; import dynamic from "next/dynamic"; -import { ServiceProvider, ModelProvider, Path, SlotID } from "../constant"; +import { Path, SlotID } from "../constant"; import { ErrorBoundary } from "./error"; import { getISOLang, getLang } from "../locales"; @@ -27,7 +27,7 @@ import { SideBar } from "./sidebar"; import { useAppConfig } from "../store/config"; import { AuthPage } from "./auth"; import { getClientConfig } from "../config/client"; -import { ClientApi } from "../client/api"; +import { type ClientApi, getClientApi } from "../client/api"; import { useAccessStore } from "../store"; export function Loading(props: { noLogo?: boolean }) { @@ -170,14 +170,8 @@ function Screen() { export function useLoadData() { const config = useAppConfig(); - var api: ClientApi; - if (config.modelConfig.providerName == ServiceProvider.Google) { - api = new ClientApi(ModelProvider.GeminiPro); - } else if (config.modelConfig.providerName == ServiceProvider.Anthropic) { - api = new ClientApi(ModelProvider.Claude); - } else { - api = new ClientApi(ModelProvider.GPT); - } + const api: ClientApi = getClientApi(config.modelConfig.providerName); + useEffect(() => { (async () => { const models = await api.llm.models(); diff --git a/app/store/chat.ts b/app/store/chat.ts index 44d41830a..d14bd82d8 100644 --- a/app/store/chat.ts +++ b/app/store/chat.ts @@ -15,7 +15,12 @@ import { SUMMARIZE_MODEL, GEMINI_SUMMARIZE_MODEL, } from "../constant"; -import { ClientApi, RequestMessage, MultimodalContent } from "../client/api"; +import { getClientApi } from "../client/api"; +import type { + ClientApi, + RequestMessage, + MultimodalContent, +} from "../client/api"; import { ChatControllerPool } from "../client/controller"; import { prettyObject } from "../utils/format"; import { estimateTokenLength } from "../utils/token"; @@ -363,15 +368,7 @@ export const useChatStore = createPersistStore( ]); }); - var api: ClientApi; - if (modelConfig.providerName == ServiceProvider.Google) { - api = new ClientApi(ModelProvider.GeminiPro); - } else if (modelConfig.providerName == ServiceProvider.Anthropic) { - api = new ClientApi(ModelProvider.Claude); - } else { - api = new ClientApi(ModelProvider.GPT); - } - + const api: ClientApi = getClientApi(modelConfig.providerName); // make request api.llm.chat({ messages: sendMessages, @@ -547,14 +544,7 @@ export const useChatStore = createPersistStore( const session = get().currentSession(); const modelConfig = session.mask.modelConfig; - var api: ClientApi; - if (modelConfig.providerName == ServiceProvider.Google) { - api = new ClientApi(ModelProvider.GeminiPro); - } else if (modelConfig.providerName == ServiceProvider.Anthropic) { - api = new ClientApi(ModelProvider.Claude); - } else { - api = new ClientApi(ModelProvider.GPT); - } + const api: ClientApi = getClientApi(modelConfig.providerName); // remove error messages if any const messages = session.messages;