Merge pull request #578 from thecodacus/context-optimization

feat(context optimization): Optimize LLM Context Management and File Handling
This commit is contained in:
Anirban Kar 2024-12-14 02:08:43 +05:30 committed by GitHub
commit 8c4397a19f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 119 additions and 10 deletions

View File

@ -92,6 +92,7 @@ export const ChatImpl = memo(
const [chatStarted, setChatStarted] = useState(initialMessages.length > 0);
const [uploadedFiles, setUploadedFiles] = useState<File[]>([]); // Move here
const [imageDataList, setImageDataList] = useState<string[]>([]); // Move here
const files = useStore(workbenchStore.files);
const { activeProviders } = useSettings();
const [model, setModel] = useState(() => {
@ -113,6 +114,7 @@ export const ChatImpl = memo(
api: '/api/chat',
body: {
apiKeys,
files,
},
onError: (error) => {
logger.error('Request failed\n\n', error);

View File

@ -3,6 +3,7 @@ import { getModel } from '~/lib/.server/llm/model';
import { MAX_TOKENS } from './constants';
import { getSystemPrompt } from './prompts';
import { DEFAULT_MODEL, DEFAULT_PROVIDER, getModelList, MODEL_REGEX, PROVIDER_REGEX } from '~/utils/constants';
import ignore from 'ignore';
import type { IProviderSetting } from '~/types/model';
interface ToolResult<Name extends string, Args, Result> {
@ -23,6 +24,78 @@ export type Messages = Message[];
export type StreamingOptions = Omit<Parameters<typeof _streamText>[0], 'model'>;
export interface File {
type: 'file';
content: string;
isBinary: boolean;
}
export interface Folder {
type: 'folder';
}
type Dirent = File | Folder;
export type FileMap = Record<string, Dirent | undefined>;
function simplifyBoltActions(input: string): string {
// Using regex to match boltAction tags that have type="file"
const regex = /(<boltAction[^>]*type="file"[^>]*>)([\s\S]*?)(<\/boltAction>)/g;
// Replace each matching occurrence
return input.replace(regex, (_0, openingTag, _2, closingTag) => {
return `${openingTag}\n ...\n ${closingTag}`;
});
}
// Common patterns to ignore, similar to .gitignore
const IGNORE_PATTERNS = [
'node_modules/**',
'.git/**',
'dist/**',
'build/**',
'.next/**',
'coverage/**',
'.cache/**',
'.vscode/**',
'.idea/**',
'**/*.log',
'**/.DS_Store',
'**/npm-debug.log*',
'**/yarn-debug.log*',
'**/yarn-error.log*',
'**/*lock.json',
'**/*lock.yml',
];
const ig = ignore().add(IGNORE_PATTERNS);
function createFilesContext(files: FileMap) {
let filePaths = Object.keys(files);
filePaths = filePaths.filter((x) => {
const relPath = x.replace('/home/project/', '');
return !ig.ignores(relPath);
});
const fileContexts = filePaths
.filter((x) => files[x] && files[x].type == 'file')
.map((path) => {
const dirent = files[path];
if (!dirent || dirent.type == 'folder') {
return '';
}
const codeWithLinesNumbers = dirent.content
.split('\n')
.map((v, i) => `${i + 1}|${v}`)
.join('\n');
return `<file path="${path}">\n${codeWithLinesNumbers}\n</file>`;
});
return `Below are the code files present in the webcontainer:\ncode format:\n<line number>|<line content>\n <codebase>${fileContexts.join('\n\n')}\n\n</codebase>`;
}
function extractPropertiesFromMessage(message: Message): { model: string; provider: string; content: string } {
const textContent = Array.isArray(message.content)
? message.content.find((item) => item.type === 'text')?.text || ''
@ -64,9 +137,10 @@ export async function streamText(props: {
env: Env;
options?: StreamingOptions;
apiKeys?: Record<string, string>;
files?: FileMap;
providerSettings?: Record<string, IProviderSetting>;
}) {
const { messages, env, options, apiKeys, providerSettings } = props;
const { messages, env, options, apiKeys, files, providerSettings } = props;
let currentModel = DEFAULT_MODEL;
let currentProvider = DEFAULT_PROVIDER.name;
const MODEL_LIST = await getModelList(apiKeys || {}, providerSettings);
@ -80,6 +154,11 @@ export async function streamText(props: {
currentProvider = provider;
return { ...message, content };
} else if (message.role == 'assistant') {
let content = message.content;
content = simplifyBoltActions(content);
return { ...message, content };
}
@ -90,9 +169,17 @@ export async function streamText(props: {
const dynamicMaxTokens = modelDetails && modelDetails.maxTokenAllowed ? modelDetails.maxTokenAllowed : MAX_TOKENS;
let systemPrompt = getSystemPrompt();
let codeContext = '';
if (files) {
codeContext = createFilesContext(files);
systemPrompt = `${systemPrompt}\n\n ${codeContext}`;
}
return _streamText({
model: getModel(currentProvider, currentModel, env, apiKeys, providerSettings) as any,
system: getSystemPrompt(),
system: systemPrompt,
maxTokens: dynamicMaxTokens,
messages: convertToCoreMessages(processedMessages as any),
...options,

View File

@ -23,14 +23,14 @@ const messageParser = new StreamingMessageParser({
logger.trace('onActionOpen', data.action);
// we only add shell actions when when the close tag got parsed because only then we have the content
if (data.action.type !== 'shell') {
if (data.action.type === 'file') {
workbenchStore.addAction(data);
}
},
onActionClose: (data) => {
logger.trace('onActionClose', data.action);
if (data.action.type === 'shell') {
if (data.action.type !== 'file') {
workbenchStore.addAction(data);
}

View File

@ -262,9 +262,9 @@ export class WorkbenchStore {
this.artifacts.setKey(messageId, { ...artifact, ...state });
}
addAction(data: ActionCallbackData) {
this._addAction(data);
// this._addAction(data);
// this.addToExecutionQueue(()=>this._addAction(data))
this.addToExecutionQueue(() => this._addAction(data));
}
async _addAction(data: ActionCallbackData) {
const { messageId } = data;
@ -294,6 +294,12 @@ export class WorkbenchStore {
unreachable('Artifact not found');
}
const action = artifact.runner.actions.get()[data.actionId];
if (action.executed) {
return;
}
if (data.action.type === 'file') {
const wc = await webcontainer;
const fullPath = nodePath.join(wc.workdir, data.action.filePath);

View File

@ -30,9 +30,9 @@ function parseCookies(cookieHeader: string) {
}
async function chatAction({ context, request }: ActionFunctionArgs) {
const { messages } = await request.json<{
const { messages, files } = await request.json<{
messages: Messages;
model: string;
files: any;
}>();
const cookieHeader = request.headers.get('Cookie');
@ -64,13 +64,27 @@ async function chatAction({ context, request }: ActionFunctionArgs) {
messages.push({ role: 'assistant', content });
messages.push({ role: 'user', content: CONTINUE_PROMPT });
const result = await streamText({ messages, env: context.cloudflare.env, options, apiKeys, providerSettings });
const result = await streamText({
messages,
env: context.cloudflare.env,
options,
apiKeys,
files,
providerSettings,
});
return stream.switchSource(result.toAIStream());
},
};
const result = await streamText({ messages, env: context.cloudflare.env, options, apiKeys, providerSettings });
const result = await streamText({
messages,
env: context.cloudflare.env,
options,
apiKeys,
files,
providerSettings,
});
stream.switchSource(result.toAIStream());