feat: 增加会话ID和消息记录功能

This commit is contained in:
zyh 2024-10-22 10:02:57 +00:00
parent f8c0f5eecf
commit 4064317a1a

View File

@ -18,7 +18,7 @@ async function chatAction({ context, request }: ActionFunctionArgs) {
return error as Response;
}
const { messages } = await request.json<{ messages: Messages }>();
const { messages, sessionId } = await request.json<{ messages: Messages; sessionId: string }>();
const stream = new SwitchableStream();
@ -27,7 +27,11 @@ async function chatAction({ context, request }: ActionFunctionArgs) {
toolChoice: 'none',
onFinish: async ({ text: content, finishReason }) => {
if (finishReason !== 'length') {
await recordTokenConsumption(userId, calculateTokensConsumed(messages, content));
const tokensConsumed = calculateTokensConsumed(messages, content);
await Promise.all([
recordTokenConsumption(userId, tokensConsumed),
recordChatHistory(userId, content, 'assistant', sessionId, tokensConsumed)
]);
return stream.close();
}
@ -48,6 +52,9 @@ async function chatAction({ context, request }: ActionFunctionArgs) {
},
};
// 记录用户的消息
await recordChatHistory(userId, messages[messages.length - 1].content, 'user', sessionId, 0);
const result = await streamText(messages, context.cloudflare.env, options);
stream.switchSource(result.toAIStream());
@ -68,6 +75,21 @@ async function chatAction({ context, request }: ActionFunctionArgs) {
}
}
async function recordChatHistory(userId: string, message: string, role: 'user' | 'assistant', sessionId: string, tokensUsed: number) {
try {
await db('chat_histories').insert({
user_id: userId,
message,
role,
session_id: sessionId,
tokens_used: tokensUsed,
});
} catch (error) {
console.error('Error recording chat history:', error);
// 这里我们只记录错误,不中断流程
}
}
async function recordTokenConsumption(userId: string, tokensConsumed: number) {
try {
await db.transaction(async (trx) => {