feat(ai): improve model fetching and error handling

- Add server-side model fetching endpoint with flexible provider support
- Refactor client-side AI settings component to use new API query
- Implement dynamic header generation for different AI providers
- Enhance error handling and toast notifications
- Remove local model fetching logic in favor of server-side implementation
This commit is contained in:
Mauricio Siu
2025-03-07 00:55:11 -06:00
parent efd176451f
commit b8e5cae88f
4 changed files with 112 additions and 46 deletions

View File

@@ -45,21 +45,12 @@ const Schema = z.object({
type Schema = z.infer<typeof Schema>;
interface Model {
id: string;
object: string;
created: number;
owned_by: string;
}
interface Props {
aiId?: string;
}
export const HandleAi = ({ aiId }: Props) => {
const [models, setModels] = useState<Model[]>([]);
const utils = api.useUtils();
const [isLoadingModels, setIsLoadingModels] = useState(false);
const [error, setError] = useState<string | null>(null);
const [open, setOpen] = useState(false);
const { data, refetch } = api.ai.one.useQuery(
@@ -73,6 +64,7 @@ export const HandleAi = ({ aiId }: Props) => {
const { mutateAsync, isLoading } = aiId
? api.ai.update.useMutation()
: api.ai.create.useMutation();
const form = useForm<Schema>({
resolver: zodResolver(Schema),
defaultValues: {
@@ -94,50 +86,33 @@ export const HandleAi = ({ aiId }: Props) => {
});
}, [aiId, form, data]);
const fetchModels = async (apiUrl: string, apiKey: string) => {
setIsLoadingModels(true);
setError(null);
try {
const response = await fetch(`${apiUrl}/models`, {
headers: {
Authorization: `Bearer ${apiKey}`,
},
});
if (!response.ok) {
throw new Error("Failed to fetch models");
}
const res = await response.json();
setModels(res.data);
const apiUrl = form.watch("apiUrl");
const apiKey = form.watch("apiKey");
// Set default model to gpt-4 if present
const defaultModel = res.data.find(
(model: Model) => model.id === "gpt-4",
);
if (defaultModel) {
form.setValue("model", defaultModel.id);
return defaultModel.id;
}
} catch (error) {
setError("Failed to fetch models. Please check your API URL and Key.");
setModels([]);
} finally {
setIsLoadingModels(false);
}
};
const { data: models, isLoading: isLoadingServerModels } =
api.ai.getModels.useQuery(
{
apiUrl: apiUrl ?? "",
apiKey: apiKey ?? "",
},
{
enabled: !!apiUrl && !!apiKey,
onError: (error) => {
setError(`Failed to fetch models: ${error.message}`);
},
},
);
useEffect(() => {
const apiUrl = form.watch("apiUrl");
const apiKey = form.watch("apiKey");
if (apiUrl && apiKey) {
form.setValue("model", "");
fetchModels(apiUrl, apiKey);
}
}, [form.watch("apiUrl"), form.watch("apiKey")]);
const onSubmit = async (data: Schema) => {
try {
console.log("Form data:", data);
console.log("Current model value:", form.getValues("model"));
await mutateAsync({
...data,
aiId: aiId || "",
@@ -148,8 +123,9 @@ export const HandleAi = ({ aiId }: Props) => {
refetch();
setOpen(false);
} catch (error) {
console.error("Submit error:", error);
toast.error("Failed to save AI settings");
toast.error("Failed to save AI settings", {
description: error instanceof Error ? error.message : "Unknown error",
});
}
};
@@ -232,13 +208,13 @@ export const HandleAi = ({ aiId }: Props) => {
)}
/>
{isLoadingModels && (
{isLoadingServerModels && (
<span className="text-sm text-muted-foreground">
Loading models...
</span>
)}
{!isLoadingModels && models.length > 0 && (
{!isLoadingServerModels && models && models.length > 0 && (
<FormField
control={form.control}
name="model"

View File

@@ -25,6 +25,10 @@ import {
addNewService,
checkServiceAccess,
} from "@dokploy/server/services/user";
import {
getProviderHeaders,
type Model,
} from "@dokploy/server/utils/ai/select-ai-provider";
import { TRPCError } from "@trpc/server";
import { z } from "zod";
@@ -41,6 +45,59 @@ export const aiRouter = createTRPCRouter({
}
return aiSetting;
}),
getModels: protectedProcedure
.input(z.object({ apiUrl: z.string().min(1), apiKey: z.string().min(1) }))
.query(async ({ input }) => {
try {
const headers = getProviderHeaders(input.apiUrl, input.apiKey);
const response = await fetch(`${input.apiUrl}/models`, { headers });
if (!response.ok) {
const errorText = await response.text();
throw new Error(`Failed to fetch models: ${errorText}`);
}
const res = await response.json();
if (Array.isArray(res)) {
return res.map((model) => ({
id: model.id || model.name,
object: "model",
created: Date.now(),
owned_by: "provider",
}));
}
if (res.models) {
return res.models.map((model: any) => ({
id: model.id || model.name,
object: "model",
created: Date.now(),
owned_by: "provider",
})) as Model[];
}
if (res.data) {
return res.data as Model[];
}
const possibleModels =
(Object.values(res).find(Array.isArray) as any[]) || [];
return possibleModels.map((model) => ({
id: model.id || model.name,
object: "model",
created: Date.now(),
owned_by: "provider",
})) as Model[];
} catch (error) {
console.log("Error fetching models:", error);
throw new TRPCError({
code: "BAD_REQUEST",
message: error instanceof Error ? error?.message : `Error: ${error}`,
});
}
}),
create: adminProcedure.input(apiCreateAi).mutation(async ({ ctx, input }) => {
return await saveAiSettings(ctx.session.activeOrganizationId, input);
}),

View File

@@ -201,6 +201,8 @@ export const suggestVariants = async ({
return result;
}
console.log(object);
throw new TRPCError({
code: "NOT_FOUND",
message: "No suggestions found",

View File

@@ -74,8 +74,39 @@ export function selectAIProvider(config: { apiUrl: string; apiKey: string }) {
headers: {
Authorization: `Bearer ${config.apiKey}`,
},
)};
});
default:
throw new Error(`Unsupported AI provider: ${providerName}`);
}
}
export const getProviderHeaders = (
apiUrl: string,
apiKey: string,
): Record<string, string> => {
// Anthropic
if (apiUrl.includes("anthropic")) {
return {
"x-api-key": apiKey,
"anthropic-version": "2023-06-01",
};
}
// Mistral
if (apiUrl.includes("mistral")) {
return {
Authorization: apiKey,
};
}
// Default (OpenAI style)
return {
Authorization: `Bearer ${apiKey}`,
};
};
export interface Model {
id: string;
object: string;
created: number;
owned_by: string;
}