mirror of
https://github.com/hexastack/hexabot
synced 2025-01-22 10:35:37 +00:00
fix: cors issues
This commit is contained in:
parent
c5b53302b5
commit
3ab7b56d29
@ -47,25 +47,10 @@ async function bootstrap() {
|
||||
// Retrieve the SettingService instance
|
||||
const settingService = app.get<SettingService>(SettingService);
|
||||
|
||||
// Fetch allowed domains from the settings collection
|
||||
const settingsAllowedDomains = await settingService.getAllowedDomains();
|
||||
|
||||
// Get allowed origins from .env configuration
|
||||
const configAllowedOrigins = config.security.cors.allowOrigins
|
||||
? config.security.cors.allowOrigins.map((origin) => origin.trim())
|
||||
: [];
|
||||
|
||||
// Combine both settings and config allowed domains
|
||||
const combinedAllowedDomains = [
|
||||
...settingsAllowedDomains,
|
||||
...configAllowedOrigins,
|
||||
];
|
||||
const allowedDomains = Array.from(new Set(combinedAllowedDomains));
|
||||
|
||||
// Enable CORS with the combined allowed domains
|
||||
app.enableCors({
|
||||
origin: (origin, callback) => {
|
||||
if (!origin || allowedDomains.includes(origin)) {
|
||||
if (!origin || settingService.isOriginAllowed(origin)) {
|
||||
callback(null, true);
|
||||
} else {
|
||||
callback(new Error('Not allowed by CORS'));
|
||||
|
@ -7,7 +7,7 @@
|
||||
*/
|
||||
|
||||
import { CACHE_MANAGER } from '@nestjs/cache-manager';
|
||||
import { Inject, Injectable } from '@nestjs/common';
|
||||
import { Inject, Injectable, OnModuleInit } from '@nestjs/common';
|
||||
import { OnEvent } from '@nestjs/event-emitter';
|
||||
import { InjectModel } from '@nestjs/mongoose';
|
||||
import { Cache } from 'cache-manager';
|
||||
@ -25,8 +25,16 @@ import { SettingRepository } from '../repositories/setting.repository';
|
||||
import { Setting } from '../schemas/setting.schema';
|
||||
import { SettingSeeder } from '../seeds/setting.seed';
|
||||
|
||||
//TODO : change to enum?
|
||||
type Channels = 'console-channel' | 'web-channel';
|
||||
|
||||
@Injectable()
|
||||
export class SettingService extends BaseService<Setting> {
|
||||
export class SettingService
|
||||
extends BaseService<Setting>
|
||||
implements OnModuleInit
|
||||
{
|
||||
private allowedOrigins: Map<Channels, Set<string>> = new Map();
|
||||
|
||||
constructor(
|
||||
readonly repository: SettingRepository,
|
||||
@Inject(CACHE_MANAGER) private readonly cacheManager: Cache,
|
||||
@ -35,6 +43,75 @@ export class SettingService extends BaseService<Setting> {
|
||||
@InjectModel(Setting.name) private settingModel: Model<Setting>,
|
||||
) {
|
||||
super(repository);
|
||||
const origins: Channels[] = ['console-channel', 'web-channel'];
|
||||
origins.forEach((channelType) => {
|
||||
this.allowedOrigins.set(channelType, new Set<string>());
|
||||
});
|
||||
}
|
||||
|
||||
async onModuleInit() {
|
||||
try {
|
||||
//TODO: refactor into initialize methods
|
||||
const webChannelAllowedDomains = await this.find({
|
||||
group: 'web_channel',
|
||||
label: 'allowed_domains',
|
||||
});
|
||||
const consoleChannelAllowedDomains = await this.find({
|
||||
group: 'console_channel',
|
||||
label: 'allowed_domains',
|
||||
});
|
||||
|
||||
webChannelAllowedDomains.forEach((webChannelSettings) => {
|
||||
(webChannelSettings.value.split(',') || []).forEach(
|
||||
(allowedDomain: string) => {
|
||||
this.allowedOrigins.get('web-channel').add(allowedDomain);
|
||||
},
|
||||
);
|
||||
});
|
||||
|
||||
consoleChannelAllowedDomains.forEach((consoleChannelSettings) => {
|
||||
(consoleChannelSettings.value.split(',') || []).forEach(
|
||||
(allowedDomain: string) => {
|
||||
this.allowedOrigins.get('console-channel').add(allowedDomain);
|
||||
},
|
||||
);
|
||||
});
|
||||
this.logger.log('allowed domains initialiazed successfully');
|
||||
} catch (error) {
|
||||
this.logger.error('Failed to initialiazed allowed domains', error);
|
||||
}
|
||||
}
|
||||
|
||||
@OnEvent('hook:web_channel:allowed_domains')
|
||||
handleUpdateWebChannelAllowedDomains(settings: Setting) {
|
||||
this.allowedOrigins.get('web-channel').clear();
|
||||
(settings.value.split(',') || []).forEach((allowedDomain: string) => {
|
||||
this.allowedOrigins.get('web-channel').add(allowedDomain);
|
||||
});
|
||||
}
|
||||
|
||||
@OnEvent('hook:console_channel:allowed_domains')
|
||||
handleUpdateConsoleChannelAllowedDomains(settings: Setting) {
|
||||
this.allowedOrigins.get('console-channel').clear();
|
||||
|
||||
(settings.value.split(',') || []).forEach((allowedDomain: string) => {
|
||||
this.allowedOrigins.get('console-channel').add(allowedDomain);
|
||||
});
|
||||
}
|
||||
|
||||
private isOriginAllowedConsoleChannel(requesterOrigin: string) {
|
||||
return this.allowedOrigins.get('console-channel').has(requesterOrigin);
|
||||
}
|
||||
|
||||
private isOriginAllowedWebChannel(requesterOrigin: string) {
|
||||
return this.allowedOrigins.get('web-channel').has(requesterOrigin);
|
||||
}
|
||||
|
||||
public isOriginAllowed(requesterOrigin: string) {
|
||||
return (
|
||||
this.isOriginAllowedConsoleChannel(requesterOrigin) ||
|
||||
this.isOriginAllowedWebChannel(requesterOrigin)
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright © 2024 Hexastack. All rights reserved.
|
||||
* Copyright © 2025 Hexastack. All rights reserved.
|
||||
*
|
||||
* Licensed under the GNU Affero General Public License v3.0 (AGPLv3) with the following additional terms:
|
||||
* 1. The name "Hexabot" is a trademark of Hexastack. You may not use this name in derivative works without express written permission.
|
||||
@ -8,65 +8,72 @@
|
||||
|
||||
import util from 'util';
|
||||
|
||||
import { Injectable } from '@nestjs/common';
|
||||
import type { ServerOptions } from 'socket.io';
|
||||
|
||||
import { config } from '@/config';
|
||||
import { SettingService } from '@/setting/services/setting.service';
|
||||
|
||||
export const buildWebSocketGatewayOptions = (): Partial<ServerOptions> => {
|
||||
const opts: Partial<ServerOptions> = {
|
||||
allowEIO3: true, // Allows support for Engine.io v3 clients.
|
||||
path: config.sockets.path,
|
||||
...(typeof config.sockets.serveClient !== 'undefined' && {
|
||||
serveClient: config.sockets.serveClient,
|
||||
}),
|
||||
...(config.sockets.beforeConnect && {
|
||||
allowRequest: (handshake, cb) => {
|
||||
try {
|
||||
const result = config.sockets.beforeConnect(handshake);
|
||||
return cb(null, result);
|
||||
} catch (e) {
|
||||
// eslint-disable-next-line no-console
|
||||
console.log(
|
||||
`A socket was rejected via the config.sockets.beforeConnect function.\n` +
|
||||
`It attempted to connect with headers:\n` +
|
||||
`${util.inspect(handshake.headers, { depth: null })}\n` +
|
||||
`Details: ${e}`,
|
||||
);
|
||||
return cb(e, false);
|
||||
}
|
||||
},
|
||||
}),
|
||||
...(config.sockets.pingTimeout && {
|
||||
pingTimeout: config.sockets.pingTimeout,
|
||||
}),
|
||||
...(config.sockets.pingInterval && {
|
||||
pingInterval: config.sockets.pingInterval,
|
||||
}),
|
||||
...(config.sockets.maxHttpBufferSize && {
|
||||
maxHttpBufferSize: config.sockets.maxHttpBufferSize,
|
||||
}),
|
||||
...(config.sockets.transports && { transports: config.sockets.transports }),
|
||||
...(config.sockets.allowUpgrades && {
|
||||
allowUpgrades: config.sockets.allowUpgrades,
|
||||
}),
|
||||
...(config.sockets.cookie && { cookie: config.sockets.cookie }),
|
||||
...(config.sockets.onlyAllowOrigins && {
|
||||
@Injectable()
|
||||
export class WebSocketGatewayOptionsService {
|
||||
constructor(private readonly settingsService: SettingService) {}
|
||||
|
||||
buildWebSocketGatewayOptions(): Partial<ServerOptions> {
|
||||
const opts: Partial<ServerOptions> = {
|
||||
allowEIO3: true, // Allows support for Engine.io v3 clients.
|
||||
path: config.sockets.path,
|
||||
...(typeof config.sockets.serveClient !== 'undefined' && {
|
||||
serveClient: config.sockets.serveClient,
|
||||
}),
|
||||
...(config.sockets.beforeConnect && {
|
||||
allowRequest: (handshake, cb) => {
|
||||
try {
|
||||
const result = config.sockets.beforeConnect(handshake);
|
||||
return cb(null, result);
|
||||
} catch (e) {
|
||||
// eslint-disable-next-line no-console
|
||||
console.log(
|
||||
`A socket was rejected via the config.sockets.beforeConnect function.\n` +
|
||||
`It attempted to connect with headers:\n` +
|
||||
`${util.inspect(handshake.headers, { depth: null })}\n` +
|
||||
`Details: ${e}`,
|
||||
);
|
||||
return cb(e, false);
|
||||
}
|
||||
},
|
||||
}),
|
||||
...(config.sockets.pingTimeout && {
|
||||
pingTimeout: config.sockets.pingTimeout,
|
||||
}),
|
||||
...(config.sockets.pingInterval && {
|
||||
pingInterval: config.sockets.pingInterval,
|
||||
}),
|
||||
...(config.sockets.maxHttpBufferSize && {
|
||||
maxHttpBufferSize: config.sockets.maxHttpBufferSize,
|
||||
}),
|
||||
...(config.sockets.transports && {
|
||||
transports: config.sockets.transports,
|
||||
}),
|
||||
...(config.sockets.allowUpgrades && {
|
||||
allowUpgrades: config.sockets.allowUpgrades,
|
||||
}),
|
||||
...(config.sockets.cookie && { cookie: config.sockets.cookie }),
|
||||
cors: {
|
||||
origin: (origin, cb) => {
|
||||
if (origin && config.sockets.onlyAllowOrigins.includes(origin)) {
|
||||
if (origin && this.settingsService.isOriginAllowed(origin)) {
|
||||
cb(null, true);
|
||||
} else {
|
||||
// eslint-disable-next-line no-console
|
||||
console.log(
|
||||
`A socket was rejected via the config.sockets.onlyAllowOrigins array.\n` +
|
||||
`A socket was rejected via the SettingsService.allowedOriginDomains array.\n` +
|
||||
`It attempted to connect with origin: ${origin}`,
|
||||
);
|
||||
cb(new Error('Origin not allowed'), false);
|
||||
}
|
||||
},
|
||||
},
|
||||
}),
|
||||
};
|
||||
};
|
||||
|
||||
return opts;
|
||||
};
|
||||
return opts;
|
||||
}
|
||||
}
|
||||
|
@ -35,25 +35,36 @@ import { config } from '@/config';
|
||||
import { LoggerService } from '@/logger/logger.service';
|
||||
import { getSessionStore } from '@/utils/constants/session-store';
|
||||
|
||||
import { OnModuleInit } from '@nestjs/common';
|
||||
import { IOIncomingMessage, IOMessagePipe } from './pipes/io-message.pipe';
|
||||
import { SocketEventDispatcherService } from './services/socket-event-dispatcher.service';
|
||||
import { Room } from './types';
|
||||
import { buildWebSocketGatewayOptions } from './utils/gateway-options';
|
||||
import { WebSocketGatewayOptionsService } from './utils/gateway-options';
|
||||
import { SocketRequest } from './utils/socket-request';
|
||||
import { SocketResponse } from './utils/socket-response';
|
||||
|
||||
@WebSocketGateway(buildWebSocketGatewayOptions())
|
||||
@WebSocketGateway()
|
||||
export class WebsocketGateway
|
||||
implements OnGatewayInit, OnGatewayConnection, OnGatewayDisconnect
|
||||
implements
|
||||
OnGatewayInit,
|
||||
OnGatewayConnection,
|
||||
OnGatewayDisconnect,
|
||||
OnModuleInit
|
||||
{
|
||||
constructor(
|
||||
private readonly logger: LoggerService,
|
||||
private readonly eventEmitter: EventEmitter2,
|
||||
private readonly socketEventDispatcherService: SocketEventDispatcherService,
|
||||
private readonly gatewayOptionsService: WebSocketGatewayOptionsService,
|
||||
) {}
|
||||
|
||||
@WebSocketServer() io: Server;
|
||||
|
||||
onModuleInit() {
|
||||
const options = this.gatewayOptionsService.buildWebSocketGatewayOptions();
|
||||
this.io = new Server(options);
|
||||
}
|
||||
|
||||
broadcastMessageSent(message: OutgoingMessage): void {
|
||||
this.io.to(Room.MESSAGE).emit('message', {
|
||||
op: 'messageSent',
|
||||
|
@ -8,12 +8,20 @@
|
||||
|
||||
import { Global, Module } from '@nestjs/common';
|
||||
|
||||
import { SettingModule } from '@/setting/setting.module';
|
||||
|
||||
import { SocketEventDispatcherService } from './services/socket-event-dispatcher.service';
|
||||
import { WebSocketGatewayOptionsService } from './utils/gateway-options';
|
||||
import { WebsocketGateway } from './websocket.gateway';
|
||||
|
||||
@Global()
|
||||
@Module({
|
||||
providers: [WebsocketGateway, SocketEventDispatcherService],
|
||||
imports: [SettingModule],
|
||||
providers: [
|
||||
WebsocketGateway,
|
||||
SocketEventDispatcherService,
|
||||
WebSocketGatewayOptionsService,
|
||||
],
|
||||
exports: [WebsocketGateway],
|
||||
})
|
||||
export class WebsocketModule {}
|
||||
|
Loading…
Reference in New Issue
Block a user