Merge pull request #197 from Hexastack/196-issue-enhance-web-socket-connection-security

fix: enhance web-socket connection access
This commit is contained in:
Med Marrouchi 2024-10-12 06:44:44 +01:00 committed by GitHub
commit baf561ee7a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 65 additions and 43 deletions

View File

@ -24,6 +24,7 @@
"test:cov": "jest --coverage --runInBand --detectOpenHandles --forceExit", "test:cov": "jest --coverage --runInBand --detectOpenHandles --forceExit",
"test:debug": "node --inspect-brk -r tsconfig-paths/register -r ts-node/register node_modules/.bin/jest --runInBand", "test:debug": "node --inspect-brk -r tsconfig-paths/register -r ts-node/register node_modules/.bin/jest --runInBand",
"test:e2e": "jest --config ./test/jest-e2e.json", "test:e2e": "jest --config ./test/jest-e2e.json",
"test:clear": "jest --clearCache",
"typecheck": "tsc --noEmit", "typecheck": "tsc --noEmit",
"reset": "npm install && npm run containers:restart", "reset": "npm install && npm run containers:restart",
"reset:hard": "npm clean-install && npm run containers:rebuild", "reset:hard": "npm clean-install && npm run containers:rebuild",

View File

@ -128,6 +128,9 @@ export class ChannelService {
); );
if (!req.session?.passport?.user?.id) { if (!req.session?.passport?.user?.id) {
setTimeout(() => {
req.socket.client.conn.close();
}, 300);
throw new UnauthorizedException( throw new UnauthorizedException(
'Only authenticated users are allowed to use this channel', 'Only authenticated users are allowed to use this channel',
); );

View File

@ -49,6 +49,8 @@ describe('WebsocketGateway', () => {
ioClient = io('http://localhost:3000', { ioClient = io('http://localhost:3000', {
autoConnect: false, autoConnect: false,
transports: ['websocket', 'polling'], transports: ['websocket', 'polling'],
// path: '/socket.io/?EIO=4&transport=websocket&channel=offline',
query: { EIO: '4', transport: 'websocket', channel: 'offline' },
}); });
app.listen(3000); app.listen(3000);

View File

@ -207,60 +207,70 @@ export class WebsocketGateway
// Handle session // Handle session
this.io.use((client, next) => { this.io.use((client, next) => {
this.logger.verbose('Client connected, attempting to load session.'); this.logger.verbose('Client connected, attempting to load session.');
if (client.request.headers.cookie) { try {
const cookies = cookie.parse(client.request.headers.cookie); const { searchParams } = new URL(`ws://localhost${client.request.url}`);
if (cookies && config.session.name in cookies) { if (client.request.headers.cookie) {
const sessionID = cookieParser.signedCookie( const cookies = cookie.parse(client.request.headers.cookie);
cookies[config.session.name], if (cookies && config.session.name in cookies) {
config.session.secret, const sessionID = cookieParser.signedCookie(
); cookies[config.session.name],
if (sessionID) { config.session.secret,
return this.loadSession(sessionID, (err, session) => { );
if (err) { if (sessionID) {
this.logger.warn( return this.loadSession(sessionID, (err, session) => {
'Unable to load session, creating a new one ...', if (err || !session) {
err, this.logger.warn(
); 'Unable to load session, creating a new one ...',
return this.createAndStoreSession(client, next); err,
} );
client.data.session = session; if (searchParams.get('channel') === 'offline') {
client.data.sessionID = sessionID; return this.createAndStoreSession(client, next);
next(); } else {
}); return next(new Error('Unauthorized: Unknown session ID'));
}
}
client.data.session = session;
client.data.sessionID = sessionID;
next();
});
} else {
return next(new Error('Unable to parse session ID from cookie'));
}
} }
} else if (searchParams.get('channel') === 'offline') {
return this.createAndStoreSession(client, next);
} else {
return next(new Error('Unauthorized to connect to WS'));
} }
} catch (e) {
this.logger.warn('Something unexpected happening');
return next(e);
} }
return this.createAndStoreSession(client, next);
}); });
} }
handleConnection(client: Socket, ..._args: any[]): void { handleConnection(client: Socket, ..._args: any[]): void {
const { sockets } = this.io.sockets; const { sockets } = this.io.sockets;
const handshake = client.handshake;
const { channel } = handshake.query;
this.logger.log(`Client id: ${client.id} connected`); this.logger.log(`Client id: ${client.id} connected`);
this.logger.debug(`Number of connected clients: ${sockets?.size}`); this.logger.debug(`Number of connected clients: ${sockets?.size}`);
this.eventEmitter.emit(`hook:websocket:connection`, client); this.eventEmitter.emit(`hook:websocket:connection`, client);
// @TODO : Revisit once we don't use anymore in frontend // @TODO : Revisit once we don't use anymore in frontend
if (!channel) { const response = new SocketResponse();
const response = new SocketResponse(); client.send(
client.send( response
response .setHeaders({
.setHeaders({ 'access-control-allow-origin':
'access-control-allow-origin': config.security.cors.allowOrigins.join(','),
config.security.cors.allowOrigins.join(','), vary: 'Origin',
vary: 'Origin', 'access-control-allow-credentials':
'access-control-allow-credentials': config.security.cors.allowCredentials.toString(),
config.security.cors.allowCredentials.toString(), })
}) .status(200)
.status(200) .json({
.json({ success: true,
success: true, }),
}), );
);
}
} }
async handleDisconnect(client: Socket): Promise<void> { async handleDisconnect(client: Socket): Promise<void> {

View File

@ -11,6 +11,7 @@ import UiChatWidget from "hexabot-widget/src/UiChatWidget";
import { usePathname } from "next/navigation"; import { usePathname } from "next/navigation";
import { getAvatarSrc } from "@/components/inbox/helpers/mapMessages"; import { getAvatarSrc } from "@/components/inbox/helpers/mapMessages";
import { useAuth } from "@/hooks/useAuth";
import { useConfig } from "@/hooks/useConfig"; import { useConfig } from "@/hooks/useConfig";
import i18n from "@/i18n/config"; import i18n from "@/i18n/config";
import { EntityType, RouterType } from "@/services/types"; import { EntityType, RouterType } from "@/services/types";
@ -20,9 +21,10 @@ import { ChatWidgetHeader } from "./ChatWidgetHeader";
export const ChatWidget = () => { export const ChatWidget = () => {
const pathname = usePathname(); const pathname = usePathname();
const { apiUrl } = useConfig(); const { apiUrl } = useConfig();
const { isAuthenticated } = useAuth();
const isVisualEditor = pathname === `/${RouterType.VISUAL_EDITOR}`; const isVisualEditor = pathname === `/${RouterType.VISUAL_EDITOR}`;
return ( return isAuthenticated ? (
<Box <Box
sx={{ sx={{
display: isVisualEditor ? "block" : "none", display: isVisualEditor ? "block" : "none",
@ -44,5 +46,5 @@ export const ChatWidget = () => {
)} )}
/> />
</Box> </Box>
); ) : null;
}; };

View File

@ -12,6 +12,7 @@ import { useMutation, useQuery, useQueryClient } from "react-query";
import { EntityType, TMutationOptions } from "@/services/types"; import { EntityType, TMutationOptions } from "@/services/types";
import { ILoginAttributes } from "@/types/auth/login.types"; import { ILoginAttributes } from "@/types/auth/login.types";
import { IUser, IUserAttributes, IUserStub } from "@/types/user.types"; import { IUser, IUserAttributes, IUserStub } from "@/types/user.types";
import { useSocket } from "@/websocket/socket-hooks";
import { useFind } from "../crud/useFind"; import { useFind } from "../crud/useFind";
import { useApiClient } from "../useApiClient"; import { useApiClient } from "../useApiClient";
@ -45,10 +46,13 @@ export const useLogout = (
>, >,
) => { ) => {
const { apiClient } = useApiClient(); const { apiClient } = useApiClient();
const { socket } = useSocket();
return useMutation({ return useMutation({
...options, ...options,
async mutationFn() { async mutationFn() {
socket?.disconnect();
return await apiClient.logout(); return await apiClient.logout();
}, },
onSuccess: () => {}, onSuccess: () => {},