diff --git a/package.json b/package.json index 2bba9997d..122a44a8f 100644 --- a/package.json +++ b/package.json @@ -196,7 +196,6 @@ "smooth-scroll-into-view-if-needed": "^1.1.32", "socket.io": "^2.4.0", "socket.io-redis": "^5.4.0", - "socketio-auth": "^0.1.1", "stoppable": "^1.1.0", "string-replace-to-array": "^1.0.3", "styled-components": "^5.2.3", diff --git a/server/logging/Logger.ts b/server/logging/Logger.ts index cd3899ac9..9bbc09e92 100644 --- a/server/logging/Logger.ts +++ b/server/logging/Logger.ts @@ -19,6 +19,7 @@ type LogCategory = | "processor" | "email" | "queue" + | "websockets" | "database" | "utils"; type Extra = Record; diff --git a/server/services/websockets.ts b/server/services/websockets.ts index b2369cfd5..4f12426ac 100644 --- a/server/services/websockets.ts +++ b/server/services/websockets.ts @@ -2,18 +2,25 @@ import http, { IncomingMessage } from "http"; import { Duplex } from "stream"; import invariant from "invariant"; import Koa from "koa"; -import IO from "socket.io"; +import { each, find } from "lodash"; +import IO, { Client } from "socket.io"; import socketRedisAdapter from "socket.io-redis"; -import SocketAuth from "socketio-auth"; import Logger from "@server/logging/Logger"; import Metrics from "@server/logging/metrics"; -import { Document, Collection, View } from "@server/models"; +import { Document, Collection, View, User } from "@server/models"; import { can } from "@server/policies"; import { getUserForJWT } from "@server/utils/jwt"; import { websocketQueue } from "../queues"; import WebsocketsProcessor from "../queues/processors/WebsocketsProcessor"; import Redis from "../redis"; +type SocketWithAuth = IO.Socket & { + auth: boolean; + client: Client & { + user?: User; + }; +}; + export default function init( app: Koa, server: http.Server, @@ -65,6 +72,9 @@ export default function init( Metrics.gaugePerInstance("websockets.count", 0); }); + // Forbid all unauthenticated connections + each(io.nsps, forbidConnections); + io.adapter( socketRedisAdapter({ pubClient: Redis.defaultClient, @@ -85,185 +95,53 @@ export default function init( } }); - io.on("connection", (socket) => { + io.on("connection", (socket: SocketWithAuth) => { Metrics.increment("websockets.connected"); Metrics.gaugePerInstance( "websockets.count", socket.client.conn.server.clientsCount ); - socket.on("disconnect", () => { + + socket.auth = false; + + socket.on("authentication", async function (data) { + try { + await authenticate(socket, data); + + Logger.debug("websockets", `Authenticated socket ${socket.id}`); + socket.auth = true; + + each(io.nsps, function (nsp) { + restoreConnection(nsp, socket); + }); + + void authenticated(io, socket); + } catch (err) { + Logger.error(`Authentication error socket ${socket.id}`, err); + socket.emit("unauthorized", { message: err.message }, function () { + socket.disconnect(); + }); + } + }); + + socket.on("disconnect", async () => { Metrics.increment("websockets.disconnected"); Metrics.gaugePerInstance( "websockets.count", socket.client.conn.server.clientsCount ); + await Redis.defaultClient.hdel(socket.id, "userId"); }); - }); - SocketAuth(io, { - authenticate: async (socket, data, callback) => { - const { token } = data; + setTimeout(function () { + // If the socket didn't authenticate after connection, disconnect it + if (!socket.auth) { + Logger.debug("websockets", `Disconnecting socket ${socket.id}`); - try { - const user = await getUserForJWT(token); - socket.client.user = user; - - // store the mapping between socket id and user id in redis - // so that it is accessible across multiple server nodes - await Redis.defaultClient.hset(socket.id, "userId", user.id); - return callback(null, true); - } catch (err) { - return callback(err, false); + // @ts-expect-error should be boolean + socket.disconnect("unauthorized"); } - }, - - postAuthenticate: async (socket) => { - const { user } = socket.client; - - // the rooms associated with the current team - // and user so we can send authenticated events - const rooms = [`team-${user.teamId}`, `user-${user.id}`]; - - // the rooms associated with collections this user - // has access to on connection. New collection subscriptions - // are managed from the client as needed through the 'join' event - const collectionIds: string[] = await user.collectionIds(); - - collectionIds.forEach((collectionId) => - rooms.push(`collection-${collectionId}`) - ); - - // join all of the rooms at once - socket.join(rooms); - - // allow the client to request to join rooms - socket.on("join", async (event) => { - // user is joining a collection channel, because their permissions have - // changed, granting them access. - if (event.collectionId) { - const collection = await Collection.scope({ - method: ["withMembership", user.id], - }).findByPk(event.collectionId); - - if (can(user, "read", collection)) { - socket.join(`collection-${event.collectionId}`, () => { - Metrics.increment("websockets.collections.join"); - }); - } - } - - // user is joining a document channel, because they have navigated to - // view a document. - if (event.documentId) { - const document = await Document.findByPk(event.documentId, { - userId: user.id, - }); - - if (can(user, "read", document)) { - const room = `document-${event.documentId}`; - await View.touch(event.documentId, user.id, event.isEditing); - const editing = await View.findRecentlyEditingByDocument( - event.documentId - ); - - socket.join(room, () => { - Metrics.increment("websockets.documents.join"); - - // let everyone else in the room know that a new user joined - io.to(room).emit("user.join", { - userId: user.id, - documentId: event.documentId, - isEditing: event.isEditing, - }); - - // let this user know who else is already present in the room - io.in(room).clients(async (err: Error, sockets: string[]) => { - if (err) { - Logger.error("Error getting clients for room", err, { - sockets, - }); - return; - } - - // because a single user can have multiple socket connections we - // need to make sure that only unique userIds are returned. A Map - // makes this easy. - const userIds = new Map(); - - for (const socketId of sockets) { - const userId = await Redis.defaultClient.hget( - socketId, - "userId" - ); - userIds.set(userId, userId); - } - - socket.emit("document.presence", { - documentId: event.documentId, - userIds: Array.from(userIds.keys()), - editingIds: editing.map((view) => view.userId), - }); - }); - }); - } - } - }); - - // allow the client to request to leave rooms - socket.on("leave", (event) => { - if (event.collectionId) { - socket.leave(`collection-${event.collectionId}`, () => { - Metrics.increment("websockets.collections.leave"); - }); - } - - if (event.documentId) { - const room = `document-${event.documentId}`; - - socket.leave(room, () => { - Metrics.increment("websockets.documents.leave"); - io.to(room).emit("user.leave", { - userId: user.id, - documentId: event.documentId, - }); - }); - } - }); - - socket.on("disconnecting", () => { - const rooms = Object.keys(socket.rooms); - - rooms.forEach((room) => { - if (room.startsWith("document-")) { - const documentId = room.replace("document-", ""); - io.to(room).emit("user.leave", { - userId: user.id, - documentId, - }); - } - }); - }); - - socket.on("presence", async (event) => { - Metrics.increment("websockets.presence"); - const room = `document-${event.documentId}`; - - if (event.documentId && socket.rooms[room]) { - const view = await View.touch( - event.documentId, - user.id, - event.isEditing - ); - - view.user = user; - io.to(room).emit("user.presence", { - userId: user.id, - documentId: event.documentId, - isEditing: event.isEditing, - }); - } - }); - }, + }, 1000); }); // Handle events from event queue that should be sent to the clients down ws @@ -277,3 +155,186 @@ export default function init( }); }); } + +async function authenticated(io: IO.Server, socket: SocketWithAuth) { + const { user } = socket.client; + if (!user) { + throw new Error("User not returned from auth"); + } + + // the rooms associated with the current team + // and user so we can send authenticated events + const rooms = [`team-${user.teamId}`, `user-${user.id}`]; + + // the rooms associated with collections this user + // has access to on connection. New collection subscriptions + // are managed from the client as needed through the 'join' event + const collectionIds: string[] = await user.collectionIds(); + + collectionIds.forEach((collectionId) => + rooms.push(`collection-${collectionId}`) + ); + + // join all of the rooms at once + socket.join(rooms); + + // allow the client to request to join rooms + socket.on("join", async (event) => { + // user is joining a collection channel, because their permissions have + // changed, granting them access. + if (event.collectionId) { + const collection = await Collection.scope({ + method: ["withMembership", user.id], + }).findByPk(event.collectionId); + + if (can(user, "read", collection)) { + socket.join(`collection-${event.collectionId}`, () => { + Metrics.increment("websockets.collections.join"); + }); + } + } + + // user is joining a document channel, because they have navigated to + // view a document. + if (event.documentId) { + const document = await Document.findByPk(event.documentId, { + userId: user.id, + }); + + if (can(user, "read", document)) { + const room = `document-${event.documentId}`; + await View.touch(event.documentId, user.id, event.isEditing); + const editing = await View.findRecentlyEditingByDocument( + event.documentId + ); + + socket.join(room, () => { + Metrics.increment("websockets.documents.join"); + + // let everyone else in the room know that a new user joined + io.to(room).emit("user.join", { + userId: user.id, + documentId: event.documentId, + isEditing: event.isEditing, + }); + + // let this user know who else is already present in the room + io.in(room).clients(async (err: Error, sockets: string[]) => { + if (err) { + Logger.error("Error getting clients for room", err, { + sockets, + }); + return; + } + + // because a single user can have multiple socket connections we + // need to make sure that only unique userIds are returned. A Map + // makes this easy. + const userIds = new Map(); + + for (const socketId of sockets) { + const userId = await Redis.defaultClient.hget(socketId, "userId"); + userIds.set(userId, userId); + } + + socket.emit("document.presence", { + documentId: event.documentId, + userIds: Array.from(userIds.keys()), + editingIds: editing.map((view) => view.userId), + }); + }); + }); + } + } + }); + + // allow the client to request to leave rooms + socket.on("leave", (event) => { + if (event.collectionId) { + socket.leave(`collection-${event.collectionId}`, () => { + Metrics.increment("websockets.collections.leave"); + }); + } + + if (event.documentId) { + const room = `document-${event.documentId}`; + + socket.leave(room, () => { + Metrics.increment("websockets.documents.leave"); + io.to(room).emit("user.leave", { + userId: user.id, + documentId: event.documentId, + }); + }); + } + }); + + socket.on("disconnecting", () => { + const rooms = Object.keys(socket.rooms); + + rooms.forEach((room) => { + if (room.startsWith("document-")) { + const documentId = room.replace("document-", ""); + io.to(room).emit("user.leave", { + userId: user.id, + documentId, + }); + } + }); + }); + + socket.on("presence", async (event) => { + Metrics.increment("websockets.presence"); + const room = `document-${event.documentId}`; + + if (event.documentId && socket.rooms[room]) { + const view = await View.touch(event.documentId, user.id, event.isEditing); + + view.user = user; + io.to(room).emit("user.presence", { + userId: user.id, + documentId: event.documentId, + isEditing: event.isEditing, + }); + } + }); +} + +/** + * Authenticate the socket with the given token, attach the user model for the + * duration of the session. + */ +async function authenticate(socket: SocketWithAuth, data: { token: string }) { + const { token } = data; + + const user = await getUserForJWT(token); + socket.client.user = user; + + // store the mapping between socket id and user id in redis so that it is + // accessible across multiple websocket servers + await Redis.defaultClient.hset(socket.id, "userId", user.id); +} + +/** + * Set a listener so connections from unauthenticated sockets are not + * considered when emitting to the namespace. The connections will be + * restored after authentication succeeds. + */ +function forbidConnections(nsp: IO.Namespace) { + nsp.on("connect", function (socket: SocketWithAuth) { + if (!socket.auth) { + Logger.debug("websockets", `removing socket from ${nsp.name}`); + delete nsp.connected[socket.id]; + } + }); +} + +/** + * If the socket attempted a connection before authentication, restore it. + */ +function restoreConnection(nsp: IO.Namespace, socket: IO.Socket) { + if (find(nsp.sockets, { id: socket.id })) { + Logger.debug("websockets", `restoring socket to ${nsp.name}`); + nsp.connected[socket.id] = socket; + } +} diff --git a/yarn.lock b/yarn.lock index 671ef34a6..6141d058d 100644 --- a/yarn.lock +++ b/yarn.lock @@ -5973,7 +5973,7 @@ debug@4, debug@^4.0.1, debug@^4.1.0, debug@^4.1.1, debug@^4.3.1, debug@^4.3.2, d dependencies: ms "2.1.2" -debug@^2.1.3, debug@^2.2.0, debug@^2.3.3, debug@^2.6.1, debug@^2.6.3, debug@^2.6.8, debug@^2.6.9: +debug@^2.2.0, debug@^2.3.3, debug@^2.6.1, debug@^2.6.3, debug@^2.6.8, debug@^2.6.9: version "2.6.9" resolved "https://registry.yarnpkg.com/debug/-/debug-2.6.9.tgz#5d128515df134ff327e90a4c93f4e077a536341f" integrity sha512-bC7ElrdJaJnPbAP+1EotYvqZsb3ecl5wi6Bfi6BJTUcNowp6cvspg0jXznRTKDjm/E7AdgFBVeAPVMNcKGsHMA== @@ -10282,7 +10282,7 @@ lodash.uniq@^4.5.0: resolved "https://registry.yarnpkg.com/lodash.uniq/-/lodash.uniq-4.5.0.tgz#d0225373aeb652adc1bc82e4945339a842754773" integrity sha1-0CJTc662Uq3BvILklFM5qEJ1R3M= -lodash@^4.0.1, lodash@^4.17.10, lodash@^4.17.11, lodash@^4.17.14, lodash@^4.17.15, lodash@^4.17.19, lodash@^4.17.20, lodash@^4.17.21, lodash@^4.17.4, lodash@^4.17.5, lodash@^4.7.0: +lodash@^4.0.1, lodash@^4.17.10, lodash@^4.17.11, lodash@^4.17.14, lodash@^4.17.15, lodash@^4.17.19, lodash@^4.17.20, lodash@^4.17.21, lodash@^4.17.4, lodash@^4.7.0: version "4.17.21" resolved "https://registry.yarnpkg.com/lodash/-/lodash-4.17.21.tgz#679591c564c3bffaae8454cf0b3df370c3d6911c" integrity sha512-v2kDEe57lecTulaDIuNTPy3Ry4gLGJ6Z1O3vE1krgXZNrsQ+LFTGHVxVjcXPs17LhbZVGedAJv8XZ1tvj5FvSg== @@ -13578,14 +13578,6 @@ socket.io@^2.4.0: socket.io-client "2.4.0" socket.io-parser "~3.4.0" -socketio-auth@^0.1.1: - version "0.1.1" - resolved "https://registry.yarnpkg.com/socketio-auth/-/socketio-auth-0.1.1.tgz#03f1fdd9d9b5e10f0a0ea9502abadbc580015d71" - integrity sha512-TDM/yiA5tnDiJqn8fO5zHrvTaKmN4EK4Dci9RaJLO11LEEbC1/E7z352OFrIWg8d/rn+Nk666ks9RhjlkGILlA== - dependencies: - debug "^2.1.3" - lodash "^4.17.5" - sort-keys@^5.0.0: version "5.0.0" resolved "https://registry.yarnpkg.com/sort-keys/-/sort-keys-5.0.0.tgz#5d775f8ae93ecc29bc7312bbf3acac4e36e3c446"