Merge pull request #213 from thecodacus/code-streaming

feat(code-streaming): added code streaming to editor while AI is writing files
This commit is contained in:
Chris Mahoney 2024-11-12 12:41:58 -06:00 committed by GitHub
commit a081f8bec5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 76 additions and 26 deletions

View File

@ -24,7 +24,7 @@ const EXAMPLE_PROMPTS = [
{ text: 'How do I center a div?' }, { text: 'How do I center a div?' },
]; ];
const providerList = [...new Set(MODEL_LIST.map((model) => model.provider))] const providerList = [...new Set(MODEL_LIST.map((model) => model.provider))];
const ModelSelector = ({ model, setModel, provider, setProvider, modelList, providerList }) => { const ModelSelector = ({ model, setModel, provider, setProvider, modelList, providerList }) => {
return ( return (
@ -33,7 +33,7 @@ const ModelSelector = ({ model, setModel, provider, setProvider, modelList, prov
value={provider} value={provider}
onChange={(e) => { onChange={(e) => {
setProvider(e.target.value); setProvider(e.target.value);
const firstModel = [...modelList].find(m => m.provider == e.target.value); const firstModel = [...modelList].find((m) => m.provider == e.target.value);
setModel(firstModel ? firstModel.name : ''); setModel(firstModel ? firstModel.name : '');
}} }}
className="flex-1 p-2 rounded-lg border border-bolt-elements-borderColor bg-bolt-elements-prompt-background text-bolt-elements-textPrimary focus:outline-none focus:ring-2 focus:ring-bolt-elements-focus transition-all" className="flex-1 p-2 rounded-lg border border-bolt-elements-borderColor bg-bolt-elements-prompt-background text-bolt-elements-textPrimary focus:outline-none focus:ring-2 focus:ring-bolt-elements-focus transition-all"
@ -58,7 +58,9 @@ const ModelSelector = ({ model, setModel, provider, setProvider, modelList, prov
onChange={(e) => setModel(e.target.value)} onChange={(e) => setModel(e.target.value)}
className="flex-1 p-2 rounded-lg border border-bolt-elements-borderColor bg-bolt-elements-prompt-background text-bolt-elements-textPrimary focus:outline-none focus:ring-2 focus:ring-bolt-elements-focus transition-all" className="flex-1 p-2 rounded-lg border border-bolt-elements-borderColor bg-bolt-elements-prompt-background text-bolt-elements-textPrimary focus:outline-none focus:ring-2 focus:ring-bolt-elements-focus transition-all"
> >
{[...modelList].filter(e => e.provider == provider && e.name).map((modelOption) => ( {[...modelList]
.filter((e) => e.provider == provider && e.name)
.map((modelOption) => (
<option key={modelOption.name} value={modelOption.name}> <option key={modelOption.name} value={modelOption.name}>
{modelOption.label} {modelOption.label}
</option> </option>
@ -81,10 +83,10 @@ interface BaseChatProps {
enhancingPrompt?: boolean; enhancingPrompt?: boolean;
promptEnhanced?: boolean; promptEnhanced?: boolean;
input?: string; input?: string;
model: string; model?: string;
setModel: (model: string) => void; setModel?: (model: string) => void;
provider: string; provider?: string;
setProvider: (provider: string) => void; setProvider?: (provider: string) => void;
handleStop?: () => void; handleStop?: () => void;
sendMessage?: (event: React.UIEvent, messageInput?: string) => void; sendMessage?: (event: React.UIEvent, messageInput?: string) => void;
handleInputChange?: (event: React.ChangeEvent<HTMLTextAreaElement>) => void; handleInputChange?: (event: React.ChangeEvent<HTMLTextAreaElement>) => void;
@ -144,7 +146,7 @@ export const BaseChat = React.forwardRef<HTMLDivElement, BaseChatProps>(
expires: 30, // 30 days expires: 30, // 30 days
secure: true, // Only send over HTTPS secure: true, // Only send over HTTPS
sameSite: 'strict', // Protect against CSRF sameSite: 'strict', // Protect against CSRF
path: '/' // Accessible across the site path: '/', // Accessible across the site
}); });
} catch (error) { } catch (error) {
console.error('Error saving API keys to cookies:', error); console.error('Error saving API keys to cookies:', error);
@ -281,7 +283,9 @@ export const BaseChat = React.forwardRef<HTMLDivElement, BaseChatProps>(
</div> </div>
{input.length > 3 ? ( {input.length > 3 ? (
<div className="text-xs text-bolt-elements-textTertiary"> <div className="text-xs text-bolt-elements-textTertiary">
Use <kbd className="kdb px-1.5 py-0.5 rounded bg-bolt-elements-background-depth-2">Shift</kbd> + <kbd className="kdb px-1.5 py-0.5 rounded bg-bolt-elements-background-depth-2">Return</kbd> for a new line Use <kbd className="kdb px-1.5 py-0.5 rounded bg-bolt-elements-background-depth-2">Shift</kbd> +{' '}
<kbd className="kdb px-1.5 py-0.5 rounded bg-bolt-elements-background-depth-2">Return</kbd> for
a new line
</div> </div>
) : null} ) : null}
</div> </div>

View File

@ -36,6 +36,10 @@ const messageParser = new StreamingMessageParser({
workbenchStore.runAction(data); workbenchStore.runAction(data);
}, },
onActionStream: (data) => {
logger.trace('onActionStream', data.action);
workbenchStore.runAction(data, true);
},
}, },
}); });

View File

@ -77,7 +77,7 @@ export class ActionRunner {
}); });
} }
async runAction(data: ActionCallbackData) { async runAction(data: ActionCallbackData, isStreaming: boolean = false) {
const { actionId } = data; const { actionId } = data;
const action = this.actions.get()[actionId]; const action = this.actions.get()[actionId];
@ -88,19 +88,22 @@ export class ActionRunner {
if (action.executed) { if (action.executed) {
return; return;
} }
if (isStreaming && action.type !== 'file') {
return;
}
this.#updateAction(actionId, { ...action, ...data.action, executed: true }); this.#updateAction(actionId, { ...action, ...data.action, executed: !isStreaming });
this.#currentExecutionPromise = this.#currentExecutionPromise this.#currentExecutionPromise = this.#currentExecutionPromise
.then(() => { .then(() => {
return this.#executeAction(actionId); return this.#executeAction(actionId, isStreaming);
}) })
.catch((error) => { .catch((error) => {
console.error('Action failed:', error); console.error('Action failed:', error);
}); });
} }
async #executeAction(actionId: string) { async #executeAction(actionId: string, isStreaming: boolean = false) {
const action = this.actions.get()[actionId]; const action = this.actions.get()[actionId];
this.#updateAction(actionId, { status: 'running' }); this.#updateAction(actionId, { status: 'running' });
@ -121,7 +124,7 @@ export class ActionRunner {
} }
} }
this.#updateAction(actionId, { status: action.abortSignal.aborted ? 'aborted' : 'complete' }); this.#updateAction(actionId, { status: isStreaming ? 'running' : action.abortSignal.aborted ? 'aborted' : 'complete' });
} catch (error) { } catch (error) {
this.#updateAction(actionId, { status: 'failed', error: 'Action failed' }); this.#updateAction(actionId, { status: 'failed', error: 'Action failed' });
logger.error(`[${action.type}]:Action failed\n\n`, error); logger.error(`[${action.type}]:Action failed\n\n`, error);

View File

@ -28,6 +28,7 @@ export interface ParserCallbacks {
onArtifactOpen?: ArtifactCallback; onArtifactOpen?: ArtifactCallback;
onArtifactClose?: ArtifactCallback; onArtifactClose?: ArtifactCallback;
onActionOpen?: ActionCallback; onActionOpen?: ActionCallback;
onActionStream?: ActionCallback;
onActionClose?: ActionCallback; onActionClose?: ActionCallback;
} }
@ -118,6 +119,21 @@ export class StreamingMessageParser {
i = closeIndex + ARTIFACT_ACTION_TAG_CLOSE.length; i = closeIndex + ARTIFACT_ACTION_TAG_CLOSE.length;
} else { } else {
if ('type' in currentAction && currentAction.type === 'file') {
let content = input.slice(i);
this._options.callbacks?.onActionStream?.({
artifactId: currentArtifact.id,
messageId,
actionId: String(state.actionId - 1),
action: {
...currentAction as FileAction,
content,
filePath: currentAction.filePath,
},
});
}
break; break;
} }
} else { } else {

View File

@ -11,7 +11,8 @@ import { PreviewsStore } from './previews';
import { TerminalStore } from './terminal'; import { TerminalStore } from './terminal';
import JSZip from 'jszip'; import JSZip from 'jszip';
import { saveAs } from 'file-saver'; import { saveAs } from 'file-saver';
import { Octokit } from "@octokit/rest"; import { Octokit, type RestEndpointMethodTypes } from "@octokit/rest";
import * as nodePath from 'node:path';
import type { WebContainerProcess } from '@webcontainer/api'; import type { WebContainerProcess } from '@webcontainer/api';
export interface ArtifactState { export interface ArtifactState {
@ -267,7 +268,7 @@ export class WorkbenchStore {
artifact.runner.addAction(data); artifact.runner.addAction(data);
} }
async runAction(data: ActionCallbackData) { async runAction(data: ActionCallbackData, isStreaming: boolean = false) {
const { messageId } = data; const { messageId } = data;
const artifact = this.#getArtifact(messageId); const artifact = this.#getArtifact(messageId);
@ -275,9 +276,30 @@ export class WorkbenchStore {
if (!artifact) { if (!artifact) {
unreachable('Artifact not found'); unreachable('Artifact not found');
} }
if (data.action.type === 'file') {
let wc = await webcontainer
const fullPath = nodePath.join(wc.workdir, data.action.filePath);
if (this.selectedFile.value !== fullPath) {
this.setSelectedFile(fullPath);
}
if (this.currentView.value !== 'code') {
this.currentView.set('code');
}
const doc = this.#editorStore.documents.get()[fullPath];
if (!doc) {
await artifact.runner.runAction(data, isStreaming);
}
this.#editorStore.updateFile(fullPath, data.action.content);
if (!isStreaming) {
this.resetCurrentDocument();
await artifact.runner.runAction(data);
}
} else {
artifact.runner.runAction(data); artifact.runner.runAction(data);
} }
}
#getArtifact(id: string) { #getArtifact(id: string) {
const artifacts = this.artifacts.get(); const artifacts = this.artifacts.get();
@ -360,9 +382,10 @@ export class WorkbenchStore {
const octokit = new Octokit({ auth: githubToken }); const octokit = new Octokit({ auth: githubToken });
// Check if the repository already exists before creating it // Check if the repository already exists before creating it
let repo let repo: RestEndpointMethodTypes["repos"]["get"]["response"]['data']
try { try {
repo = await octokit.repos.get({ owner: owner, repo: repoName }); let resp = await octokit.repos.get({ owner: owner, repo: repoName });
repo = resp.data
} catch (error) { } catch (error) {
if (error instanceof Error && 'status' in error && error.status === 404) { if (error instanceof Error && 'status' in error && error.status === 404) {
// Repository doesn't exist, so create a new one // Repository doesn't exist, so create a new one

View File

@ -117,5 +117,5 @@
"resolutions": { "resolutions": {
"@typescript-eslint/utils": "^8.0.0-alpha.30" "@typescript-eslint/utils": "^8.0.0-alpha.30"
}, },
"packageManager": "pnpm@9.12.2+sha512.22721b3a11f81661ae1ec68ce1a7b879425a1ca5b991c975b074ac220b187ce56c708fe5db69f4c962c989452eee76c82877f4ee80f474cebd61ee13461b6228" "packageManager": "pnpm@9.4.0"
} }