From df25c678d138f4d36079c55780277cf27046b229 Mon Sep 17 00:00:00 2001 From: Kirjava Date: Wed, 24 Jul 2024 14:47:48 +0100 Subject: [PATCH] feat: chat autoscroll (#6) --- .../bolt/app/components/chat/BaseChat.tsx | 10 ++-- .../bolt/app/components/chat/Chat.client.tsx | 6 ++- .../app/components/chat/Messages.client.tsx | 7 +-- packages/bolt/app/lib/hooks/index.ts | 1 + packages/bolt/app/lib/hooks/useSnapScroll.ts | 54 +++++++++++++++++++ 5 files changed, 71 insertions(+), 7 deletions(-) create mode 100644 packages/bolt/app/lib/hooks/useSnapScroll.ts diff --git a/packages/bolt/app/components/chat/BaseChat.tsx b/packages/bolt/app/components/chat/BaseChat.tsx index 9bad347..ada6235 100644 --- a/packages/bolt/app/components/chat/BaseChat.tsx +++ b/packages/bolt/app/components/chat/BaseChat.tsx @@ -1,6 +1,5 @@ import type { Message } from 'ai'; -import type { LegacyRef } from 'react'; -import React from 'react'; +import React, { type LegacyRef, type RefCallback } from 'react'; import { ClientOnly } from 'remix-utils/client-only'; import { classNames } from '../../utils/classNames'; import { IconButton } from '../ui/IconButton'; @@ -10,6 +9,8 @@ import { SendButton } from './SendButton.client'; interface BaseChatProps { textareaRef?: LegacyRef | undefined; + messageRef?: RefCallback | undefined; + scrollRef?: RefCallback | undefined; chatStarted?: boolean; isStreaming?: boolean; messages?: Message[]; @@ -30,6 +31,8 @@ export const BaseChat = React.forwardRef( ( { textareaRef, + messageRef, + scrollRef, chatStarted = false, isStreaming = false, enhancingPrompt = false, @@ -47,7 +50,7 @@ export const BaseChat = React.forwardRef( return (
-
+
{!chatStarted && (
@@ -71,6 +74,7 @@ export const BaseChat = React.forwardRef( {() => { return chatStarted ? ( { diff --git a/packages/bolt/app/components/chat/Messages.client.tsx b/packages/bolt/app/components/chat/Messages.client.tsx index a0e5921..3783ea1 100644 --- a/packages/bolt/app/components/chat/Messages.client.tsx +++ b/packages/bolt/app/components/chat/Messages.client.tsx @@ -2,6 +2,7 @@ import type { Message } from 'ai'; import { classNames } from '../../utils/classNames'; import { AssistantMessage } from './AssistantMessage'; import { UserMessage } from './UserMessage'; +import React from 'react'; interface MessagesProps { id?: string; @@ -10,11 +11,11 @@ interface MessagesProps { messages?: Message[]; } -export function Messages(props: MessagesProps) { +export const Messages = React.forwardRef((props: MessagesProps, ref) => { const { id, isStreaming = false, messages = [] } = props; return ( -
+
{messages.length > 0 ? messages.map((message, i) => { const { role, content } = message; @@ -61,4 +62,4 @@ export function Messages(props: MessagesProps) { {isStreaming &&
}
); -} +}); diff --git a/packages/bolt/app/lib/hooks/index.ts b/packages/bolt/app/lib/hooks/index.ts index d254836..9837e12 100644 --- a/packages/bolt/app/lib/hooks/index.ts +++ b/packages/bolt/app/lib/hooks/index.ts @@ -1,2 +1,3 @@ export * from './useMessageParser'; export * from './usePromptEnhancer'; +export * from './useSnapScroll'; diff --git a/packages/bolt/app/lib/hooks/useSnapScroll.ts b/packages/bolt/app/lib/hooks/useSnapScroll.ts new file mode 100644 index 0000000..65e229f --- /dev/null +++ b/packages/bolt/app/lib/hooks/useSnapScroll.ts @@ -0,0 +1,54 @@ +import { useRef, useCallback } from 'react'; + +export function useSnapScroll() { + const autoScrollRef = useRef(true); + const scrollNodeRef = useRef(); + const onScrollRef = useRef<() => void>(); + const observerRef = useRef(); + + const messageRef = useCallback((node: HTMLDivElement | null) => { + if (node) { + const observer = new ResizeObserver(() => { + if (autoScrollRef.current) { + if (scrollNodeRef.current) { + const { scrollHeight, clientHeight } = scrollNodeRef.current; + const scrollTarget = scrollHeight - clientHeight; + + scrollNodeRef.current.scrollTo({ + top: scrollTarget, + }); + } + } + }); + + observer.observe(node); + } else { + observerRef.current?.disconnect(); + observerRef.current = undefined; + } + }, []); + + const scrollRef = useCallback((node: HTMLDivElement | null) => { + if (node) { + onScrollRef.current = () => { + const { scrollTop, scrollHeight, clientHeight } = node; + const scrollTarget = scrollHeight - clientHeight; + + autoScrollRef.current = Math.abs(scrollTop - scrollTarget) <= 10; + }; + + node.addEventListener('scroll', onScrollRef.current); + + scrollNodeRef.current = node; + } else { + if (onScrollRef.current) { + scrollNodeRef.current?.removeEventListener('scroll', onScrollRef.current); + } + + scrollNodeRef.current = undefined; + onScrollRef.current = undefined; + } + }, []); + + return [messageRef, scrollRef]; +}