From 3ab7b56d299cd4121f29f053108e28d78a11a409 Mon Sep 17 00:00:00 2001 From: abdou6666 Date: Mon, 20 Jan 2025 14:34:25 +0100 Subject: [PATCH] fix: cors issues --- api/src/main.ts | 17 +--- api/src/setting/services/setting.service.ts | 81 ++++++++++++++++- api/src/websocket/utils/gateway-options.ts | 99 +++++++++++---------- api/src/websocket/websocket.gateway.ts | 17 +++- api/src/websocket/websocket.module.ts | 10 ++- 5 files changed, 156 insertions(+), 68 deletions(-) diff --git a/api/src/main.ts b/api/src/main.ts index 48e61a04..1b0c7095 100644 --- a/api/src/main.ts +++ b/api/src/main.ts @@ -47,25 +47,10 @@ async function bootstrap() { // Retrieve the SettingService instance const settingService = app.get(SettingService); - // Fetch allowed domains from the settings collection - const settingsAllowedDomains = await settingService.getAllowedDomains(); - - // Get allowed origins from .env configuration - const configAllowedOrigins = config.security.cors.allowOrigins - ? config.security.cors.allowOrigins.map((origin) => origin.trim()) - : []; - - // Combine both settings and config allowed domains - const combinedAllowedDomains = [ - ...settingsAllowedDomains, - ...configAllowedOrigins, - ]; - const allowedDomains = Array.from(new Set(combinedAllowedDomains)); - // Enable CORS with the combined allowed domains app.enableCors({ origin: (origin, callback) => { - if (!origin || allowedDomains.includes(origin)) { + if (!origin || settingService.isOriginAllowed(origin)) { callback(null, true); } else { callback(new Error('Not allowed by CORS')); diff --git a/api/src/setting/services/setting.service.ts b/api/src/setting/services/setting.service.ts index 033bcd7b..829a3a4d 100644 --- a/api/src/setting/services/setting.service.ts +++ b/api/src/setting/services/setting.service.ts @@ -7,7 +7,7 @@ */ import { CACHE_MANAGER } from '@nestjs/cache-manager'; -import { Inject, Injectable } from '@nestjs/common'; +import { Inject, Injectable, OnModuleInit } from '@nestjs/common'; import { OnEvent } from '@nestjs/event-emitter'; import { InjectModel } from '@nestjs/mongoose'; import { Cache } from 'cache-manager'; @@ -25,8 +25,16 @@ import { SettingRepository } from '../repositories/setting.repository'; import { Setting } from '../schemas/setting.schema'; import { SettingSeeder } from '../seeds/setting.seed'; +//TODO : change to enum? +type Channels = 'console-channel' | 'web-channel'; + @Injectable() -export class SettingService extends BaseService { +export class SettingService + extends BaseService + implements OnModuleInit +{ + private allowedOrigins: Map> = new Map(); + constructor( readonly repository: SettingRepository, @Inject(CACHE_MANAGER) private readonly cacheManager: Cache, @@ -35,6 +43,75 @@ export class SettingService extends BaseService { @InjectModel(Setting.name) private settingModel: Model, ) { super(repository); + const origins: Channels[] = ['console-channel', 'web-channel']; + origins.forEach((channelType) => { + this.allowedOrigins.set(channelType, new Set()); + }); + } + + async onModuleInit() { + try { + //TODO: refactor into initialize methods + const webChannelAllowedDomains = await this.find({ + group: 'web_channel', + label: 'allowed_domains', + }); + const consoleChannelAllowedDomains = await this.find({ + group: 'console_channel', + label: 'allowed_domains', + }); + + webChannelAllowedDomains.forEach((webChannelSettings) => { + (webChannelSettings.value.split(',') || []).forEach( + (allowedDomain: string) => { + this.allowedOrigins.get('web-channel').add(allowedDomain); + }, + ); + }); + + consoleChannelAllowedDomains.forEach((consoleChannelSettings) => { + (consoleChannelSettings.value.split(',') || []).forEach( + (allowedDomain: string) => { + this.allowedOrigins.get('console-channel').add(allowedDomain); + }, + ); + }); + this.logger.log('allowed domains initialiazed successfully'); + } catch (error) { + this.logger.error('Failed to initialiazed allowed domains', error); + } + } + + @OnEvent('hook:web_channel:allowed_domains') + handleUpdateWebChannelAllowedDomains(settings: Setting) { + this.allowedOrigins.get('web-channel').clear(); + (settings.value.split(',') || []).forEach((allowedDomain: string) => { + this.allowedOrigins.get('web-channel').add(allowedDomain); + }); + } + + @OnEvent('hook:console_channel:allowed_domains') + handleUpdateConsoleChannelAllowedDomains(settings: Setting) { + this.allowedOrigins.get('console-channel').clear(); + + (settings.value.split(',') || []).forEach((allowedDomain: string) => { + this.allowedOrigins.get('console-channel').add(allowedDomain); + }); + } + + private isOriginAllowedConsoleChannel(requesterOrigin: string) { + return this.allowedOrigins.get('console-channel').has(requesterOrigin); + } + + private isOriginAllowedWebChannel(requesterOrigin: string) { + return this.allowedOrigins.get('web-channel').has(requesterOrigin); + } + + public isOriginAllowed(requesterOrigin: string) { + return ( + this.isOriginAllowedConsoleChannel(requesterOrigin) || + this.isOriginAllowedWebChannel(requesterOrigin) + ); } /** diff --git a/api/src/websocket/utils/gateway-options.ts b/api/src/websocket/utils/gateway-options.ts index 1b4e9ce3..dbe1431c 100644 --- a/api/src/websocket/utils/gateway-options.ts +++ b/api/src/websocket/utils/gateway-options.ts @@ -1,5 +1,5 @@ /* - * Copyright © 2024 Hexastack. All rights reserved. + * Copyright © 2025 Hexastack. All rights reserved. * * Licensed under the GNU Affero General Public License v3.0 (AGPLv3) with the following additional terms: * 1. The name "Hexabot" is a trademark of Hexastack. You may not use this name in derivative works without express written permission. @@ -8,65 +8,72 @@ import util from 'util'; +import { Injectable } from '@nestjs/common'; import type { ServerOptions } from 'socket.io'; import { config } from '@/config'; +import { SettingService } from '@/setting/services/setting.service'; -export const buildWebSocketGatewayOptions = (): Partial => { - const opts: Partial = { - allowEIO3: true, // Allows support for Engine.io v3 clients. - path: config.sockets.path, - ...(typeof config.sockets.serveClient !== 'undefined' && { - serveClient: config.sockets.serveClient, - }), - ...(config.sockets.beforeConnect && { - allowRequest: (handshake, cb) => { - try { - const result = config.sockets.beforeConnect(handshake); - return cb(null, result); - } catch (e) { - // eslint-disable-next-line no-console - console.log( - `A socket was rejected via the config.sockets.beforeConnect function.\n` + - `It attempted to connect with headers:\n` + - `${util.inspect(handshake.headers, { depth: null })}\n` + - `Details: ${e}`, - ); - return cb(e, false); - } - }, - }), - ...(config.sockets.pingTimeout && { - pingTimeout: config.sockets.pingTimeout, - }), - ...(config.sockets.pingInterval && { - pingInterval: config.sockets.pingInterval, - }), - ...(config.sockets.maxHttpBufferSize && { - maxHttpBufferSize: config.sockets.maxHttpBufferSize, - }), - ...(config.sockets.transports && { transports: config.sockets.transports }), - ...(config.sockets.allowUpgrades && { - allowUpgrades: config.sockets.allowUpgrades, - }), - ...(config.sockets.cookie && { cookie: config.sockets.cookie }), - ...(config.sockets.onlyAllowOrigins && { +@Injectable() +export class WebSocketGatewayOptionsService { + constructor(private readonly settingsService: SettingService) {} + + buildWebSocketGatewayOptions(): Partial { + const opts: Partial = { + allowEIO3: true, // Allows support for Engine.io v3 clients. + path: config.sockets.path, + ...(typeof config.sockets.serveClient !== 'undefined' && { + serveClient: config.sockets.serveClient, + }), + ...(config.sockets.beforeConnect && { + allowRequest: (handshake, cb) => { + try { + const result = config.sockets.beforeConnect(handshake); + return cb(null, result); + } catch (e) { + // eslint-disable-next-line no-console + console.log( + `A socket was rejected via the config.sockets.beforeConnect function.\n` + + `It attempted to connect with headers:\n` + + `${util.inspect(handshake.headers, { depth: null })}\n` + + `Details: ${e}`, + ); + return cb(e, false); + } + }, + }), + ...(config.sockets.pingTimeout && { + pingTimeout: config.sockets.pingTimeout, + }), + ...(config.sockets.pingInterval && { + pingInterval: config.sockets.pingInterval, + }), + ...(config.sockets.maxHttpBufferSize && { + maxHttpBufferSize: config.sockets.maxHttpBufferSize, + }), + ...(config.sockets.transports && { + transports: config.sockets.transports, + }), + ...(config.sockets.allowUpgrades && { + allowUpgrades: config.sockets.allowUpgrades, + }), + ...(config.sockets.cookie && { cookie: config.sockets.cookie }), cors: { origin: (origin, cb) => { - if (origin && config.sockets.onlyAllowOrigins.includes(origin)) { + if (origin && this.settingsService.isOriginAllowed(origin)) { cb(null, true); } else { // eslint-disable-next-line no-console console.log( - `A socket was rejected via the config.sockets.onlyAllowOrigins array.\n` + + `A socket was rejected via the SettingsService.allowedOriginDomains array.\n` + `It attempted to connect with origin: ${origin}`, ); cb(new Error('Origin not allowed'), false); } }, }, - }), - }; + }; - return opts; -}; + return opts; + } +} diff --git a/api/src/websocket/websocket.gateway.ts b/api/src/websocket/websocket.gateway.ts index 10c7c628..5e3c1cdc 100644 --- a/api/src/websocket/websocket.gateway.ts +++ b/api/src/websocket/websocket.gateway.ts @@ -35,25 +35,36 @@ import { config } from '@/config'; import { LoggerService } from '@/logger/logger.service'; import { getSessionStore } from '@/utils/constants/session-store'; +import { OnModuleInit } from '@nestjs/common'; import { IOIncomingMessage, IOMessagePipe } from './pipes/io-message.pipe'; import { SocketEventDispatcherService } from './services/socket-event-dispatcher.service'; import { Room } from './types'; -import { buildWebSocketGatewayOptions } from './utils/gateway-options'; +import { WebSocketGatewayOptionsService } from './utils/gateway-options'; import { SocketRequest } from './utils/socket-request'; import { SocketResponse } from './utils/socket-response'; -@WebSocketGateway(buildWebSocketGatewayOptions()) +@WebSocketGateway() export class WebsocketGateway - implements OnGatewayInit, OnGatewayConnection, OnGatewayDisconnect + implements + OnGatewayInit, + OnGatewayConnection, + OnGatewayDisconnect, + OnModuleInit { constructor( private readonly logger: LoggerService, private readonly eventEmitter: EventEmitter2, private readonly socketEventDispatcherService: SocketEventDispatcherService, + private readonly gatewayOptionsService: WebSocketGatewayOptionsService, ) {} @WebSocketServer() io: Server; + onModuleInit() { + const options = this.gatewayOptionsService.buildWebSocketGatewayOptions(); + this.io = new Server(options); + } + broadcastMessageSent(message: OutgoingMessage): void { this.io.to(Room.MESSAGE).emit('message', { op: 'messageSent', diff --git a/api/src/websocket/websocket.module.ts b/api/src/websocket/websocket.module.ts index f2908330..b1c3f7e3 100644 --- a/api/src/websocket/websocket.module.ts +++ b/api/src/websocket/websocket.module.ts @@ -8,12 +8,20 @@ import { Global, Module } from '@nestjs/common'; +import { SettingModule } from '@/setting/setting.module'; + import { SocketEventDispatcherService } from './services/socket-event-dispatcher.service'; +import { WebSocketGatewayOptionsService } from './utils/gateway-options'; import { WebsocketGateway } from './websocket.gateway'; @Global() @Module({ - providers: [WebsocketGateway, SocketEventDispatcherService], + imports: [SettingModule], + providers: [ + WebsocketGateway, + SocketEventDispatcherService, + WebSocketGatewayOptionsService, + ], exports: [WebsocketGateway], }) export class WebsocketModule {}