feat: chat autoscroll (#6)

This commit is contained in:
Kirjava 2024-07-24 14:47:48 +01:00 committed by GitHub
parent f4987a4ecd
commit df25c678d1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 71 additions and 7 deletions

View File

@ -1,6 +1,5 @@
import type { Message } from 'ai'; import type { Message } from 'ai';
import type { LegacyRef } from 'react'; import React, { type LegacyRef, type RefCallback } from 'react';
import React from 'react';
import { ClientOnly } from 'remix-utils/client-only'; import { ClientOnly } from 'remix-utils/client-only';
import { classNames } from '../../utils/classNames'; import { classNames } from '../../utils/classNames';
import { IconButton } from '../ui/IconButton'; import { IconButton } from '../ui/IconButton';
@ -10,6 +9,8 @@ import { SendButton } from './SendButton.client';
interface BaseChatProps { interface BaseChatProps {
textareaRef?: LegacyRef<HTMLTextAreaElement> | undefined; textareaRef?: LegacyRef<HTMLTextAreaElement> | undefined;
messageRef?: RefCallback<HTMLDivElement> | undefined;
scrollRef?: RefCallback<HTMLDivElement> | undefined;
chatStarted?: boolean; chatStarted?: boolean;
isStreaming?: boolean; isStreaming?: boolean;
messages?: Message[]; messages?: Message[];
@ -30,6 +31,8 @@ export const BaseChat = React.forwardRef<HTMLDivElement, BaseChatProps>(
( (
{ {
textareaRef, textareaRef,
messageRef,
scrollRef,
chatStarted = false, chatStarted = false,
isStreaming = false, isStreaming = false,
enhancingPrompt = false, enhancingPrompt = false,
@ -47,7 +50,7 @@ export const BaseChat = React.forwardRef<HTMLDivElement, BaseChatProps>(
return ( return (
<div ref={ref} className="relative flex h-full w-full overflow-hidden "> <div ref={ref} className="relative flex h-full w-full overflow-hidden ">
<div className="flex overflow-scroll w-full h-full"> <div ref={scrollRef} className="flex overflow-scroll w-full h-full">
<div id="chat" className="flex flex-col w-full h-full px-6"> <div id="chat" className="flex flex-col w-full h-full px-6">
{!chatStarted && ( {!chatStarted && (
<div id="intro" className="mt-[20vh] mb-14 max-w-3xl mx-auto"> <div id="intro" className="mt-[20vh] mb-14 max-w-3xl mx-auto">
@ -71,6 +74,7 @@ export const BaseChat = React.forwardRef<HTMLDivElement, BaseChatProps>(
{() => { {() => {
return chatStarted ? ( return chatStarted ? (
<Messages <Messages
ref={messageRef}
className="flex flex-col w-full flex-1 max-w-3xl px-4 pb-10 mx-auto z-1" className="flex flex-col w-full flex-1 max-w-3xl px-4 pb-10 mx-auto z-1"
messages={messages} messages={messages}
isStreaming={isStreaming} isStreaming={isStreaming}

View File

@ -1,7 +1,7 @@
import { useChat } from 'ai/react'; import { useChat } from 'ai/react';
import { useAnimate } from 'framer-motion'; import { useAnimate } from 'framer-motion';
import { useEffect, useRef, useState } from 'react'; import { useEffect, useRef, useState } from 'react';
import { useMessageParser, usePromptEnhancer } from '../../lib/hooks'; import { useMessageParser, usePromptEnhancer, useSnapScroll } from '../../lib/hooks';
import { chatStore } from '../../lib/stores/chat'; import { chatStore } from '../../lib/stores/chat';
import { workbenchStore } from '../../lib/stores/workbench'; import { workbenchStore } from '../../lib/stores/workbench';
import { cubicEasingFn } from '../../utils/easings'; import { cubicEasingFn } from '../../utils/easings';
@ -87,6 +87,8 @@ export function Chat() {
textareaRef.current?.blur(); textareaRef.current?.blur();
}; };
const [messageRef, scrollRef] = useSnapScroll();
return ( return (
<BaseChat <BaseChat
ref={animationScope} ref={animationScope}
@ -97,6 +99,8 @@ export function Chat() {
enhancingPrompt={enhancingPrompt} enhancingPrompt={enhancingPrompt}
promptEnhanced={promptEnhanced} promptEnhanced={promptEnhanced}
sendMessage={sendMessage} sendMessage={sendMessage}
messageRef={messageRef}
scrollRef={scrollRef}
handleInputChange={handleInputChange} handleInputChange={handleInputChange}
handleStop={abort} handleStop={abort}
messages={messages.map((message, i) => { messages={messages.map((message, i) => {

View File

@ -2,6 +2,7 @@ import type { Message } from 'ai';
import { classNames } from '../../utils/classNames'; import { classNames } from '../../utils/classNames';
import { AssistantMessage } from './AssistantMessage'; import { AssistantMessage } from './AssistantMessage';
import { UserMessage } from './UserMessage'; import { UserMessage } from './UserMessage';
import React from 'react';
interface MessagesProps { interface MessagesProps {
id?: string; id?: string;
@ -10,11 +11,11 @@ interface MessagesProps {
messages?: Message[]; messages?: Message[];
} }
export function Messages(props: MessagesProps) { export const Messages = React.forwardRef<HTMLDivElement, MessagesProps>((props: MessagesProps, ref) => {
const { id, isStreaming = false, messages = [] } = props; const { id, isStreaming = false, messages = [] } = props;
return ( return (
<div id={id} className={props.className}> <div id={id} ref={ref} className={props.className}>
{messages.length > 0 {messages.length > 0
? messages.map((message, i) => { ? messages.map((message, i) => {
const { role, content } = message; const { role, content } = message;
@ -61,4 +62,4 @@ export function Messages(props: MessagesProps) {
{isStreaming && <div className="text-center w-full i-svg-spinners:3-dots-fade text-4xl mt-4"></div>} {isStreaming && <div className="text-center w-full i-svg-spinners:3-dots-fade text-4xl mt-4"></div>}
</div> </div>
); );
} });

View File

@ -1,2 +1,3 @@
export * from './useMessageParser'; export * from './useMessageParser';
export * from './usePromptEnhancer'; export * from './usePromptEnhancer';
export * from './useSnapScroll';

View File

@ -0,0 +1,54 @@
import { useRef, useCallback } from 'react';
export function useSnapScroll() {
const autoScrollRef = useRef(true);
const scrollNodeRef = useRef<HTMLDivElement>();
const onScrollRef = useRef<() => void>();
const observerRef = useRef<ResizeObserver>();
const messageRef = useCallback((node: HTMLDivElement | null) => {
if (node) {
const observer = new ResizeObserver(() => {
if (autoScrollRef.current) {
if (scrollNodeRef.current) {
const { scrollHeight, clientHeight } = scrollNodeRef.current;
const scrollTarget = scrollHeight - clientHeight;
scrollNodeRef.current.scrollTo({
top: scrollTarget,
});
}
}
});
observer.observe(node);
} else {
observerRef.current?.disconnect();
observerRef.current = undefined;
}
}, []);
const scrollRef = useCallback((node: HTMLDivElement | null) => {
if (node) {
onScrollRef.current = () => {
const { scrollTop, scrollHeight, clientHeight } = node;
const scrollTarget = scrollHeight - clientHeight;
autoScrollRef.current = Math.abs(scrollTop - scrollTarget) <= 10;
};
node.addEventListener('scroll', onScrollRef.current);
scrollNodeRef.current = node;
} else {
if (onScrollRef.current) {
scrollNodeRef.current?.removeEventListener('scroll', onScrollRef.current);
}
scrollNodeRef.current = undefined;
onScrollRef.current = undefined;
}
}, []);
return [messageRef, scrollRef];
}