Zod schemas for routes under /plugins (#6378)

* fix: schema for slack routes

* fix: slack.post

* fix: email
This commit is contained in:
Apoorv Mishra
2024-01-13 10:55:30 +05:30
committed by GitHub
parent 7e61a519f1
commit 3561b79d65
7 changed files with 186 additions and 77 deletions

View File

@@ -8,6 +8,15 @@ import { getTestServer } from "@server/test/support";
const server = getTestServer(); const server = getTestServer();
describe("email", () => { describe("email", () => {
it("should fail with status 400 bad request if email is invalid", async () => {
const res = await server.post("/auth/email", {
body: { email: "invalid" },
});
const body = await res.json();
expect(res.status).toEqual(400);
expect(body.message).toEqual("email: Invalid email");
});
it("should require email param", async () => { it("should require email param", async () => {
const res = await server.post("/auth/email", { const res = await server.post("/auth/email", {
body: {}, body: {},

View File

@@ -1,5 +1,5 @@
import Router from "koa-router"; import Router from "koa-router";
import { Client, NotificationEventType } from "@shared/types"; import { NotificationEventType } from "@shared/types";
import { parseDomain } from "@shared/utils/domains"; import { parseDomain } from "@shared/utils/domains";
import InviteAcceptedEmail from "@server/emails/templates/InviteAcceptedEmail"; import InviteAcceptedEmail from "@server/emails/templates/InviteAcceptedEmail";
import SigninEmail from "@server/emails/templates/SigninEmail"; import SigninEmail from "@server/emails/templates/SigninEmail";
@@ -7,20 +7,22 @@ import WelcomeEmail from "@server/emails/templates/WelcomeEmail";
import env from "@server/env"; import env from "@server/env";
import { AuthorizationError } from "@server/errors"; import { AuthorizationError } from "@server/errors";
import { rateLimiter } from "@server/middlewares/rateLimiter"; import { rateLimiter } from "@server/middlewares/rateLimiter";
import validate from "@server/middlewares/validate";
import { User, Team } from "@server/models"; import { User, Team } from "@server/models";
import { APIContext } from "@server/types";
import { RateLimiterStrategy } from "@server/utils/RateLimiter"; import { RateLimiterStrategy } from "@server/utils/RateLimiter";
import { signIn } from "@server/utils/authentication"; import { signIn } from "@server/utils/authentication";
import { getUserForEmailSigninToken } from "@server/utils/jwt"; import { getUserForEmailSigninToken } from "@server/utils/jwt";
import { assertEmail, assertPresent } from "@server/validation"; import * as T from "./schema";
const router = new Router(); const router = new Router();
router.post( router.post(
"email", "email",
rateLimiter(RateLimiterStrategy.TenPerHour), rateLimiter(RateLimiterStrategy.TenPerHour),
async (ctx) => { validate(T.EmailSchema),
const { email, client } = ctx.request.body; async (ctx: APIContext<T.EmailReq>) => {
assertEmail(email, "email is required"); const { email, client } = ctx.input.body;
const domain = parseDomain(ctx.request.hostname); const domain = parseDomain(ctx.request.hostname);
@@ -71,7 +73,7 @@ router.post(
to: user.email, to: user.email,
token: user.getEmailSigninToken(), token: user.getEmailSigninToken(),
teamUrl: team.url, teamUrl: team.url,
client: client === Client.Desktop ? Client.Desktop : Client.Web, client,
}).schedule(); }).schedule();
user.lastSigninEmailSentAt = new Date(); user.lastSigninEmailSentAt = new Date();
@@ -84,52 +86,57 @@ router.post(
} }
); );
router.get("email.callback", async (ctx) => { router.get(
const { token, client } = ctx.request.query; "email.callback",
assertPresent(token, "token is required"); validate(T.EmailCallbackSchema),
async (ctx: APIContext<T.EmailCallbackReq>) => {
const { token, client } = ctx.input.query;
let user!: User; let user!: User;
try { try {
user = await getUserForEmailSigninToken(token as string); user = await getUserForEmailSigninToken(token as string);
} catch (err) { } catch (err) {
ctx.redirect(`/?notice=expired-token`); ctx.redirect(`/?notice=expired-token`);
return; return;
} }
if (!user.team.emailSigninEnabled) { if (!user.team.emailSigninEnabled) {
return ctx.redirect("/?notice=auth-error"); return ctx.redirect("/?notice=auth-error");
} }
if (user.isSuspended) { if (user.isSuspended) {
return ctx.redirect("/?notice=user-suspended"); return ctx.redirect("/?notice=user-suspended");
} }
if (user.isInvited) { if (user.isInvited) {
await new WelcomeEmail({ await new WelcomeEmail({
to: user.email, to: user.email,
teamUrl: user.team.url,
}).schedule();
const inviter = await user.$get("invitedBy");
if (inviter?.subscribedToEventType(NotificationEventType.InviteAccepted)) {
await new InviteAcceptedEmail({
to: inviter.email,
inviterId: inviter.id,
invitedName: user.name,
teamUrl: user.team.url, teamUrl: user.team.url,
}).schedule(); }).schedule();
}
}
// set cookies on response and redirect to team subdomain const inviter = await user.$get("invitedBy");
await signIn(ctx, "email", { if (
user, inviter?.subscribedToEventType(NotificationEventType.InviteAccepted)
team: user.team, ) {
isNewTeam: false, await new InviteAcceptedEmail({
isNewUser: false, to: inviter.email,
client: client === Client.Desktop ? Client.Desktop : Client.Web, inviterId: inviter.id,
}); invitedName: user.name,
}); teamUrl: user.team.url,
}).schedule();
}
}
// set cookies on response and redirect to team subdomain
await signIn(ctx, "email", {
user,
team: user.team,
isNewTeam: false,
isNewUser: false,
client,
});
}
);
export default router; export default router;

View File

@@ -0,0 +1,21 @@
import { z } from "zod";
import { Client } from "@shared/types";
import { BaseSchema } from "@server/routes/api/schema";
export const EmailSchema = BaseSchema.extend({
body: z.object({
email: z.string().email(),
client: z.nativeEnum(Client).default(Client.Web),
}),
});
export type EmailReq = z.infer<typeof EmailSchema>;
export const EmailCallbackSchema = BaseSchema.extend({
query: z.object({
token: z.string(),
client: z.nativeEnum(Client).default(Client.Web),
}),
});
export type EmailCallbackReq = z.infer<typeof EmailCallbackSchema>;

View File

@@ -0,0 +1,31 @@
import isEmpty from "lodash/isEmpty";
import { z } from "zod";
import { BaseSchema } from "@server/routes/api/schema";
export const SlackCommandsSchema = BaseSchema.extend({
query: z
.object({
code: z.string().nullish(),
state: z.string().uuid().nullish(),
error: z.string().nullish(),
})
.refine((req) => !(isEmpty(req.code) && isEmpty(req.error)), {
message: "one of code or error is required",
}),
});
export type SlackCommandsReq = z.infer<typeof SlackCommandsSchema>;
export const SlackPostSchema = BaseSchema.extend({
query: z
.object({
code: z.string().nullish(),
state: z.string().uuid().nullish(),
error: z.string().nullish(),
})
.refine((req) => !(isEmpty(req.code) && isEmpty(req.error)), {
message: "one of code or error is required",
}),
});
export type SlackPostReq = z.infer<typeof SlackPostSchema>;

View File

@@ -0,0 +1,39 @@
import { getTestServer } from "@server/test/support";
const server = getTestServer();
describe("#slack.commands", () => {
it("should fail with status 400 bad request if query param state is not a uuid", async () => {
const res = await server.get("/auth/slack.commands?state=123");
const body = await res.json();
expect(res.status).toEqual(400);
expect(body.message).toEqual("state: Invalid uuid");
});
it("should fail with status 400 bad request when both code and error are missing in query params", async () => {
const res = await server.get(
"/auth/slack.commands?state=182d14d5-0dbd-4521-ac52-25484c25c96e"
);
const body = await res.json();
expect(res.status).toEqual(400);
expect(body.message).toEqual("query: one of code or error is required");
});
});
describe("#slack.post", () => {
it("should fail with status 400 bad request if query param state is not a uuid", async () => {
const res = await server.get("/auth/slack.post?state=123");
const body = await res.json();
expect(res.status).toEqual(400);
expect(body.message).toEqual("state: Invalid uuid");
});
it("should fail with status 400 bad request when both code and error are missing in query params", async () => {
const res = await server.get(
"/auth/slack.post?state=182d14d5-0dbd-4521-ac52-25484c25c96e"
);
const body = await res.json();
expect(res.status).toEqual(400);
expect(body.message).toEqual("query: one of code or error is required");
});
});

View File

@@ -9,6 +9,7 @@ import accountProvisioner from "@server/commands/accountProvisioner";
import env from "@server/env"; import env from "@server/env";
import auth from "@server/middlewares/authentication"; import auth from "@server/middlewares/authentication";
import passportMiddleware from "@server/middlewares/passport"; import passportMiddleware from "@server/middlewares/passport";
import validate from "@server/middlewares/validate";
import { import {
IntegrationAuthentication, IntegrationAuthentication,
Collection, Collection,
@@ -16,14 +17,14 @@ import {
Team, Team,
User, User,
} from "@server/models"; } from "@server/models";
import { AppContext, AuthenticationResult } from "@server/types"; import { APIContext, AuthenticationResult } from "@server/types";
import { import {
getClientFromContext, getClientFromContext,
getTeamFromContext, getTeamFromContext,
StateStore, StateStore,
} from "@server/utils/passport"; } from "@server/utils/passport";
import { assertPresent, assertUuid } from "@server/validation";
import * as Slack from "../slack"; import * as Slack from "../slack";
import * as T from "./schema";
type SlackProfile = Profile & { type SlackProfile = Profile & {
team: { team: {
@@ -132,10 +133,10 @@ if (env.SLACK_CLIENT_ID && env.SLACK_CLIENT_SECRET) {
auth({ auth({
optional: true, optional: true,
}), }),
async (ctx: AppContext) => { validate(T.SlackCommandsSchema),
const { code, state, error } = ctx.request.query; async (ctx: APIContext<T.SlackCommandsReq>) => {
const { code, state: teamId, error } = ctx.input.query;
const { user } = ctx.state.auth; const { user } = ctx.state.auth;
assertPresent(code || error, "code is required");
if (error) { if (error) {
ctx.redirect(integrationSettingsPath(`slack?error=${error}`)); ctx.redirect(integrationSettingsPath(`slack?error=${error}`));
@@ -146,9 +147,9 @@ if (env.SLACK_CLIENT_ID && env.SLACK_CLIENT_SECRET) {
// access authentication for subdomains. We must forward to the appropriate // access authentication for subdomains. We must forward to the appropriate
// subdomain to complete the oauth flow // subdomain to complete the oauth flow
if (!user) { if (!user) {
if (state) { if (teamId) {
try { try {
const team = await Team.findByPk(String(state), { const team = await Team.findByPk(teamId, {
rejectOnEmpty: true, rejectOnEmpty: true,
}); });
return redirectOnClient( return redirectOnClient(
@@ -168,7 +169,8 @@ if (env.SLACK_CLIENT_ID && env.SLACK_CLIENT_SECRET) {
} }
const endpoint = `${env.URL}/auth/slack.commands`; const endpoint = `${env.URL}/auth/slack.commands`;
const data = await Slack.oauthAccess(String(code), endpoint); // validation middleware ensures that code is non-null at this point
const data = await Slack.oauthAccess(code!, endpoint);
const authentication = await IntegrationAuthentication.create({ const authentication = await IntegrationAuthentication.create({
service: IntegrationService.Slack, service: IntegrationService.Slack,
userId: user.id, userId: user.id,
@@ -195,14 +197,10 @@ if (env.SLACK_CLIENT_ID && env.SLACK_CLIENT_SECRET) {
auth({ auth({
optional: true, optional: true,
}), }),
async (ctx: AppContext) => { validate(T.SlackPostSchema),
const { code, error, state } = ctx.request.query; async (ctx: APIContext<T.SlackPostReq>) => {
const { code, error, state: collectionId } = ctx.input.query;
const { user } = ctx.state.auth; const { user } = ctx.state.auth;
assertPresent(code || error, "code is required");
// FIX ME! What about having zod like schema in place here?
const collectionId = state as string;
assertUuid(collectionId, "collectionId must be an uuid");
if (error) { if (error) {
ctx.redirect(integrationSettingsPath(`slack?error=${error}`)); ctx.redirect(integrationSettingsPath(`slack?error=${error}`));
@@ -213,21 +211,24 @@ if (env.SLACK_CLIENT_ID && env.SLACK_CLIENT_SECRET) {
// access authentication for subdomains. We must forward to the // access authentication for subdomains. We must forward to the
// appropriate subdomain to complete the oauth flow // appropriate subdomain to complete the oauth flow
if (!user) { if (!user) {
try { if (collectionId) {
const collection = await Collection.findOne({ try {
where: { const collection = await Collection.findByPk(collectionId, {
id: String(state), rejectOnEmpty: true,
}, });
rejectOnEmpty: true, const team = await Team.findByPk(collection.teamId, {
}); rejectOnEmpty: true,
const team = await Team.findByPk(collection.teamId, { });
rejectOnEmpty: true, return redirectOnClient(
}); ctx,
return redirectOnClient( `${team.url}/auth/slack.post?${ctx.request.querystring}`
ctx, );
`${team.url}/auth/slack.post?${ctx.request.querystring}` } catch (err) {
); return ctx.redirect(
} catch (err) { integrationSettingsPath(`slack?error=unauthenticated`)
);
}
} else {
return ctx.redirect( return ctx.redirect(
integrationSettingsPath(`slack?error=unauthenticated`) integrationSettingsPath(`slack?error=unauthenticated`)
); );
@@ -235,7 +236,8 @@ if (env.SLACK_CLIENT_ID && env.SLACK_CLIENT_SECRET) {
} }
const endpoint = `${env.URL}/auth/slack.post`; const endpoint = `${env.URL}/auth/slack.post`;
const data = await Slack.oauthAccess(code as string, endpoint); // validation middleware ensures that code is non-null at this point
const data = await Slack.oauthAccess(code!, endpoint);
const authentication = await IntegrationAuthentication.create({ const authentication = await IntegrationAuthentication.create({
service: IntegrationService.Slack, service: IntegrationService.Slack,
userId: user.id, userId: user.id,

View File

@@ -32,7 +32,7 @@ export default class AuthenticationHelper {
const rootDir = env.ENVIRONMENT === "test" ? "" : "build"; const rootDir = env.ENVIRONMENT === "test" ? "" : "build";
glob glob
.sync(path.join(rootDir, "plugins/*/server/auth/!(*.test).[jt]s")) .sync(path.join(rootDir, "plugins/*/server/auth/!(*.test|schema).[jt]s"))
.forEach((filePath: string) => { .forEach((filePath: string) => {
const { default: authProvider, name } = require(path.join( const { default: authProvider, name } = require(path.join(
process.cwd(), process.cwd(),