feat: roles must alternate between user and assistant in claude, so add a fake assistant message between two user messages

This commit is contained in:
butterfly 2024-04-07 18:02:31 +08:00
parent 768decde93
commit 86b5c55855
2 changed files with 20 additions and 153 deletions

View File

@ -69,31 +69,21 @@ const ClaudeMapper = {
system: "user",
} as const;
const keys = ["claude-2, claude-instant-1"];
export class ClaudeApi implements LLMApi {
extractMessage(res: any) {
console.log("[Response] claude response: ", res);
return res.completion;
return res?.content?.[0]?.text;
}
async chatComplete(options: ChatOptions): Promise<void> {
const ClaudeMapper: Record<RequestMessage["role"], string> = {
assistant: "Assistant",
user: "Human",
system: "Human",
};
async chat(options: ChatOptions): Promise<void> {
const visionModel = isVisionModel(options.config.model);
const accessStore = useAccessStore.getState();
const shouldStream = !!options.config.stream;
const prompt = options.messages
.map((v) => ({
role: ClaudeMapper[v.role] ?? "Human",
content: v.content,
}))
.map((v) => `\n\n${v.role}: ${v.content}`)
.join("");
const modelConfig = {
...useAppConfig.getState().modelConfig,
...useChatStore.getState().currentSession().mask.modelConfig,
@ -102,142 +92,28 @@ export class ClaudeApi implements LLMApi {
},
};
const requestBody: ChatRequest = {
prompt,
stream: shouldStream,
const messages = [...options.messages];
model: modelConfig.model,
max_tokens_to_sample: modelConfig.max_tokens,
temperature: modelConfig.temperature,
top_p: modelConfig.top_p,
// top_k: modelConfig.top_k,
top_k: 5,
};
const keys = ["system", "user"];
const path = this.path(Anthropic.ChatPath1);
// roles must alternate between "user" and "assistant" in claude, so add a fake assistant message between two user messages
for (let i = 0; i < messages.length - 1; i++) {
const message = messages[i];
const nextMessage = messages[i + 1];
const controller = new AbortController();
options.onController?.(controller);
const payload = {
method: "POST",
body: JSON.stringify(requestBody),
signal: controller.signal,
headers: {
"Content-Type": "application/json",
// Accept: "application/json",
"x-api-key": accessStore.anthropicApiKey,
"anthropic-version": accessStore.anthropicApiVersion,
Authorization: getAuthKey(accessStore.anthropicApiKey),
if (keys.includes(message.role) && keys.includes(nextMessage.role)) {
messages[i] = [
message,
{
role: "assistant",
content: ";",
},
// mode: "no-cors" as RequestMode,
};
if (shouldStream) {
try {
const context = {
text: "",
finished: false,
};
const finish = () => {
if (!context.finished) {
options.onFinish(context.text);
context.finished = true;
}
};
controller.signal.onabort = finish;
fetchEventSource(path, {
...payload,
async onopen(res) {
const contentType = res.headers.get("content-type");
console.log("response content type: ", contentType);
if (contentType?.startsWith("text/plain")) {
context.text = await res.clone().text();
return finish();
}
if (
!res.ok ||
!res.headers
.get("content-type")
?.startsWith(EventStreamContentType) ||
res.status !== 200
) {
const responseTexts = [context.text];
let extraInfo = await res.clone().text();
try {
const resJson = await res.clone().json();
extraInfo = prettyObject(resJson);
} catch {}
if (res.status === 401) {
responseTexts.push(Locale.Error.Unauthorized);
}
if (extraInfo) {
responseTexts.push(extraInfo);
}
context.text = responseTexts.join("\n\n");
return finish();
}
},
onmessage(msg) {
if (msg.data === "[DONE]" || context.finished) {
return finish();
}
const chunk = msg.data;
try {
const chunkJson = JSON.parse(chunk) as ChatStreamResponse;
const delta = chunkJson.completion;
if (delta) {
context.text += delta;
options.onUpdate?.(context.text, delta);
}
} catch (e) {
console.error("[Request] parse error", chunk, msg);
}
},
onclose() {
finish();
},
onerror(e) {
options.onError?.(e);
},
openWhenHidden: true,
});
} catch (e) {
console.error("failed to chat", e);
options.onError?.(e as Error);
}
} else {
try {
controller.signal.onabort = () => options.onFinish("");
const res = await fetch(path, payload);
const resJson = await res.json();
const message = this.extractMessage(resJson);
options.onFinish(message);
} catch (e) {
console.error("failed to chat", e);
options.onError?.(e as Error);
] as any;
}
}
}
async chat(options: ChatOptions): Promise<void> {
const visionModel = isVisionModel(options.config.model);
const accessStore = useAccessStore.getState();
const shouldStream = !!options.config.stream;
const prompt = options.messages
const prompt = messages
.flat()
.filter((v) => {
if (!v.content) return false;
if (typeof v.content === "string" && !v.content.trim()) return false;
@ -285,14 +161,6 @@ export class ClaudeApi implements LLMApi {
};
});
const modelConfig = {
...useAppConfig.getState().modelConfig,
...useChatStore.getState().currentSession().mask.modelConfig,
...{
model: options.config.model,
},
};
const requestBody: AnthropicChatRequest = {
messages: prompt,
stream: shouldStream,

View File

@ -496,7 +496,6 @@ export const useChatStore = createPersistStore(
tokenCount += estimateTokenLength(getMessageTextContent(msg));
reversedRecentMessages.push(msg);
}
// concat all messages
const recentMessages = [
...systemPrompts,