chore: Rate limiter audit (#3965)

* chore: Rate limiter audit api/users

* Make requests required

* api/collections

* Remove checkRateLimit on FileOperation (now done at route level through rate limiter)

* auth rate limit

* Add metric logging when rate limit exceeded

* Refactor to shared configs

* test
This commit is contained in:
Tom Moor
2022-08-14 16:04:04 +01:00
committed by GitHub
parent 9338328a82
commit a326e0ee88
14 changed files with 367 additions and 282 deletions

View File

@@ -12,6 +12,7 @@ import {
NetworkError, NetworkError,
NotFoundError, NotFoundError,
OfflineError, OfflineError,
RateLimitExceededError,
RequestError, RequestError,
ServiceUnavailableError, ServiceUnavailableError,
UpdateRequiredError, UpdateRequiredError,
@@ -181,6 +182,12 @@ class ApiClient {
throw new ServiceUnavailableError(error.message); throw new ServiceUnavailableError(error.message);
} }
if (response.status === 429) {
throw new RateLimitExceededError(
`Too many requests, try again in a minute.`
);
}
throw new RequestError(`Error ${response.status}: ${error.message}`); throw new RequestError(`Error ${response.status}: ${error.message}`);
}; };

View File

@@ -12,6 +12,8 @@ export class OfflineError extends ExtendableError {}
export class ServiceUnavailableError extends ExtendableError {} export class ServiceUnavailableError extends ExtendableError {}
export class RateLimitExceededError extends ExtendableError {}
export class RequestError extends ExtendableError {} export class RequestError extends ExtendableError {}
export class UpdateRequiredError extends ExtendableError {} export class UpdateRequiredError extends ExtendableError {}

View File

@@ -22,6 +22,7 @@ export function initSentry(history: History) {
"NetworkError", "NetworkError",
"NotFoundError", "NotFoundError",
"OfflineError", "OfflineError",
"RateLimitExceededError",
"ServiceUnavailableError", "ServiceUnavailableError",
"UpdateRequiredError", "UpdateRequiredError",
"ChunkLoadError", "ChunkLoadError",

View File

@@ -1,7 +1,9 @@
import { RateLimiterRedis } from "rate-limiter-flexible"; import {
IRateLimiterStoreOptions,
RateLimiterRedis,
} from "rate-limiter-flexible";
import env from "@server/env"; import env from "@server/env";
import Redis from "@server/redis"; import Redis from "@server/redis";
import { RateLimiterConfig } from "@server/types";
export default class RateLimiter { export default class RateLimiter {
constructor() { constructor() {
@@ -22,7 +24,7 @@ export default class RateLimiter {
return this.rateLimiterMap.get(path) || this.defaultRateLimiter; return this.rateLimiterMap.get(path) || this.defaultRateLimiter;
} }
static setRateLimiter(path: string, config: RateLimiterConfig): void { static setRateLimiter(path: string, config: IRateLimiterStoreOptions): void {
const rateLimiter = new RateLimiterRedis(config); const rateLimiter = new RateLimiterRedis(config);
this.rateLimiterMap.set(path, rateLimiter); this.rateLimiterMap.set(path, rateLimiter);
} }
@@ -31,3 +33,29 @@ export default class RateLimiter {
return this.rateLimiterMap.has(path); return this.rateLimiterMap.has(path);
} }
} }
/**
* Re-useable configuration for rate limiter middleware.
*/
export const RateLimiterStrategy = {
/** Allows five requests per minute, per IP address */
FivePerMinute: {
duration: 60,
requests: 5,
},
/** Allows ten requests per minute, per IP address */
TenPerMinute: {
duration: 60,
requests: 10,
},
/** Allows ten requests per hour, per IP address */
TenPerHour: {
duration: 3600,
requests: 10,
},
/** Allows five requests per hour, per IP address */
FivePerHour: {
duration: 3600,
requests: 5,
},
};

View File

@@ -14,3 +14,15 @@ export default class MockRateLimiter {
return false; return false;
} }
} }
export const RateLimiterStrategy = new Proxy(
{},
{
get() {
return {
duration: 60,
requests: 10,
};
},
}
);

View File

