From 74d9409cc32c5923fe72e98eb961d1d56304519e Mon Sep 17 00:00:00 2001 From: Nan Yu Date: Thu, 4 Aug 2022 05:00:52 -0400 Subject: [PATCH] fix: refactor auth flow to explicitly pass in a host (#3909) * fix: refactor auth flow to explicitly pass in a host * add new error handler to all SSO providers * refactor passport error into middleware --- app/scenes/Login/AuthenticationProvider.tsx | 14 ++++++---- server/errors.ts | 10 ------- server/middlewares/passport.ts | 25 ++++++++++++++--- server/utils/passport.ts | 31 ++++++--------------- 4 files changed, 38 insertions(+), 42 deletions(-) diff --git a/app/scenes/Login/AuthenticationProvider.tsx b/app/scenes/Login/AuthenticationProvider.tsx index 3e58da287..b5bb854ad 100644 --- a/app/scenes/Login/AuthenticationProvider.tsx +++ b/app/scenes/Login/AuthenticationProvider.tsx @@ -89,11 +89,15 @@ function AuthenticationProvider(props: Props) { ); } - // If we're on a custom domain then the auth must point to the root - // app.getoutline.com for authentication so that the state cookie can be set - // and read. - const isCustomDomain = parseDomain(window.location.origin).custom; - const href = `${isCustomDomain ? env.URL : ""}${authUrl}`; + // If we're on a custom domain or a subdomain then the auth must point to the + // apex (env.URL) for authentication so that the state cookie can be set and read. + // We pass the host into the auth URL so that the server can redirect on error + // and keep the user on the same page. + const { custom, teamSubdomain, host } = parseDomain(window.location.origin); + const needsRedirect = custom || teamSubdomain; + const href = needsRedirect + ? `${env.URL}${authUrl}?host=${encodeURI(host)}` + : authUrl; return ( diff --git a/server/errors.ts b/server/errors.ts index f9f9b1194..5b7f89609 100644 --- a/server/errors.ts +++ b/server/errors.ts @@ -161,16 +161,6 @@ export function GmailAccountCreationError( }); } -export function AuthRedirectError( - message = "Redirect to the correct domain after authentication", - redirectUrl: string -) { - return httpErrors(400, message, { - id: "auth_redirect", - redirectUrl, - }); -} - export function OIDCMalformedUserInfoError( message = "User profile information malformed" ) { diff --git a/server/middlewares/passport.ts b/server/middlewares/passport.ts index 8879ea29f..6b17e2a66 100644 --- a/server/middlewares/passport.ts +++ b/server/middlewares/passport.ts @@ -3,6 +3,7 @@ import { Context } from "koa"; import env from "@server/env"; import Logger from "@server/logging/Logger"; import { signIn } from "@server/utils/authentication"; +import { parseState } from "@server/utils/passport"; import { AccountProvisionerResult } from "../commands/accountProvisioner"; export default function createMiddleware(providerName: string) { @@ -18,12 +19,28 @@ export default function createMiddleware(providerName: string) { if (err.id) { const notice = err.id.replace(/_/g, "-"); - const hasQueryString = err.redirectUrl?.includes("?"); + const redirectUrl = err.redirectUrl ?? "/"; + const hasQueryString = redirectUrl?.includes("?"); + + // Every authentication action is routed through the apex domain. + // But when there is an error, we want to redirect the user on the + // same domain or subdomain that they originated from (found in state). + + // get original host + const state = ctx.cookies.get("state"); + const host = state ? parseState(state).host : ctx.hostname; + + // form a URL object with the err.redirectUrl and replace the host + const reqProtocol = ctx.protocol; + const requestHost = ctx.get("host"); + const url = new URL( + `${reqProtocol}://${requestHost}${redirectUrl}` + ); + + url.host = host; return ctx.redirect( - `${err.redirectUrl || "/"}${ - hasQueryString ? "&" : "?" - }notice=${notice}` + `${url.toString()}${hasQueryString ? "&" : "?"}notice=${notice}` ); } diff --git a/server/utils/passport.ts b/server/utils/passport.ts index 95915911e..897e32081 100644 --- a/server/utils/passport.ts +++ b/server/utils/passport.ts @@ -9,18 +9,19 @@ import { import { getCookieDomain, parseDomain } from "@shared/utils/domains"; import env from "@server/env"; import { Team } from "@server/models"; -import { AuthRedirectError, OAuthStateMismatchError } from "../errors"; +import { OAuthStateMismatchError } from "../errors"; export class StateStore { key = "state"; store = (ctx: Context, callback: StateStoreStoreCallback) => { // token is a short lived one-time pad to prevent replay attacks - // appDomain is the domain the user originated from when attempting auth - // we expect it to be a team subdomain, custom domain, or apex domain const token = crypto.randomBytes(8).toString("hex"); - const appDomain = parseDomain(ctx.hostname); - const state = buildState(appDomain.host, token); + + // We expect host to be a team subdomain, custom domain, or apex domain + // that is passed via query param from the auth provider component. + const host = ctx.query.host?.toString() || parseDomain(ctx.hostname).host; + const state = buildState(host, token); ctx.cookies.set(this.key, state, { httpOnly: false, @@ -46,24 +47,7 @@ export class StateStore { ); } - const { host, token } = parseState(state); - - // Oauth callbacks are hard-coded to come to the apex domain, so we - // redirect to the original app domain before attempting authentication. - // If there is an error during auth, the user will end up on the same domain - // that they started from. - const appDomain = parseDomain(host); - if (appDomain.host !== parseDomain(ctx.hostname).host) { - const reqProtocol = ctx.protocol; - const requestHost = ctx.get("host"); - const requestPath = ctx.originalUrl; - const requestUrl = `${reqProtocol}://${requestHost}${requestPath}`; - const url = new URL(requestUrl); - - url.host = appDomain.host; - - return callback(AuthRedirectError(``, url.toString()), false, token); - } + const { token } = parseState(state); // Destroy the one-time pad token and ensure it matches ctx.cookies.set(this.key, "", { @@ -106,6 +90,7 @@ export async function getTeamFromContext(ctx: Context) { // we use it to infer the team they intend on signing into const state = ctx.cookies.get("state"); const host = state ? parseState(state).host : ctx.hostname; + const domain = parseDomain(host); let team;