diff --git a/packages/bolt/app/components/chat/BaseChat.tsx b/packages/bolt/app/components/chat/BaseChat.tsx index 9bad3478..ada6235f 100644 --- a/packages/bolt/app/components/chat/BaseChat.tsx +++ b/packages/bolt/app/components/chat/BaseChat.tsx @@ -1,6 +1,5 @@ import type { Message } from 'ai'; -import type { LegacyRef } from 'react'; -import React from 'react'; +import React, { type LegacyRef, type RefCallback } from 'react'; import { ClientOnly } from 'remix-utils/client-only'; import { classNames } from '../../utils/classNames'; import { IconButton } from '../ui/IconButton'; @@ -10,6 +9,8 @@ import { SendButton } from './SendButton.client'; interface BaseChatProps { textareaRef?: LegacyRef<HTMLTextAreaElement> | undefined; + messageRef?: RefCallback<HTMLDivElement> | undefined; + scrollRef?: RefCallback<HTMLDivElement> | undefined; chatStarted?: boolean; isStreaming?: boolean; messages?: Message[]; @@ -30,6 +31,8 @@ export const BaseChat = React.forwardRef<HTMLDivElement, BaseChatProps>( ( { textareaRef, + messageRef, + scrollRef, chatStarted = false, isStreaming = false, enhancingPrompt = false, @@ -47,7 +50,7 @@ export const BaseChat = React.forwardRef<HTMLDivElement, BaseChatProps>( return ( <div ref={ref} className="relative flex h-full w-full overflow-hidden "> - <div className="flex overflow-scroll w-full h-full"> + <div ref={scrollRef} className="flex overflow-scroll w-full h-full"> <div id="chat" className="flex flex-col w-full h-full px-6"> {!chatStarted && ( <div id="intro" className="mt-[20vh] mb-14 max-w-3xl mx-auto"> @@ -71,6 +74,7 @@ export const BaseChat = React.forwardRef<HTMLDivElement, BaseChatProps>( {() => { return chatStarted ? ( <Messages + ref={messageRef} className="flex flex-col w-full flex-1 max-w-3xl px-4 pb-10 mx-auto z-1" messages={messages} isStreaming={isStreaming} diff --git a/packages/bolt/app/components/chat/Chat.client.tsx b/packages/bolt/app/components/chat/Chat.client.tsx index 854a1fd0..7aada172 100644 --- a/packages/bolt/app/components/chat/Chat.client.tsx +++ b/packages/bolt/app/components/chat/Chat.client.tsx @@ -1,7 +1,7 @@ import { useChat } from 'ai/react'; import { useAnimate } from 'framer-motion'; import { useEffect, useRef, useState } from 'react'; -import { useMessageParser, usePromptEnhancer } from '../../lib/hooks'; +import { useMessageParser, usePromptEnhancer, useSnapScroll } from '../../lib/hooks'; import { chatStore } from '../../lib/stores/chat'; import { workbenchStore } from '../../lib/stores/workbench'; import { cubicEasingFn } from '../../utils/easings'; @@ -87,6 +87,8 @@ export function Chat() { textareaRef.current?.blur(); }; + const [messageRef, scrollRef] = useSnapScroll(); + return ( <BaseChat ref={animationScope} @@ -97,6 +99,8 @@ export function Chat() { enhancingPrompt={enhancingPrompt} promptEnhanced={promptEnhanced} sendMessage={sendMessage} + messageRef={messageRef} + scrollRef={scrollRef} handleInputChange={handleInputChange} handleStop={abort} messages={messages.map((message, i) => { diff --git a/packages/bolt/app/components/chat/Messages.client.tsx b/packages/bolt/app/components/chat/Messages.client.tsx index a0e5921a..3783ea17 100644 --- a/packages/bolt/app/components/chat/Messages.client.tsx +++ b/packages/bolt/app/components/chat/Messages.client.tsx @@ -2,6 +2,7 @@ import type { Message } from 'ai'; import { classNames } from '../../utils/classNames'; import { AssistantMessage } from './AssistantMessage'; import { UserMessage } from './UserMessage'; +import React from 'react'; interface MessagesProps { id?: string; @@ -10,11 +11,11 @@ interface MessagesProps { messages?: Message[]; } -export function Messages(props: MessagesProps) { +export const Messages = React.forwardRef<HTMLDivElement, MessagesProps>((props: MessagesProps, ref) => { const { id, isStreaming = false, messages = [] } = props; return ( - <div id={id} className={props.className}> + <div id={id} ref={ref} className={props.className}> {messages.length > 0 ? messages.map((message, i) => { const { role, content } = message; @@ -61,4 +62,4 @@ export function Messages(props: MessagesProps) { {isStreaming && <div className="text-center w-full i-svg-spinners:3-dots-fade text-4xl mt-4"></div>} </div> ); -} +}); diff --git a/packages/bolt/app/lib/hooks/index.ts b/packages/bolt/app/lib/hooks/index.ts index d2548368..9837e12d 100644 --- a/packages/bolt/app/lib/hooks/index.ts +++ b/packages/bolt/app/lib/hooks/index.ts @@ -1,2 +1,3 @@ export * from './useMessageParser'; export * from './usePromptEnhancer'; +export * from './useSnapScroll'; diff --git a/packages/bolt/app/lib/hooks/useSnapScroll.ts b/packages/bolt/app/lib/hooks/useSnapScroll.ts new file mode 100644 index 00000000..65e229f9 --- /dev/null +++ b/packages/bolt/app/lib/hooks/useSnapScroll.ts @@ -0,0 +1,54 @@ +import { useRef, useCallback } from 'react'; + +export function useSnapScroll() { + const autoScrollRef = useRef(true); + const scrollNodeRef = useRef<HTMLDivElement>(); + const onScrollRef = useRef<() => void>(); + const observerRef = useRef<ResizeObserver>(); + + const messageRef = useCallback((node: HTMLDivElement | null) => { + if (node) { + const observer = new ResizeObserver(() => { + if (autoScrollRef.current) { + if (scrollNodeRef.current) { + const { scrollHeight, clientHeight } = scrollNodeRef.current; + const scrollTarget = scrollHeight - clientHeight; + + scrollNodeRef.current.scrollTo({ + top: scrollTarget, + }); + } + } + }); + + observer.observe(node); + } else { + observerRef.current?.disconnect(); + observerRef.current = undefined; + } + }, []); + + const scrollRef = useCallback((node: HTMLDivElement | null) => { + if (node) { + onScrollRef.current = () => { + const { scrollTop, scrollHeight, clientHeight } = node; + const scrollTarget = scrollHeight - clientHeight; + + autoScrollRef.current = Math.abs(scrollTop - scrollTarget) <= 10; + }; + + node.addEventListener('scroll', onScrollRef.current); + + scrollNodeRef.current = node; + } else { + if (onScrollRef.current) { + scrollNodeRef.current?.removeEventListener('scroll', onScrollRef.current); + } + + scrollNodeRef.current = undefined; + onScrollRef.current = undefined; + } + }, []); + + return [messageRef, scrollRef]; +}