fix: cors issues

This commit is contained in:
abdou6666 2025-01-20 14:34:25 +01:00
parent c5b53302b5
commit 3ab7b56d29
5 changed files with 156 additions and 68 deletions

View File

@ -47,25 +47,10 @@ async function bootstrap() {
// Retrieve the SettingService instance
const settingService = app.get<SettingService>(SettingService);
// Fetch allowed domains from the settings collection
const settingsAllowedDomains = await settingService.getAllowedDomains();
// Get allowed origins from .env configuration
const configAllowedOrigins = config.security.cors.allowOrigins
? config.security.cors.allowOrigins.map((origin) => origin.trim())
: [];
// Combine both settings and config allowed domains
const combinedAllowedDomains = [
...settingsAllowedDomains,
...configAllowedOrigins,
];
const allowedDomains = Array.from(new Set(combinedAllowedDomains));
// Enable CORS with the combined allowed domains
app.enableCors({
origin: (origin, callback) => {
if (!origin || allowedDomains.includes(origin)) {
if (!origin || settingService.isOriginAllowed(origin)) {
callback(null, true);
} else {
callback(new Error('Not allowed by CORS'));

View File

@ -7,7 +7,7 @@
*/
import { CACHE_MANAGER } from '@nestjs/cache-manager';
import { Inject, Injectable } from '@nestjs/common';
import { Inject, Injectable, OnModuleInit } from '@nestjs/common';
import { OnEvent } from '@nestjs/event-emitter';
import { InjectModel } from '@nestjs/mongoose';
import { Cache } from 'cache-manager';
@ -25,8 +25,16 @@ import { SettingRepository } from '../repositories/setting.repository';
import { Setting } from '../schemas/setting.schema';
import { SettingSeeder } from '../seeds/setting.seed';
//TODO : change to enum?
type Channels = 'console-channel' | 'web-channel';
@Injectable()
export class SettingService extends BaseService<Setting> {
export class SettingService
extends BaseService<Setting>
implements OnModuleInit
{
private allowedOrigins: Map<Channels, Set<string>> = new Map();
constructor(
readonly repository: SettingRepository,
@Inject(CACHE_MANAGER) private readonly cacheManager: Cache,
@ -35,6 +43,75 @@ export class SettingService extends BaseService<Setting> {
@InjectModel(Setting.name) private settingModel: Model<Setting>,
) {
super(repository);
const origins: Channels[] = ['console-channel', 'web-channel'];
origins.forEach((channelType) => {
this.allowedOrigins.set(channelType, new Set<string>());
});
}
async onModuleInit() {
try {
//TODO: refactor into initialize methods
const webChannelAllowedDomains = await this.find({
group: 'web_channel',
label: 'allowed_domains',
});
const consoleChannelAllowedDomains = await this.find({
group: 'console_channel',
label: 'allowed_domains',
});
webChannelAllowedDomains.forEach((webChannelSettings) => {
(webChannelSettings.value.split(',') || []).forEach(
(allowedDomain: string) => {
this.allowedOrigins.get('web-channel').add(allowedDomain);
},
);
});
consoleChannelAllowedDomains.forEach((consoleChannelSettings) => {
(consoleChannelSettings.value.split(',') || []).forEach(
(allowedDomain: string) => {
this.allowedOrigins.get('console-channel').add(allowedDomain);
},
);
});
this.logger.log('allowed domains initialiazed successfully');
} catch (error) {
this.logger.error('Failed to initialiazed allowed domains', error);
}
}
@OnEvent('hook:web_channel:allowed_domains')
handleUpdateWebChannelAllowedDomains(settings: Setting) {
this.allowedOrigins.get('web-channel').clear();
(settings.value.split(',') || []).forEach((allowedDomain: string) => {
this.allowedOrigins.get('web-channel').add(allowedDomain);
});
}
@OnEvent('hook:console_channel:allowed_domains')
handleUpdateConsoleChannelAllowedDomains(settings: Setting) {
this.allowedOrigins.get('console-channel').clear();
(settings.value.split(',') || []).forEach((allowedDomain: string) => {
this.allowedOrigins.get('console-channel').add(allowedDomain);
});
}
private isOriginAllowedConsoleChannel(requesterOrigin: string) {
return this.allowedOrigins.get('console-channel').has(requesterOrigin);
}
private isOriginAllowedWebChannel(requesterOrigin: string) {
return this.allowedOrigins.get('web-channel').has(requesterOrigin);
}
public isOriginAllowed(requesterOrigin: string) {
return (
this.isOriginAllowedConsoleChannel(requesterOrigin) ||
this.isOriginAllowedWebChannel(requesterOrigin)
);
}
/**

View File

@ -1,5 +1,5 @@
/*
* Copyright © 2024 Hexastack. All rights reserved.
* Copyright © 2025 Hexastack. All rights reserved.
*
* Licensed under the GNU Affero General Public License v3.0 (AGPLv3) with the following additional terms:
* 1. The name "Hexabot" is a trademark of Hexastack. You may not use this name in derivative works without express written permission.
@ -8,65 +8,72 @@
import util from 'util';
import { Injectable } from '@nestjs/common';
import type { ServerOptions } from 'socket.io';
import { config } from '@/config';
import { SettingService } from '@/setting/services/setting.service';
export const buildWebSocketGatewayOptions = (): Partial<ServerOptions> => {
const opts: Partial<ServerOptions> = {
allowEIO3: true, // Allows support for Engine.io v3 clients.
path: config.sockets.path,
...(typeof config.sockets.serveClient !== 'undefined' && {
serveClient: config.sockets.serveClient,
}),
...(config.sockets.beforeConnect && {
allowRequest: (handshake, cb) => {
try {
const result = config.sockets.beforeConnect(handshake);
return cb(null, result);
} catch (e) {
// eslint-disable-next-line no-console
console.log(
`A socket was rejected via the config.sockets.beforeConnect function.\n` +
`It attempted to connect with headers:\n` +
`${util.inspect(handshake.headers, { depth: null })}\n` +
`Details: ${e}`,
);
return cb(e, false);
}
},
}),
...(config.sockets.pingTimeout && {
pingTimeout: config.sockets.pingTimeout,
}),
...(config.sockets.pingInterval && {
pingInterval: config.sockets.pingInterval,
}),
...(config.sockets.maxHttpBufferSize && {
maxHttpBufferSize: config.sockets.maxHttpBufferSize,
}),
...(config.sockets.transports && { transports: config.sockets.transports }),
...(config.sockets.allowUpgrades && {
allowUpgrades: config.sockets.allowUpgrades,
}),
...(config.sockets.cookie && { cookie: config.sockets.cookie }),
...(config.sockets.onlyAllowOrigins && {
@Injectable()
export class WebSocketGatewayOptionsService {
constructor(private readonly settingsService: SettingService) {}
buildWebSocketGatewayOptions(): Partial<ServerOptions> {
const opts: Partial<ServerOptions> = {
allowEIO3: true, // Allows support for Engine.io v3 clients.
path: config.sockets.path,
...(typeof config.sockets.serveClient !== 'undefined' && {
serveClient: config.sockets.serveClient,
}),
...(config.sockets.beforeConnect && {
allowRequest: (handshake, cb) => {
try {
const result = config.sockets.beforeConnect(handshake);
return cb(null, result);
} catch (e) {
// eslint-disable-next-line no-console
console.log(
`A socket was rejected via the config.sockets.beforeConnect function.\n` +
`It attempted to connect with headers:\n` +
`${util.inspect(handshake.headers, { depth: null })}\n` +
`Details: ${e}`,
);
return cb(e, false);
}
},
}),
...(config.sockets.pingTimeout && {
pingTimeout: config.sockets.pingTimeout,
}),
...(config.sockets.pingInterval && {
pingInterval: config.sockets.pingInterval,
}),
...(config.sockets.maxHttpBufferSize && {
maxHttpBufferSize: config.sockets.maxHttpBufferSize,
}),
...(config.sockets.transports && {
transports: config.sockets.transports,
}),
...(config.sockets.allowUpgrades && {
allowUpgrades: config.sockets.allowUpgrades,
}),
...(config.sockets.cookie && { cookie: config.sockets.cookie }),
cors: {
origin: (origin, cb) => {
if (origin && config.sockets.onlyAllowOrigins.includes(origin)) {
if (origin && this.settingsService.isOriginAllowed(origin)) {
cb(null, true);
} else {
// eslint-disable-next-line no-console
console.log(
`A socket was rejected via the config.sockets.onlyAllowOrigins array.\n` +
`A socket was rejected via the SettingsService.allowedOriginDomains array.\n` +
`It attempted to connect with origin: ${origin}`,
);
cb(new Error('Origin not allowed'), false);
}
},
},
}),
};
};
return opts;
};
return opts;
}
}

View File

@ -35,25 +35,36 @@ import { config } from '@/config';
import { LoggerService } from '@/logger/logger.service';
import { getSessionStore } from '@/utils/constants/session-store';
import { OnModuleInit } from '@nestjs/common';
import { IOIncomingMessage, IOMessagePipe } from './pipes/io-message.pipe';
import { SocketEventDispatcherService } from './services/socket-event-dispatcher.service';
import { Room } from './types';
import { buildWebSocketGatewayOptions } from './utils/gateway-options';
import { WebSocketGatewayOptionsService } from './utils/gateway-options';
import { SocketRequest } from './utils/socket-request';
import { SocketResponse } from './utils/socket-response';
@WebSocketGateway(buildWebSocketGatewayOptions())
@WebSocketGateway()
export class WebsocketGateway
implements OnGatewayInit, OnGatewayConnection, OnGatewayDisconnect
implements
OnGatewayInit,
OnGatewayConnection,
OnGatewayDisconnect,
OnModuleInit
{
constructor(
private readonly logger: LoggerService,
private readonly eventEmitter: EventEmitter2,
private readonly socketEventDispatcherService: SocketEventDispatcherService,
private readonly gatewayOptionsService: WebSocketGatewayOptionsService,
) {}
@WebSocketServer() io: Server;
onModuleInit() {
const options = this.gatewayOptionsService.buildWebSocketGatewayOptions();
this.io = new Server(options);
}
broadcastMessageSent(message: OutgoingMessage): void {
this.io.to(Room.MESSAGE).emit('message', {
op: 'messageSent',

View File

@ -8,12 +8,20 @@
import { Global, Module } from '@nestjs/common';
import { SettingModule } from '@/setting/setting.module';
import { SocketEventDispatcherService } from './services/socket-event-dispatcher.service';
import { WebSocketGatewayOptionsService } from './utils/gateway-options';
import { WebsocketGateway } from './websocket.gateway';
@Global()
@Module({
providers: [WebsocketGateway, SocketEventDispatcherService],
imports: [SettingModule],
providers: [
WebsocketGateway,
SocketEventDispatcherService,
WebSocketGatewayOptionsService,
],
exports: [WebsocketGateway],
})
export class WebsocketModule {}