feat: support streaming for Gemini Pro (#3688)

* feat: support streaming for Gemini Pro

* feat: display texts smoothly

* chore: remove comments
This commit is contained in:
Fred Liang 2023-12-29 03:42:45 +08:00 committed by GitHub
parent 3ba598633c
commit 5cf58d9446
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 47 additions and 79 deletions

View File

@ -20,6 +20,7 @@ export class GeminiProApi implements LLMApi {
);
}
async chat(options: ChatOptions): Promise<void> {
const apiClient = this;
const messages = options.messages.map((v) => ({
role: v.role.replace("assistant", "model").replace("system", "user"),
parts: [{ text: v.content }],
@ -61,8 +62,7 @@ export class GeminiProApi implements LLMApi {
console.log("[Request] google payload: ", requestPayload);
// todo: support stream later
const shouldStream = false;
const shouldStream = !!options.config.stream;
const controller = new AbortController();
options.onController?.(controller);
try {
@ -82,13 +82,21 @@ export class GeminiProApi implements LLMApi {
if (shouldStream) {
let responseText = "";
let remainText = "";
let streamChatPath = chatPath.replace(
"generateContent",
"streamGenerateContent",
);
let finished = false;
const finish = () => {
finished = true;
options.onFinish(responseText + remainText);
};
// animate response to make it looks smooth
function animateResponseText() {
if (finished || controller.signal.aborted) {
responseText += remainText;
console.log("[Response Animation] finished");
finish();
return;
}
@ -105,87 +113,40 @@ export class GeminiProApi implements LLMApi {
// start animaion
animateResponseText();
fetch(streamChatPath, chatPayload)
.then((response) => {
const reader = response?.body?.getReader();
const decoder = new TextDecoder();
let partialData = "";
const finish = () => {
if (!finished) {
return reader?.read().then(function processText({
done,
value,
}): Promise<any> {
if (done) {
console.log("Stream complete");
// options.onFinish(responseText + remainText);
finished = true;
options.onFinish(responseText + remainText);
}
};
controller.signal.onabort = finish;
fetchEventSource(chatPath, {
...chatPayload,
async onopen(res) {
clearTimeout(requestTimeoutId);
const contentType = res.headers.get("content-type");
console.log(
"[OpenAI] request response content type: ",
contentType,
);
if (contentType?.startsWith("text/plain")) {
responseText = await res.clone().text();
return finish();
return Promise.resolve();
}
if (
!res.ok ||
!res.headers
.get("content-type")
?.startsWith(EventStreamContentType) ||
res.status !== 200
) {
const responseTexts = [responseText];
let extraInfo = await res.clone().text();
partialData += decoder.decode(value, { stream: true });
try {
const resJson = await res.clone().json();
extraInfo = prettyObject(resJson);
} catch {}
if (res.status === 401) {
responseTexts.push(Locale.Error.Unauthorized);
let data = JSON.parse(ensureProperEnding(partialData));
console.log(data);
let fetchText = apiClient.extractMessage(data[data.length - 1]);
console.log("[Response Animation] fetchText: ", fetchText);
remainText += fetchText;
} catch (error) {
// skip error message when parsing json
}
if (extraInfo) {
responseTexts.push(extraInfo);
}
responseText = responseTexts.join("\n\n");
return finish();
}
},
onmessage(msg) {
if (msg.data === "[DONE]" || finished) {
return finish();
}
const text = msg.data;
try {
const json = JSON.parse(text) as {
choices: Array<{
delta: {
content: string;
};
}>;
};
const delta = json.choices[0]?.delta?.content;
if (delta) {
remainText += delta;
}
} catch (e) {
console.error("[Request] parse error", text);
}
},
onclose() {
finish();
},
onerror(e) {
options.onError?.(e);
throw e;
},
openWhenHidden: true,
return reader.read().then(processText);
});
})
.catch((error) => {
console.error("Error:", error);
});
} else {
const res = await fetch(chatPath, chatPayload);
@ -220,3 +181,10 @@ export class GeminiProApi implements LLMApi {
return "/api/google/" + path;
}
}
function ensureProperEnding(str: string) {
if (str.startsWith("[") && !str.endsWith("]")) {
return str + "]";
}
return str;
}