feat(session): encrypt data and fix renewal (#38)

This commit is contained in:
Roberto Vidal 2024-08-19 17:39:37 +02:00 committed by GitHub
parent b939a0af2d
commit 44226db359
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 170 additions and 50 deletions

View File

@ -4,44 +4,59 @@ import { CLIENT_ID, CLIENT_ORIGIN } from '~/lib/constants';
import { request as doRequest } from '~/lib/fetch';
import { logger } from '~/utils/logger';
import type { Identity } from '~/lib/analytics';
import { decrypt, encrypt } from '~/lib/crypto';
const DEV_SESSION_SECRET = import.meta.env.DEV ? 'LZQMrERo3Ewn/AbpSYJ9aw==' : undefined;
const DEV_PAYLOAD_SECRET = import.meta.env.DEV ? '2zAyrhjcdFeXk0YEDzilMXbdrGAiR+8ACIUgFNfjLaI=' : undefined;
const TOKEN_KEY = 't';
const EXPIRES_KEY = 'e';
const USER_ID_KEY = 'u';
const SEGMENT_KEY = 's';
interface SessionData {
refresh: string;
expiresAt: number;
userId: string | null;
segmentWriteKey: string | null;
[TOKEN_KEY]: string;
[EXPIRES_KEY]: number;
[USER_ID_KEY]?: string;
[SEGMENT_KEY]?: string;
}
export async function isAuthenticated(request: Request, env: Env) {
const { session, sessionStorage } = await getSession(request, env);
const token = session.get('refresh');
const sessionData: SessionData | null = await decryptSessionData(env, session.get('d'));
const header = async (cookie: Promise<string>) => ({ headers: { 'Set-Cookie': await cookie } });
const destroy = () => header(sessionStorage.destroySession(session));
if (token == null) {
if (sessionData?.[TOKEN_KEY] == null) {
return { authenticated: false as const, response: await destroy() };
}
const expiresAt = session.get('expiresAt') ?? 0;
const expiresAt = sessionData[EXPIRES_KEY] ?? 0;
if (Date.now() < expiresAt) {
return { authenticated: true as const };
}
logger.debug('Renewing token');
let data: Awaited<ReturnType<typeof refreshToken>> | null = null;
try {
data = await refreshToken(token);
} catch {
data = await refreshToken(sessionData[TOKEN_KEY]);
} catch (error) {
// we can ignore the error here because it's handled below
logger.error(error);
}
if (data != null) {
const expiresAt = cookieExpiration(data.expires_in, data.created_at);
session.set('expiresAt', expiresAt);
const newSessionData = { ...sessionData, [EXPIRES_KEY]: expiresAt };
const encryptedData = await encryptSessionData(env, newSessionData);
session.set('d', encryptedData);
return { authenticated: true as const, response: await header(sessionStorage.commitSession(session)) };
} else {
@ -59,13 +74,15 @@ export async function createUserSession(
const expiresAt = cookieExpiration(tokens.expires_in, tokens.created_at);
session.set('refresh', tokens.refresh);
session.set('expiresAt', expiresAt);
const sessionData: SessionData = {
[TOKEN_KEY]: tokens.refresh,
[EXPIRES_KEY]: expiresAt,
[USER_ID_KEY]: identity?.userId ?? undefined,
[SEGMENT_KEY]: identity?.segmentWriteKey ?? undefined,
};
if (identity) {
session.set('userId', identity.userId ?? null);
session.set('segmentWriteKey', identity.segmentWriteKey ?? null);
}
const encryptedData = await encryptSessionData(env, sessionData);
session.set('d', encryptedData);
return {
headers: {
@ -77,7 +94,7 @@ export async function createUserSession(
}
function getSessionStorage(cloudflareEnv: Env) {
return createCookieSessionStorage<SessionData>({
return createCookieSessionStorage<{ d: string }>({
cookie: {
name: '__session',
httpOnly: true,
@ -91,7 +108,11 @@ function getSessionStorage(cloudflareEnv: Env) {
export async function logout(request: Request, env: Env) {
const { session, sessionStorage } = await getSession(request, env);
revokeToken(session.get('refresh'));
const sessionData = await decryptSessionData(env, session.get('d'));
if (sessionData) {
revokeToken(sessionData[TOKEN_KEY]);
}
return redirect('/login', {
headers: {
@ -106,7 +127,18 @@ export function validateAccessToken(access: string) {
return jwtPayload.bolt === true;
}
export async function getSession(request: Request, env: Env) {
export async function getSessionData(request: Request, env: Env) {
const { session } = await getSession(request, env);
const decrypted = await decryptSessionData(env, session.get('d'));
return {
userId: decrypted?.[USER_ID_KEY],
segmentWriteKey: decrypted?.[SEGMENT_KEY],
};
}
async function getSession(request: Request, env: Env) {
const sessionStorage = getSessionStorage(env);
const cookie = request.headers.get('Cookie');
@ -117,12 +149,15 @@ async function refreshToken(refresh: string): Promise<{ expires_in: number; crea
const response = await doRequest(`${CLIENT_ORIGIN}/oauth/token`, {
method: 'POST',
body: urlParams({ grant_type: 'refresh_token', client_id: CLIENT_ID, refresh_token: refresh }),
headers: {
'content-type': 'application/x-www-form-urlencoded',
},
});
const body = await response.json();
if (!response.ok) {
throw new Error(`Unable to refresh token\n${JSON.stringify(body)}`);
throw new Error(`Unable to refresh token\n${response.status} ${JSON.stringify(body)}`);
}
const { access_token: access } = body;
@ -151,6 +186,9 @@ async function revokeToken(refresh?: string) {
token_type_hint: 'refresh_token',
client_id: CLIENT_ID,
}),
headers: {
'content-type': 'application/x-www-form-urlencoded',
},
});
if (!response.ok) {
@ -171,3 +209,18 @@ function urlParams(data: Record<string, string>) {
return encoded;
}
async function decryptSessionData(env: Env, encryptedData?: string) {
const decryptedData = encryptedData ? await decrypt(payloadSecret(env), encryptedData) : undefined;
const sessionData: SessionData | null = JSON.parse(decryptedData ?? 'null');
return sessionData;
}
async function encryptSessionData(env: Env, sessionData: SessionData) {
return await encrypt(payloadSecret(env), JSON.stringify(sessionData));
}
function payloadSecret(env: Env) {
return DEV_PAYLOAD_SECRET || env.PAYLOAD_SECRET;
}

View File

@ -0,0 +1,58 @@
const encoder = new TextEncoder();
const decoder = new TextDecoder();
const IV_LENGTH = 16;
export async function encrypt(key: string, data: string) {
const iv = crypto.getRandomValues(new Uint8Array(IV_LENGTH));
const cryptoKey = await getKey(key);
const ciphertext = await crypto.subtle.encrypt(
{
name: 'AES-CBC',
iv,
},
cryptoKey,
encoder.encode(data),
);
const bundle = new Uint8Array(IV_LENGTH + ciphertext.byteLength);
bundle.set(new Uint8Array(ciphertext));
bundle.set(iv, ciphertext.byteLength);
return decodeBase64(bundle);
}
export async function decrypt(key: string, payload: string) {
const bundle = encodeBase64(payload);
const iv = new Uint8Array(bundle.buffer, bundle.byteLength - IV_LENGTH);
const ciphertext = new Uint8Array(bundle.buffer, 0, bundle.byteLength - IV_LENGTH);
const cryptoKey = await getKey(key);
const plaintext = await crypto.subtle.decrypt(
{
name: 'AES-CBC',
iv,
},
cryptoKey,
ciphertext,
);
return decoder.decode(plaintext);
}
async function getKey(key: string) {
return await crypto.subtle.importKey('raw', encodeBase64(key), { name: 'AES-CBC' }, false, ['encrypt', 'decrypt']);
}
function decodeBase64(encoded: Uint8Array) {
const byteChars = Array.from(encoded, (byte) => String.fromCodePoint(byte));
return btoa(byteChars.join(''));
}
function encodeBase64(data: string) {
return Uint8Array.from(atob(data), (ch) => ch.codePointAt(0)!);
}

View File

@ -1,12 +1,12 @@
import { json, type ActionFunctionArgs } from '@remix-run/cloudflare';
import { handleWithAuth } from '~/lib/.server/login';
import { getSession } from '~/lib/.server/sessions';
import { getSessionData } from '~/lib/.server/sessions';
import { sendEventInternal, type AnalyticsEvent } from '~/lib/analytics';
async function analyticsAction({ request, context }: ActionFunctionArgs) {
const event: AnalyticsEvent = await request.json();
const { session } = await getSession(request, context.cloudflare.env);
const { success, error } = await sendEventInternal(session.data, event);
const sessionData = await getSessionData(request, context.cloudflare.env);
const { success, error } = await sendEventInternal(sessionData, event);
if (!success) {
return json({ error }, { status: 500 });

View File

@ -4,7 +4,7 @@ import { CONTINUE_PROMPT } from '~/lib/.server/llm/prompts';
import { streamText, type Messages, type StreamingOptions } from '~/lib/.server/llm/stream-text';
import SwitchableStream from '~/lib/.server/llm/switchable-stream';
import { handleWithAuth } from '~/lib/.server/login';
import { getSession } from '~/lib/.server/sessions';
import { getSessionData } from '~/lib/.server/sessions';
import { AnalyticsAction, AnalyticsTrackEvent, sendEventInternal } from '~/lib/analytics';
export async function action(args: ActionFunctionArgs) {
@ -21,9 +21,9 @@ async function chatAction({ context, request }: ActionFunctionArgs) {
toolChoice: 'none',
onFinish: async ({ text: content, finishReason, usage }) => {
if (finishReason !== 'length') {
const { session } = await getSession(request, context.cloudflare.env);
const sessionData = await getSessionData(request, context.cloudflare.env);
await sendEventInternal(session.data, {
await sendEventInternal(sessionData, {
action: AnalyticsAction.Track,
payload: {
event: AnalyticsTrackEvent.MessageComplete,

View File

@ -13,6 +13,9 @@ interface Logger {
let currentLevel: DebugLevel = import.meta.env.VITE_LOG_LEVEL ?? import.meta.env.DEV ? 'debug' : 'info';
const isWorker = 'HTMLRewriter' in globalThis;
const supportsColor = !isWorker;
export const logger: Logger = {
trace: (...messages: any[]) => log('trace', undefined, messages),
debug: (...messages: any[]) => log('debug', undefined, messages),
@ -44,7 +47,28 @@ function setLevel(level: DebugLevel) {
function log(level: DebugLevel, scope: string | undefined, messages: any[]) {
const levelOrder: DebugLevel[] = ['trace', 'debug', 'info', 'warn', 'error'];
if (levelOrder.indexOf(level) >= levelOrder.indexOf(currentLevel)) {
if (levelOrder.indexOf(level) < levelOrder.indexOf(currentLevel)) {
return;
}
const allMessages = messages.reduce((acc, current) => {
if (acc.endsWith('\n')) {
return acc + current;
}
if (!acc) {
return current;
}
return `${acc} ${current}`;
}, '');
if (!supportsColor) {
console.log(`[${level.toUpperCase()}]`, allMessages);
return;
}
const labelBackgroundColor = getColorForLevel(level);
const labelTextColor = level === 'warn' ? 'black' : 'white';
@ -57,22 +81,7 @@ function log(level: DebugLevel, scope: string | undefined, messages: any[]) {
styles.push('', scopeStyles);
}
console.log(
`%c${level.toUpperCase()}${scope ? `%c %c${scope}` : ''}`,
...styles,
messages.reduce((acc, current) => {
if (acc.endsWith('\n')) {
return acc + current;
}
if (!acc) {
return current;
}
return `${acc} ${current}`;
}, ''),
);
}
console.log(`%c${level.toUpperCase()}${scope ? `%c %c${scope}` : ''}`, ...styles, allMessages);
}
function getLabelStyles(color: string, textColor: string) {

View File

@ -1,5 +1,5 @@
interface Env {
ANTHROPIC_API_KEY: string;
SESSION_SECRET: string;
LOGIN_PASSWORD: string;
PAYLOAD_SECRET: string;
}