Merge pull request #1405 from frostming/patch-1

feat: fallback to openai compatible provider if url host doesn't match
This commit is contained in:
Mauricio Siu
2025-03-07 00:56:33 -06:00
committed by GitHub
3 changed files with 117 additions and 46 deletions

View File

@@ -45,21 +45,12 @@ const Schema = z.object({
type Schema = z.infer<typeof Schema>; type Schema = z.infer<typeof Schema>;
interface Model {
id: string;
object: string;
created: number;
owned_by: string;
}
interface Props { interface Props {
aiId?: string; aiId?: string;
} }
export const HandleAi = ({ aiId }: Props) => { export const HandleAi = ({ aiId }: Props) => {
const [models, setModels] = useState<Model[]>([]);
const utils = api.useUtils(); const utils = api.useUtils();
const [isLoadingModels, setIsLoadingModels] = useState(false);
const [error, setError] = useState<string | null>(null); const [error, setError] = useState<string | null>(null);
const [open, setOpen] = useState(false); const [open, setOpen] = useState(false);
const { data, refetch } = api.ai.one.useQuery( const { data, refetch } = api.ai.one.useQuery(
@@ -73,6 +64,7 @@ export const HandleAi = ({ aiId }: Props) => {
const { mutateAsync, isLoading } = aiId const { mutateAsync, isLoading } = aiId
? api.ai.update.useMutation() ? api.ai.update.useMutation()
: api.ai.create.useMutation(); : api.ai.create.useMutation();
const form = useForm<Schema>({ const form = useForm<Schema>({
resolver: zodResolver(Schema), resolver: zodResolver(Schema),
defaultValues: { defaultValues: {
@@ -94,50 +86,33 @@ export const HandleAi = ({ aiId }: Props) => {
}); });
}, [aiId, form, data]); }, [aiId, form, data]);
const fetchModels = async (apiUrl: string, apiKey: string) => { const apiUrl = form.watch("apiUrl");
setIsLoadingModels(true); const apiKey = form.watch("apiKey");
setError(null);
try {
const response = await fetch(`${apiUrl}/models`, {
headers: {
Authorization: `Bearer ${apiKey}`,
},
});
if (!response.ok) {
throw new Error("Failed to fetch models");
}
const res = await response.json();
setModels(res.data);
// Set default model to gpt-4 if present const { data: models, isLoading: isLoadingServerModels } =
const defaultModel = res.data.find( api.ai.getModels.useQuery(
(model: Model) => model.id === "gpt-4", {
); apiUrl: apiUrl ?? "",
if (defaultModel) { apiKey: apiKey ?? "",
form.setValue("model", defaultModel.id); },
return defaultModel.id; {
} enabled: !!apiUrl && !!apiKey,
} catch (error) { onError: (error) => {
setError("Failed to fetch models. Please check your API URL and Key."); setError(`Failed to fetch models: ${error.message}`);
setModels([]); },
} finally { },
setIsLoadingModels(false); );
}
};
useEffect(() => { useEffect(() => {
const apiUrl = form.watch("apiUrl"); const apiUrl = form.watch("apiUrl");
const apiKey = form.watch("apiKey"); const apiKey = form.watch("apiKey");
if (apiUrl && apiKey) { if (apiUrl && apiKey) {
form.setValue("model", ""); form.setValue("model", "");
fetchModels(apiUrl, apiKey);
} }
}, [form.watch("apiUrl"), form.watch("apiKey")]); }, [form.watch("apiUrl"), form.watch("apiKey")]);
const onSubmit = async (data: Schema) => { const onSubmit = async (data: Schema) => {
try { try {
console.log("Form data:", data);
console.log("Current model value:", form.getValues("model"));
await mutateAsync({ await mutateAsync({
...data, ...data,
aiId: aiId || "", aiId: aiId || "",
@@ -148,8 +123,9 @@ export const HandleAi = ({ aiId }: Props) => {
refetch(); refetch();
setOpen(false); setOpen(false);
} catch (error) { } catch (error) {
console.error("Submit error:", error); toast.error("Failed to save AI settings", {
toast.error("Failed to save AI settings"); description: error instanceof Error ? error.message : "Unknown error",
});
} }
}; };
@@ -232,13 +208,13 @@ export const HandleAi = ({ aiId }: Props) => {
)} )}
/> />
{isLoadingModels && ( {isLoadingServerModels && (
<span className="text-sm text-muted-foreground"> <span className="text-sm text-muted-foreground">
Loading models... Loading models...
</span> </span>
)} )}
{!isLoadingModels && models.length > 0 && ( {!isLoadingServerModels && models && models.length > 0 && (
<FormField <FormField
control={form.control} control={form.control}
name="model" name="model"

View File

@@ -25,6 +25,10 @@ import {
addNewService, addNewService,
checkServiceAccess, checkServiceAccess,
} from "@dokploy/server/services/user"; } from "@dokploy/server/services/user";
import {
getProviderHeaders,
type Model,
} from "@dokploy/server/utils/ai/select-ai-provider";
import { TRPCError } from "@trpc/server"; import { TRPCError } from "@trpc/server";
import { z } from "zod"; import { z } from "zod";
@@ -41,6 +45,58 @@ export const aiRouter = createTRPCRouter({
} }
return aiSetting; return aiSetting;
}), }),
getModels: protectedProcedure
.input(z.object({ apiUrl: z.string().min(1), apiKey: z.string().min(1) }))
.query(async ({ input }) => {
try {
const headers = getProviderHeaders(input.apiUrl, input.apiKey);
const response = await fetch(`${input.apiUrl}/models`, { headers });
if (!response.ok) {
const errorText = await response.text();
throw new Error(`Failed to fetch models: ${errorText}`);
}
const res = await response.json();
if (Array.isArray(res)) {
return res.map((model) => ({
id: model.id || model.name,
object: "model",
created: Date.now(),
owned_by: "provider",
}));
}
if (res.models) {
return res.models.map((model: any) => ({
id: model.id || model.name,
object: "model",
created: Date.now(),
owned_by: "provider",
})) as Model[];
}
if (res.data) {
return res.data as Model[];
}
const possibleModels =
(Object.values(res).find(Array.isArray) as any[]) || [];
return possibleModels.map((model) => ({
id: model.id || model.name,
object: "model",
created: Date.now(),
owned_by: "provider",
})) as Model[];
} catch (error) {
throw new TRPCError({
code: "BAD_REQUEST",
message: error instanceof Error ? error?.message : `Error: ${error}`,
});
}
}),
create: adminProcedure.input(apiCreateAi).mutation(async ({ ctx, input }) => { create: adminProcedure.input(apiCreateAi).mutation(async ({ ctx, input }) => {
return await saveAiSettings(ctx.session.activeOrganizationId, input); return await saveAiSettings(ctx.session.activeOrganizationId, input);
}), }),

View File

@@ -17,7 +17,7 @@ function getProviderName(apiUrl: string) {
if (apiUrl.includes("localhost:11434") || apiUrl.includes("ollama")) if (apiUrl.includes("localhost:11434") || apiUrl.includes("ollama"))
return "ollama"; return "ollama";
if (apiUrl.includes("api.deepinfra.com")) return "deepinfra"; if (apiUrl.includes("api.deepinfra.com")) return "deepinfra";
throw new Error(`Unsupported AI provider for URL: ${apiUrl}`); return "custom";
} }
export function selectAIProvider(config: { apiUrl: string; apiKey: string }) { export function selectAIProvider(config: { apiUrl: string; apiKey: string }) {
@@ -67,7 +67,46 @@ export function selectAIProvider(config: { apiUrl: string; apiKey: string }) {
baseURL: config.apiUrl, baseURL: config.apiUrl,
apiKey: config.apiKey, apiKey: config.apiKey,
}); });
case "custom":
return createOpenAICompatible({
name: "custom",
baseURL: config.apiUrl,
headers: {
Authorization: `Bearer ${config.apiKey}`,
},
});
default: default:
throw new Error(`Unsupported AI provider: ${providerName}`); throw new Error(`Unsupported AI provider: ${providerName}`);
} }
} }
export const getProviderHeaders = (
apiUrl: string,
apiKey: string,
): Record<string, string> => {
// Anthropic
if (apiUrl.includes("anthropic")) {
return {
"x-api-key": apiKey,
"anthropic-version": "2023-06-01",
};
}
// Mistral
if (apiUrl.includes("mistral")) {
return {
Authorization: apiKey,
};
}
// Default (OpenAI style)
return {
Authorization: `Bearer ${apiKey}`,
};
};
export interface Model {
id: string;
object: string;
created: number;
owned_by: string;
}