@@ -110,7 +110,9 @@ class Logger {
extra?: Extra, extra?: Extra,
request?: IncomingMessage request?: IncomingMessage
) { ) {
Metrics.increment("logger.error"); Metrics.increment("logger.error", {
name: error.name,
});
Tracing.setError(error); Tracing.setError(error);
if (env.SENTRY_DSN) { if (env.SENTRY_DSN) {

View File

@@ -3,10 +3,17 @@ import { defaults } from "lodash";
import RateLimiter from "@server/RateLimiter"; import RateLimiter from "@server/RateLimiter";
import env from "@server/env"; import env from "@server/env";
import { RateLimitExceededError } from "@server/errors"; import { RateLimitExceededError } from "@server/errors";
import Metrics from "@server/logging/metrics";
import Redis from "@server/redis"; import Redis from "@server/redis";
import { RateLimiterConfig } from "@server/types";
export function rateLimiter() { /**
* Middleware that limits the number of requests per IP address that are allowed
* within a window. Should only be applied once to a server do not use on
* individual routes.
*
* @returns The middleware function.
*/
export function defaultRateLimiter() {
return async function rateLimiterMiddleware(ctx: Context, next: Next) { return async function rateLimiterMiddleware(ctx: Context, next: Next) {
if (!env.RATE_LIMITER_ENABLED) { if (!env.RATE_LIMITER_ENABLED) {
return next(); return next();
@@ -28,6 +35,10 @@ export function rateLimiter() {
`${new Date(Date.now() + rateLimiterRes.msBeforeNext)}` `${new Date(Date.now() + rateLimiterRes.msBeforeNext)}`
); );
Metrics.increment("rate_limit.exceeded", {
path: ctx.path,
});
throw RateLimitExceededError(); throw RateLimitExceededError();
} }
@@ -35,7 +46,20 @@ export function rateLimiter() {
}; };
} }
export function registerRateLimiter(config: RateLimiterConfig) { type RateLimiterConfig = {
/** The window for which this rate limiter is considered (defaults to 60s) */
duration?: number;
/** The number of requests per IP address that are allowed within the window */
requests: number;
};
/**
* Middleware that limits the number of requests per IP address that are allowed
* within a window, overrides default middleware when used on a route.
*
* @returns The middleware function.
*/
export function rateLimiter(config: RateLimiterConfig) {
return async function registerRateLimiterMiddleware( return async function registerRateLimiterMiddleware(
ctx: Context, ctx: Context,
next: Next next: Next
@@ -47,11 +71,18 @@ export function registerRateLimiter(config: RateLimiterConfig) {
if (!RateLimiter.hasRateLimiter(ctx.path)) { if (!RateLimiter.hasRateLimiter(ctx.path)) {
RateLimiter.setRateLimiter( RateLimiter.setRateLimiter(
ctx.path, ctx.path,
defaults(config, { defaults(
duration: env.RATE_LIMITER_DURATION_WINDOW, {
keyPrefix: RateLimiter.RATE_LIMITER_REDIS_KEY_PREFIX, ...config,
storeClient: Redis.defaultClient, points: config.requests,
}) },
{
duration: 60,
points: env.RATE_LIMITER_REQUESTS,
keyPrefix: RateLimiter.RATE_LIMITER_REDIS_KEY_PREFIX,
storeClient: Redis.defaultClient,
}
)
); );
} }

View File

@@ -1,4 +1,3 @@
import { subHours } from "date-fns";
import { Op, WhereOptions } from "sequelize"; import { Op, WhereOptions } from "sequelize";
import { import {
ForeignKey, ForeignKey,
@@ -8,9 +7,7 @@ import {
BelongsTo, BelongsTo,
Table, Table,
DataType, DataType,
AfterValidate,
} from "sequelize-typescript"; } from "sequelize-typescript";
import { RateLimitExceededError } from "@server/errors";
import { deleteFromS3, getFileByKey } from "@server/utils/s3"; import { deleteFromS3, getFileByKey } from "@server/utils/s3";
import Collection from "./Collection"; import Collection from "./Collection";
import Team from "./Team"; import Team from "./Team";
@@ -53,15 +50,13 @@ export enum FileOperationState {
@Table({ tableName: "file_operations", modelName: "file_operation" }) @Table({ tableName: "file_operations", modelName: "file_operation" })
@Fix @Fix
class FileOperation extends IdModel { class FileOperation extends IdModel {
@Column(DataType.ENUM("import", "export")) @Column(DataType.ENUM(...Object.values(FileOperationType)))
type: FileOperationType; type: FileOperationType;
@Column(DataType.STRING) @Column(DataType.STRING)
format: FileOperationFormat; format: FileOperationFormat;
@Column( @Column(DataType.ENUM(...Object.values(FileOperationState)))
DataType.ENUM("creating", "uploading", "complete", "error", "expired")
)
state: FileOperationState; state: FileOperationState;
@Column @Column
@@ -93,21 +88,6 @@ class FileOperation extends IdModel {
await deleteFromS3(model.key); await deleteFromS3(model.key);
} }
@AfterValidate
static async checkRateLimit(model: FileOperation) {
const count = await this.countExportsAfterDateTime(
model.teamId,
subHours(new Date(), 12),
{
type: model.type,
}
);
if (count >= 12) {
throw RateLimitExceededError();
}
}
// associations // associations
@BelongsTo(() => User, "userId") @BelongsTo(() => User, "userId")

View File

@@ -4,12 +4,13 @@ import Router from "koa-router";
import { Sequelize, Op, WhereOptions } from "sequelize"; import { Sequelize, Op, WhereOptions } from "sequelize";
import { randomElement } from "@shared/random"; import { randomElement } from "@shared/random";
import { colorPalette } from "@shared/utils/collections"; import { colorPalette } from "@shared/utils/collections";
import { RateLimiterStrategy } from "@server/RateLimiter";
import collectionExporter from "@server/commands/collectionExporter"; import collectionExporter from "@server/commands/collectionExporter";
import teamUpdater from "@server/commands/teamUpdater"; import teamUpdater from "@server/commands/teamUpdater";
import { sequelize } from "@server/database/sequelize"; import { sequelize } from "@server/database/sequelize";
import { ValidationError } from "@server/errors"; import { ValidationError } from "@server/errors";
import auth from "@server/middlewares/authentication"; import auth from "@server/middlewares/authentication";
import { rateLimiter } from "@server/middlewares/rateLimiter";
import { import {
Collection, Collection,
CollectionUser, CollectionUser,
@@ -143,54 +144,59 @@ router.post("collections.info", auth(), async (ctx) => {
}; };
}); });
router.post("collections.import", auth(), async (ctx) => { router.post(
const { attachmentId, format = FileOperationFormat.MarkdownZip } = ctx.body; "collections.import",
assertUuid(attachmentId, "attachmentId is required"); auth(),
rateLimiter(RateLimiterStrategy.TenPerHour),
async (ctx) => {
const { attachmentId, format = FileOperationFormat.MarkdownZip } = ctx.body;
assertUuid(attachmentId, "attachmentId is required");
const { user } = ctx.state; const { user } = ctx.state;
authorize(user, "importCollection", user.team); authorize(user, "importCollection", user.team);
const attachment = await Attachment.findByPk(attachmentId); const attachment = await Attachment.findByPk(attachmentId);
authorize(user, "read", attachment); authorize(user, "read", attachment);
assertIn(format, Object.values(FileOperationFormat), "Invalid format"); assertIn(format, Object.values(FileOperationFormat), "Invalid format");
await sequelize.transaction(async (transaction) => { await sequelize.transaction(async (transaction) => {
const fileOperation = await FileOperation.create( const fileOperation = await FileOperation.create(
{ {
type: FileOperationType.Import,
state: FileOperationState.Creating,
format,
size: attachment.size,
key: attachment.key,
userId: user.id,
teamId: user.teamId,
},
{
transaction,
}
);
await Event.create(
{
name: "fileOperations.create",
teamId: user.teamId,
actorId: user.id,
modelId: fileOperation.id,
data: {
type: FileOperationType.Import, type: FileOperationType.Import,
state: FileOperationState.Creating,
format,
size: attachment.size,
key: attachment.key,
userId: user.id,
teamId: user.teamId,
}, },
}, {
{ transaction,
transaction, }
} );
);
});
ctx.body = { await Event.create(
success: true, {
}; name: "fileOperations.create",
}); teamId: user.teamId,
actorId: user.id,
modelId: fileOperation.id,
data: {
type: FileOperationType.Import,
},
},
{
transaction,
}
);
});
ctx.body = {
success: true,
};
}
);
router.post("collections.add_group", auth(), async (ctx) => { router.post("collections.add_group", auth(), async (ctx) => {
const { id, groupId, permission = "read_write" } = ctx.body; const { id, groupId, permission = "read_write" } = ctx.body;
@@ -485,57 +491,67 @@ router.post("collections.memberships", auth(), pagination(), async (ctx) => {
}; };
}); });
router.post("collections.export", auth(), async (ctx) => { router.post(
const { id } = ctx.body; "collections.export",
assertUuid(id, "id is required"); auth(),
const { user } = ctx.state; rateLimiter(RateLimiterStrategy.TenPerHour),
const team = await Team.findByPk(user.teamId); async (ctx) => {
authorize(user, "createExport", team); const { id } = ctx.body;
assertUuid(id, "id is required");
const { user } = ctx.state;
const team = await Team.findByPk(user.teamId);
authorize(user, "createExport", team);
const collection = await Collection.scope({ const collection = await Collection.scope({
method: ["withMembership", user.id], method: ["withMembership", user.id],
}).findByPk(id); }).findByPk(id);
authorize(user, "read", collection); authorize(user, "read", collection);
const fileOperation = await sequelize.transaction(async (transaction) => { const fileOperation = await sequelize.transaction(async (transaction) => {
return collectionExporter({ return collectionExporter({
collection, collection,
user, user,
team, team,
ip: ctx.request.ip, ip: ctx.request.ip,
transaction, transaction,
});
}); });
});
ctx.body = { ctx.body = {
success: true, success: true,
data: { data: {
fileOperation: presentFileOperation(fileOperation), fileOperation: presentFileOperation(fileOperation),
}, },
}; };
}); }
);
router.post("collections.export_all", auth(), async (ctx) => { router.post(
const { user } = ctx.state; "collections.export_all",
const team = await Team.findByPk(user.teamId); auth(),
authorize(user, "createExport", team); rateLimiter(RateLimiterStrategy.TenPerHour),
async (ctx) => {
const { user } = ctx.state;
const team = await Team.findByPk(user.teamId);
authorize(user, "createExport", team);
const fileOperation = await sequelize.transaction(async (transaction) => { const fileOperation = await sequelize.transaction(async (transaction) => {
return collectionExporter({ return collectionExporter({
user, user,
team, team,
ip: ctx.request.ip, ip: ctx.request.ip,
transaction, transaction,
});
}); });
});
ctx.body = { ctx.body = {
success: true, success: true,
data: { data: {
fileOperation: presentFileOperation(fileOperation), fileOperation: presentFileOperation(fileOperation),
}, },
}; };
}); }
);
router.post("collections.update", auth(), async (ctx) => { router.post("collections.update", auth(), async (ctx) => {
const { const {

View File

@@ -5,7 +5,7 @@ import env from "@server/env";
import { NotFoundError } from "@server/errors"; import { NotFoundError } from "@server/errors";
import errorHandling from "@server/middlewares/errorHandling"; import errorHandling from "@server/middlewares/errorHandling";
import methodOverride from "@server/middlewares/methodOverride"; import methodOverride from "@server/middlewares/methodOverride";
import { rateLimiter } from "@server/middlewares/rateLimiter"; import { defaultRateLimiter } from "@server/middlewares/rateLimiter";
import apiKeys from "./apiKeys"; import apiKeys from "./apiKeys";
import attachments from "./attachments"; import attachments from "./attachments";
import auth from "./auth"; import auth from "./auth";
@@ -81,7 +81,7 @@ router.post("*", (ctx) => {
ctx.throw(NotFoundError("Endpoint not found")); ctx.throw(NotFoundError("Endpoint not found"));
}); });
api.use(rateLimiter()); api.use(defaultRateLimiter());
// Router is embedded in a Koa application wrapper, because koa-router does not // Router is embedded in a Koa application wrapper, because koa-router does not
// allow middleware to catch any routes which were not explicitly defined. // allow middleware to catch any routes which were not explicitly defined.

View File

@@ -2,6 +2,7 @@ import crypto from "crypto";
import Router from "koa-router"; import Router from "koa-router";
import { Op, WhereOptions } from "sequelize"; import { Op, WhereOptions } from "sequelize";
import { UserValidation } from "@shared/validations"; import { UserValidation } from "@shared/validations";
import { RateLimiterStrategy } from "@server/RateLimiter";
import userDemoter from "@server/commands/userDemoter"; import userDemoter from "@server/commands/userDemoter";
import userDestroyer from "@server/commands/userDestroyer"; import userDestroyer from "@server/commands/userDestroyer";
import userInviter from "@server/commands/userInviter"; import userInviter from "@server/commands/userInviter";
@@ -13,6 +14,7 @@ import env from "@server/env";
import { ValidationError } from "@server/errors"; import { ValidationError } from "@server/errors";
import logger from "@server/logging/Logger"; import logger from "@server/logging/Logger";
import auth from "@server/middlewares/authentication"; import auth from "@server/middlewares/authentication";
import { rateLimiter } from "@server/middlewares/rateLimiter";
import { Event, User, Team } from "@server/models"; import { Event, User, Team } from "@server/models";
import { UserFlag, UserRole } from "@server/models/User"; import { UserFlag, UserRole } from "@server/models/User";
import { can, authorize } from "@server/policies"; import { can, authorize } from "@server/policies";
@@ -308,26 +310,31 @@ router.post("users.activate", auth(), async (ctx) => {
}; };
}); });
router.post("users.invite", auth(), async (ctx) => { router.post(
const { invites } = ctx.body; "users.invite",
assertArray(invites, "invites must be an array"); auth(),
const { user } = ctx.state; rateLimiter(RateLimiterStrategy.TenPerHour),
const team = await Team.findByPk(user.teamId); async (ctx) => {
authorize(user, "inviteUser", team); const { invites } = ctx.body;
assertArray(invites, "invites must be an array");
const { user } = ctx.state;
const team = await Team.findByPk(user.teamId);
authorize(user, "inviteUser", team);
const response = await userInviter({ const response = await userInviter({
user, user,
invites: invites.slice(0, UserValidation.maxInvitesPerRequest), invites: invites.slice(0, UserValidation.maxInvitesPerRequest),
ip: ctx.request.ip, ip: ctx.request.ip,
}); });
ctx.body = { ctx.body = {
data: { data: {
sent: response.sent, sent: response.sent,
users: response.users.map((user) => presentUser(user)), users: response.users.map((user) => presentUser(user)),
}, },
}; };
}); }
);
router.post("users.resendInvite", auth(), async (ctx) => { router.post("users.resendInvite", auth(), async (ctx) => {
const { id } = ctx.body; const { id } = ctx.body;
@@ -371,49 +378,59 @@ router.post("users.resendInvite", auth(), async (ctx) => {
}; };
}); });
router.post("users.requestDelete", auth(), async (ctx) => { router.post(
const { user } = ctx.state; "users.requestDelete",
authorize(user, "delete", user); auth(),
rateLimiter(RateLimiterStrategy.FivePerHour),
async (ctx) => {
const { user } = ctx.state;
authorize(user, "delete", user);
if (emailEnabled) { if (emailEnabled) {
await ConfirmUserDeleteEmail.schedule({ await ConfirmUserDeleteEmail.schedule({
to: user.email, to: user.email,
deleteConfirmationCode: user.deleteConfirmationCode, deleteConfirmationCode: user.deleteConfirmationCode,
});
}
ctx.body = {
success: true,
};
}
);
router.post(
"users.delete",
auth(),
rateLimiter(RateLimiterStrategy.FivePerHour),
async (ctx) => {
const { code = "" } = ctx.body;
const { user } = ctx.state;
authorize(user, "delete", user);
const deleteConfirmationCode = user.deleteConfirmationCode;
if (
emailEnabled &&
(code.length !== deleteConfirmationCode.length ||
!crypto.timingSafeEqual(
Buffer.from(code),
Buffer.from(deleteConfirmationCode)
))
) {
throw ValidationError("The confirmation code was incorrect");
}
await userDestroyer({
user,
actor: user,
ip: ctx.request.ip,
}); });
ctx.body = {
success: true,
};
} }
);
ctx.body = {
success: true,
};
});
router.post("users.delete", auth(), async (ctx) => {
const { code = "" } = ctx.body;
const { user } = ctx.state;
authorize(user, "delete", user);
const deleteConfirmationCode = user.deleteConfirmationCode;
if (
emailEnabled &&
(code.length !== deleteConfirmationCode.length ||
!crypto.timingSafeEqual(
Buffer.from(code),
Buffer.from(deleteConfirmationCode)
))
) {
throw ValidationError("The confirmation code was incorrect");
}
await userDestroyer({
user,
actor: user,
ip: ctx.request.ip,
});
ctx.body = {
success: true,
};
});
export default router; export default router;

View File

@@ -5,12 +5,15 @@ import bodyParser from "koa-body";
import Router from "koa-router"; import Router from "koa-router";
import { AuthenticationError } from "@server/errors"; import { AuthenticationError } from "@server/errors";
import auth from "@server/middlewares/authentication"; import auth from "@server/middlewares/authentication";
import { defaultRateLimiter } from "@server/middlewares/rateLimiter";
import { Collection, Team, View } from "@server/models"; import { Collection, Team, View } from "@server/models";
import providers from "./providers"; import providers from "./providers";
const app = new Koa(); const app = new Koa();
const router = new Router(); const router = new Router();
router.use(passport.initialize()); router.use(passport.initialize());
router.use(defaultRateLimiter());
// dynamically load available authentication provider routes // dynamically load available authentication provider routes
providers.forEach((provider) => { providers.forEach((provider) => {

View File

@@ -1,7 +1,7 @@
import { subMinutes } from "date-fns";
import Router from "koa-router"; import Router from "koa-router";
import { find } from "lodash"; import { find } from "lodash";
import { parseDomain } from "@shared/utils/domains"; import { parseDomain } from "@shared/utils/domains";
import { RateLimiterStrategy } from "@server/RateLimiter";
import InviteAcceptedEmail from "@server/emails/templates/InviteAcceptedEmail"; import InviteAcceptedEmail from "@server/emails/templates/InviteAcceptedEmail";
import SigninEmail from "@server/emails/templates/SigninEmail"; import SigninEmail from "@server/emails/templates/SigninEmail";
import WelcomeEmail from "@server/emails/templates/WelcomeEmail"; import WelcomeEmail from "@server/emails/templates/WelcomeEmail";
@@ -9,6 +9,7 @@ import env from "@server/env";
import { AuthorizationError } from "@server/errors"; import { AuthorizationError } from "@server/errors";
import errorHandling from "@server/middlewares/errorHandling"; import errorHandling from "@server/middlewares/errorHandling";
import methodOverride from "@server/middlewares/methodOverride"; import methodOverride from "@server/middlewares/methodOverride";
import { rateLimiter } from "@server/middlewares/rateLimiter";
import { User, Team } from "@server/models"; import { User, Team } from "@server/models";
import { signIn } from "@server/utils/authentication"; import { signIn } from "@server/utils/authentication";
import { getUserForEmailSigninToken } from "@server/utils/jwt"; import { getUserForEmailSigninToken } from "@server/utils/jwt";
@@ -23,102 +24,94 @@ export const config = {
router.use(methodOverride()); router.use(methodOverride());
router.post("email", errorHandling(), async (ctx) => { router.post(
const { email } = ctx.body; "email",
assertEmail(email, "email is required"); errorHandling(),
const users = await User.scope("withAuthentications").findAll({ rateLimiter(RateLimiterStrategy.TenPerHour),
where: { async (ctx) => {
email: email.toLowerCase(), const { email } = ctx.body;
}, assertEmail(email, "email is required");
}); const users = await User.scope("withAuthentications").findAll({
where: {
if (users.length) { email: email.toLowerCase(),
let team!: Team | null; },
const domain = parseDomain(ctx.request.hostname);
if (domain.custom) {
team = await Team.scope("withAuthenticationProviders").findOne({
where: {
domain: ctx.request.hostname,
},
});
} else if (env.SUBDOMAINS_ENABLED && domain.teamSubdomain) {
team = await Team.scope("withAuthenticationProviders").findOne({
where: {
subdomain: domain.teamSubdomain,
},
});
}
// If there are multiple users with this email address then give precedence
// to the one that is active on this subdomain/domain (if any)
let user = users.find((user) => team && user.teamId === team.id);
// A user was found for the email address, but they don't belong to the team
// that this subdomain belongs to, we load their team and allow the logic to
// continue
if (!user) {
user = users[0];
team = await Team.scope("withAuthenticationProviders").findByPk(
user.teamId
);
}
if (!team) {
team = await Team.scope("withAuthenticationProviders").findByPk(
user.teamId
);
}
if (!team) {
ctx.redirect(`/?notice=auth-error`);
return;
}
// If the user matches an email address associated with an SSO
// provider then just forward them directly to that sign-in page
if (user.authentications.length) {
const authProvider = find(team.authenticationProviders, {
id: user.authentications[0].authenticationProviderId,
});
ctx.body = {
redirect: `${team.url}/auth/${authProvider?.name}`,
};
return;
}
if (!team.emailSigninEnabled) {
throw AuthorizationError();
}
// basic rate limit of endpoint to prevent send email abuse
if (
user.lastSigninEmailSentAt &&
user.lastSigninEmailSentAt > subMinutes(new Date(), 2)
) {
ctx.body = {
redirect: `${team.url}?notice=email-auth-ratelimit`,
message: "Rate limit exceeded",
success: false,
};
return;
}
// send email to users registered address with a short-lived token
await SigninEmail.schedule({
to: user.email,
token: user.getEmailSigninToken(),
teamUrl: team.url,
}); });
user.lastSigninEmailSentAt = new Date();
await user.save();
}
// respond with success regardless of whether an email was sent if (users.length) {
ctx.body = { let team!: Team | null;
success: true, const domain = parseDomain(ctx.request.hostname);
};
}); if (domain.custom) {
team = await Team.scope("withAuthenticationProviders").findOne({
where: {
domain: ctx.request.hostname,
},
});
} else if (env.SUBDOMAINS_ENABLED && domain.teamSubdomain) {
team = await Team.scope("withAuthenticationProviders").findOne({
where: {
subdomain: domain.teamSubdomain,
},
});
}
// If there are multiple users with this email address then give precedence
// to the one that is active on this subdomain/domain (if any)
let user = users.find((user) => team && user.teamId === team.id);
// A user was found for the email address, but they don't belong to the team
// that this subdomain belongs to, we load their team and allow the logic to
// continue
if (!user) {
user = users[0];
team = await Team.scope("withAuthenticationProviders").findByPk(
user.teamId
);
}
if (!team) {
team = await Team.scope("withAuthenticationProviders").findByPk(
user.teamId
);
}
if (!team) {
ctx.redirect(`/?notice=auth-error`);
return;
}
// If the user matches an email address associated with an SSO
// provider then just forward them directly to that sign-in page
if (user.authentications.length) {
const authProvider = find(team.authenticationProviders, {
id: user.authentications[0].authenticationProviderId,
});
ctx.body = {
redirect: `${team.url}/auth/${authProvider?.name}`,
};
return;
}
if (!team.emailSigninEnabled) {
throw AuthorizationError();
}
// send email to users registered address with a short-lived token
await SigninEmail.schedule({
to: user.email,
token: user.getEmailSigninToken(),
teamUrl: team.url,
});
user.lastSigninEmailSentAt = new Date();
await user.save();
}
// respond with success regardless of whether an email was sent
ctx.body = {
success: true,
};
}
);
router.get("email.callback", async (ctx) => { router.get("email.callback", async (ctx) => {
const { token } = ctx.request.query; const { token } = ctx.request.query;

View File

@@ -1,5 +1,4 @@
import { Context } from "koa"; import { Context } from "koa";
import Redis from "@server/redis";
import { FileOperation, Team, User } from "./models"; import { FileOperation, Team, User } from "./models";
export enum AuthenticationTypes { export enum AuthenticationTypes {
@@ -298,9 +297,3 @@ export type Event =
| UserEvent | UserEvent
| ViewEvent | ViewEvent
| WebhookSubscriptionEvent; | WebhookSubscriptionEvent;
export type RateLimiterConfig = {
points: number;
duration: number;
storeClient: Redis;
};