feat: add getClientApi method

This commit is contained in:
Dogtiti 2024-07-06 11:27:53 +08:00
parent 2d1f522aaf
commit 5e0657ce55
4 changed files with 26 additions and 42 deletions

View File

@ -200,3 +200,14 @@ export function getHeaders() {
return headers;
}
export function getClientApi(provider: ServiceProvider): ClientApi {
switch (provider) {
case ServiceProvider.Google:
return new ClientApi(ModelProvider.GeminiPro);
case ServiceProvider.Anthropic:
return new ClientApi(ModelProvider.Claude);
default:
return new ClientApi(ModelProvider.GPT);
}
}

View File

@ -36,13 +36,9 @@ import { toBlob, toPng } from "html-to-image";
import { DEFAULT_MASK_AVATAR } from "../store/mask";
import { prettyObject } from "../utils/format";
import {
EXPORT_MESSAGE_CLASS_NAME,
ModelProvider,
ServiceProvider,
} from "../constant";
import { EXPORT_MESSAGE_CLASS_NAME } from "../constant";
import { getClientConfig } from "../config/client";
import { ClientApi } from "../client/api";
import { type ClientApi, getClientApi } from "../client/api";
import { getMessageTextContent } from "../utils";
const Markdown = dynamic(async () => (await import("./markdown")).Markdown, {
@ -316,14 +312,7 @@ export function PreviewActions(props: {
const onRenderMsgs = (msgs: ChatMessage[]) => {
setShouldExport(false);
var api: ClientApi;
if (config.modelConfig.providerName == ServiceProvider.Google) {
api = new ClientApi(ModelProvider.GeminiPro);
} else if (config.modelConfig.providerName == ServiceProvider.Anthropic) {
api = new ClientApi(ModelProvider.Claude);
} else {
api = new ClientApi(ModelProvider.GPT);
}
const api: ClientApi = getClientApi(config.modelConfig.providerName);
api
.share(msgs)

View File

@ -12,7 +12,7 @@ import LoadingIcon from "../icons/three-dots.svg";
import { getCSSVar, useMobileScreen } from "../utils";
import dynamic from "next/dynamic";
import { ServiceProvider, ModelProvider, Path, SlotID } from "../constant";
import { Path, SlotID } from "../constant";
import { ErrorBoundary } from "./error";
import { getISOLang, getLang } from "../locales";
@ -27,7 +27,7 @@ import { SideBar } from "./sidebar";
import { useAppConfig } from "../store/config";
import { AuthPage } from "./auth";
import { getClientConfig } from "../config/client";
import { ClientApi } from "../client/api";
import { type ClientApi, getClientApi } from "../client/api";
import { useAccessStore } from "../store";
export function Loading(props: { noLogo?: boolean }) {
@ -170,14 +170,8 @@ function Screen() {
export function useLoadData() {
const config = useAppConfig();
var api: ClientApi;
if (config.modelConfig.providerName == ServiceProvider.Google) {
api = new ClientApi(ModelProvider.GeminiPro);
} else if (config.modelConfig.providerName == ServiceProvider.Anthropic) {
api = new ClientApi(ModelProvider.Claude);
} else {
api = new ClientApi(ModelProvider.GPT);
}
const api: ClientApi = getClientApi(config.modelConfig.providerName);
useEffect(() => {
(async () => {
const models = await api.llm.models();

View File

@ -15,7 +15,12 @@ import {
SUMMARIZE_MODEL,
GEMINI_SUMMARIZE_MODEL,
} from "../constant";
import { ClientApi, RequestMessage, MultimodalContent } from "../client/api";
import { getClientApi } from "../client/api";
import type {
ClientApi,
RequestMessage,
MultimodalContent,
} from "../client/api";
import { ChatControllerPool } from "../client/controller";
import { prettyObject } from "../utils/format";
import { estimateTokenLength } from "../utils/token";
@ -363,15 +368,7 @@ export const useChatStore = createPersistStore(
]);
});
var api: ClientApi;
if (modelConfig.providerName == ServiceProvider.Google) {
api = new ClientApi(ModelProvider.GeminiPro);
} else if (modelConfig.providerName == ServiceProvider.Anthropic) {
api = new ClientApi(ModelProvider.Claude);
} else {
api = new ClientApi(ModelProvider.GPT);
}
const api: ClientApi = getClientApi(modelConfig.providerName);
// make request
api.llm.chat({
messages: sendMessages,
@ -547,14 +544,7 @@ export const useChatStore = createPersistStore(
const session = get().currentSession();
const modelConfig = session.mask.modelConfig;
var api: ClientApi;
if (modelConfig.providerName == ServiceProvider.Google) {
api = new ClientApi(ModelProvider.GeminiPro);
} else if (modelConfig.providerName == ServiceProvider.Anthropic) {
api = new ClientApi(ModelProvider.Claude);
} else {
api = new ClientApi(ModelProvider.GPT);
}
const api: ClientApi = getClientApi(modelConfig.providerName);
// remove error messages if any
const messages = session.messages;