diff --git a/.env.sample b/.env.sample index 64d5228f8..8e61333bc 100644 --- a/.env.sample +++ b/.env.sample @@ -137,10 +137,6 @@ MAXIMUM_IMPORT_SIZE=5120000 # requests and this ends up being duplicative DEBUG=http -# Comma separated list of domains to be allowed to signin to the wiki. If not -# set, all domains are allowed by default when using Google OAuth to signin -ALLOWED_DOMAINS= - # For a complete Slack integration with search and posting to channels the # following configs are also needed, some more details # => https://wiki.generaloutline.com/share/be25efd1-b3ef-4450-b8e5-c4a4fc11e02a diff --git a/app.json b/app.json index 7ebbc83f8..1085b5c27 100644 --- a/app.json +++ b/app.json @@ -43,10 +43,6 @@ "value": "true", "required": true }, - "ALLOWED_DOMAINS": { - "description": "Comma separated list of domains to be allowed (optional). If not set, all domains are allowed by default when using Google OAuth to signin. Consider putting {your app name}.herokuapp.com and any domain you are binding on in this list.", - "required": false - }, "URL": { "description": "https://{your app name}.herokuapp.com, or the domain you are binding to", "required": true @@ -209,4 +205,4 @@ "required": false } } -} \ No newline at end of file +} diff --git a/app/models/Team.ts b/app/models/Team.ts index e9565fde4..d829d518f 100644 --- a/app/models/Team.ts +++ b/app/models/Team.ts @@ -55,6 +55,10 @@ class Team extends BaseModel { url: string; + @Field + @observable + allowedDomains: string[] | null | undefined; + @computed get signinMethods(): string { return "SSO"; diff --git a/app/scenes/Login/Notices.tsx b/app/scenes/Login/Notices.tsx index 639e6e13d..6d7de2b90 100644 --- a/app/scenes/Login/Notices.tsx +++ b/app/scenes/Login/Notices.tsx @@ -21,12 +21,6 @@ export default function Notices() { installation. Try another? )} - {notice === "hd-not-allowed" && ( - - Sorry, your Google apps domain is not allowed. Please try again with - an allowed team domain. - - )} {notice === "malformed_user_info" && ( We could not read the user info supplied by your identity provider. @@ -79,6 +73,12 @@ export default function Notices() { Please request an invite from your team admin and try again. )} + {notice === "domain-not-allowed" && ( + + Sorry, your domain is not allowed. Please try again with an allowed + team domain. + + )} ); } diff --git a/app/scenes/Settings/Security.tsx b/app/scenes/Settings/Security.tsx index 82e21d82d..82947e0e7 100644 --- a/app/scenes/Settings/Security.tsx +++ b/app/scenes/Settings/Security.tsx @@ -1,15 +1,21 @@ import { debounce } from "lodash"; import { observer } from "mobx-react"; -import { PadlockIcon } from "outline-icons"; +import { CloseIcon, PadlockIcon } from "outline-icons"; import { useState } from "react"; import * as React from "react"; import { useTranslation, Trans } from "react-i18next"; +import styled from "styled-components"; +import Button from "~/components/Button"; import ConfirmationDialog from "~/components/ConfirmationDialog"; +import Flex from "~/components/Flex"; import Heading from "~/components/Heading"; +import Input from "~/components/Input"; import InputSelect from "~/components/InputSelect"; +import NudeButton from "~/components/NudeButton"; import Scene from "~/components/Scene"; import Switch from "~/components/Switch"; import Text from "~/components/Text"; +import Tooltip from "~/components/Tooltip"; import env from "~/env"; import useCurrentTeam from "~/hooks/useCurrentTeam"; import useStores from "~/hooks/useStores"; @@ -29,6 +35,7 @@ function Security() { defaultUserRole: team.defaultUserRole, memberCollectionCreate: team.memberCollectionCreate, inviteRequired: team.inviteRequired, + allowedDomains: team.allowedDomains, }); const authenticationMethods = team.signinMethods; @@ -43,6 +50,8 @@ function Security() { [showToast, t] ); + const [domainsChanged, setDomainsChanged] = useState(false); + const saveData = React.useCallback( async (newData) => { try { @@ -53,6 +62,8 @@ function Security() { showToast(err.message, { type: "error", }); + } finally { + setDomainsChanged(false); } }, [auth, showSuccessMessage, showToast] @@ -110,6 +121,36 @@ function Security() { [data, saveData, t, dialogs, authenticationMethods] ); + const handleRemoveDomain = async (index: number) => { + const newData = { + ...data, + }; + newData.allowedDomains && newData.allowedDomains.splice(index, 1); + + setData(newData); + setDomainsChanged(true); + }; + + const handleAddDomain = () => { + const newData = { + ...data, + allowedDomains: [...(data.allowedDomains || []), ""], + }; + + setData(newData); + setDomainsChanged(true); + }; + + const createOnDomainChangedHandler = (index: number) => ( + ev: React.ChangeEvent + ) => { + const newData = { ...data }; + + newData.allowedDomains![index] = ev.currentTarget.value; + setData(newData); + setDomainsChanged(true); + }; + return ( }> {t("Security")} @@ -220,8 +261,62 @@ function Security() { short /> + + + {data.allowedDomains && + data.allowedDomains.map((domain, index) => ( + + + + + handleRemoveDomain(index)}> + + + + + + ))} + + + {!data.allowedDomains?.length || + data.allowedDomains[data.allowedDomains.length - 1] !== "" ? ( + + ) : ( + + )} + + {domainsChanged && ( + + )} + + ); } +const Remove = styled("div")` + margin-top: 6px; +`; + export default observer(Security); diff --git a/app/scenes/Settings/components/SettingRow.tsx b/app/scenes/Settings/components/SettingRow.tsx index 04df5bf51..e0e7ccc57 100644 --- a/app/scenes/Settings/components/SettingRow.tsx +++ b/app/scenes/Settings/components/SettingRow.tsx @@ -41,6 +41,10 @@ const Column = styled.div` min-width: 60%; } + &:last-child { + min-width: 0; + } + ${breakpoint("tablet")` p { margin-bottom: 0; diff --git a/server/commands/accountProvisioner.test.ts b/server/commands/accountProvisioner.test.ts index 64522fae6..770bb3080 100644 --- a/server/commands/accountProvisioner.test.ts +++ b/server/commands/accountProvisioner.test.ts @@ -1,8 +1,9 @@ import WelcomeEmail from "@server/emails/templates/WelcomeEmail"; +import { TeamDomain } from "@server/models"; import Collection from "@server/models/Collection"; import UserAuthentication from "@server/models/UserAuthentication"; import { buildUser, buildTeam } from "@server/test/factories"; -import { flushdb } from "@server/test/support"; +import { flushdb, seed } from "@server/test/support"; import accountProvisioner from "./accountProvisioner"; beforeEach(() => { @@ -148,6 +149,100 @@ describe("accountProvisioner", () => { expect(error).toBeTruthy(); }); + it("should throw an error when the domain is not allowed", async () => { + const { admin, team: existingTeam } = await seed(); + const providers = await existingTeam.$get("authenticationProviders"); + const authenticationProvider = providers[0]; + + await TeamDomain.create({ + teamId: existingTeam.id, + name: "other.com", + createdById: admin.id, + }); + + let error; + + try { + await accountProvisioner({ + ip, + user: { + name: "Jenny Tester", + email: "jenny@example.com", + avatarUrl: "https://example.com/avatar.png", + username: "jtester", + }, + team: { + name: existingTeam.name, + avatarUrl: existingTeam.avatarUrl, + subdomain: "example", + }, + authenticationProvider: { + name: authenticationProvider.name, + providerId: authenticationProvider.providerId, + }, + authentication: { + providerId: "123456789", + accessToken: "123", + scopes: ["read"], + }, + }); + } catch (err) { + error = err; + } + + expect(error).toBeTruthy(); + }); + + it("should create a new user in an existing team when the domain is allowed", async () => { + const spy = jest.spyOn(WelcomeEmail, "schedule"); + const { admin, team } = await seed(); + const authenticationProviders = await team.$get("authenticationProviders"); + const authenticationProvider = authenticationProviders[0]; + await TeamDomain.create({ + teamId: team.id, + name: "example.com", + createdById: admin.id, + }); + + const { user, isNewUser } = await accountProvisioner({ + ip, + user: { + name: "Jenny Tester", + email: "jenny@example.com", + avatarUrl: "https://example.com/avatar.png", + username: "jtester", + }, + team: { + name: team.name, + avatarUrl: team.avatarUrl, + subdomain: "example", + }, + authenticationProvider: { + name: authenticationProvider.name, + providerId: authenticationProvider.providerId, + }, + authentication: { + providerId: "123456789", + accessToken: "123", + scopes: ["read"], + }, + }); + const authentications = await user.$get("authentications"); + const auth = authentications[0]; + expect(auth.accessToken).toEqual("123"); + expect(auth.scopes.length).toEqual(1); + expect(auth.scopes[0]).toEqual("read"); + expect(user.email).toEqual("jenny@example.com"); + expect(user.username).toEqual("jtester"); + expect(isNewUser).toEqual(true); + expect(spy).toHaveBeenCalled(); + // should provision welcome collection + const collectionCount = await Collection.count(); + expect(collectionCount).toEqual(1); + + spy.mockRestore(); + }); + it("should create a new user in an existing team", async () => { const spy = jest.spyOn(WelcomeEmail, "schedule"); const team = await buildTeam(); diff --git a/server/commands/teamCreator.test.ts b/server/commands/teamCreator.test.ts index 41bffa31f..2a0824e17 100644 --- a/server/commands/teamCreator.test.ts +++ b/server/commands/teamCreator.test.ts @@ -1,4 +1,5 @@ -import { buildTeam } from "@server/test/factories"; +import TeamDomain from "@server/models/TeamDomain"; +import { buildTeam, buildUser } from "@server/test/factories"; import { flushdb } from "@server/test/support"; import teamCreator from "./teamCreator"; @@ -48,6 +49,15 @@ describe("teamCreator", () => { it("should return existing team when within allowed domains", async () => { delete process.env.DEPLOYMENT; const existing = await buildTeam(); + const user = await buildUser({ + teamId: existing.id, + }); + await TeamDomain.create({ + teamId: existing.id, + name: "allowed-domain.com", + createdById: user.id, + }); + const result = await teamCreator({ name: "Updated name", subdomain: "example", @@ -67,6 +77,34 @@ describe("teamCreator", () => { expect(providers.length).toEqual(2); }); + it("should error when NOT within allowed domains", async () => { + const user = await buildUser(); + delete process.env.DEPLOYMENT; + const existing = await buildTeam(); + await TeamDomain.create({ + teamId: existing.id, + name: "other-domain.com", + createdById: user.id, + }); + + let error; + try { + await teamCreator({ + name: "Updated name", + subdomain: "example", + domain: "allowed-domain.com", + authenticationProvider: { + name: "google", + providerId: "allowed-domain.com", + }, + }); + } catch (err) { + error = err; + } + + expect(error).toBeTruthy(); + }); + it("should return exising team", async () => { delete process.env.DEPLOYMENT; const authenticationProvider = { diff --git a/server/commands/teamCreator.ts b/server/commands/teamCreator.ts index b1aa85f74..bf8afe413 100644 --- a/server/commands/teamCreator.ts +++ b/server/commands/teamCreator.ts @@ -1,10 +1,9 @@ import invariant from "invariant"; +import { DomainNotAllowedError, MaximumTeamsError } from "@server/errors"; import Logger from "@server/logging/logger"; import { APM } from "@server/logging/tracing"; import { Team, AuthenticationProvider } from "@server/models"; -import { isDomainAllowed } from "@server/utils/authentication"; import { generateAvatarUrl } from "@server/utils/avatars"; -import { MaximumTeamsError } from "../errors"; type TeamCreatorResult = { team: Team; @@ -60,19 +59,23 @@ async function teamCreator({ // If the self-hosted installation has a single team and the domain for the // new team is allowed then assign the authentication provider to the // existing team - if (teamCount === 1 && domain && isDomainAllowed(domain)) { + if (teamCount === 1 && domain) { const team = await Team.findOne(); invariant(team, "Team should exist"); - authP = await team.$create( - "authenticationProvider", - authenticationProvider - ); - return { - authenticationProvider: authP, - team, - isNewTeam: false, - }; + if (await team.isDomainAllowed(domain)) { + authP = await team.$create( + "authenticationProvider", + authenticationProvider + ); + return { + authenticationProvider: authP, + team, + isNewTeam: false, + }; + } else { + throw DomainNotAllowedError(); + } } if (teamCount >= 1) { diff --git a/server/commands/teamUpdater.ts b/server/commands/teamUpdater.ts index 306c82dfa..cb9063b13 100644 --- a/server/commands/teamUpdater.ts +++ b/server/commands/teamUpdater.ts @@ -1,9 +1,9 @@ import { Transaction } from "sequelize"; import { sequelize } from "@server/database/sequelize"; -import { Event, Team, User } from "@server/models"; +import { Event, Team, TeamDomain, User } from "@server/models"; type TeamUpdaterProps = { - params: Partial; + params: Partial> & { allowedDomains?: string[] }; ip?: string; user: User; team: Team; @@ -22,8 +22,11 @@ const teamUpdater = async ({ params, user, team, ip }: TeamUpdaterProps) => { defaultCollectionId, defaultUserRole, inviteRequired, + allowedDomains, } = params; + const transaction: Transaction = await sequelize.transaction(); + if (subdomain !== undefined && process.env.SUBDOMAINS_ENABLED === "true") { team.subdomain = subdomain === "" ? null : subdomain; } @@ -58,11 +61,50 @@ const teamUpdater = async ({ params, user, team, ip }: TeamUpdaterProps) => { if (inviteRequired !== undefined) { team.inviteRequired = inviteRequired; } + if (allowedDomains !== undefined) { + const existingAllowedDomains = await TeamDomain.findAll({ + where: { teamId: team.id }, + transaction, + }); + + // Only keep existing domains if they are still in the list of allowed domains + const newAllowedDomains = team.allowedDomains.filter((existingTeamDomain) => + allowedDomains.includes(existingTeamDomain.name) + ); + + // Add new domains + const existingDomains = team.allowedDomains.map((x) => x.name); + const newDomains = allowedDomains.filter( + (newDomain) => newDomain !== "" && !existingDomains.includes(newDomain) + ); + await Promise.all( + newDomains.map(async (newDomain) => { + newAllowedDomains.push( + await TeamDomain.create( + { + name: newDomain, + teamId: team.id, + createdById: user.id, + }, + { transaction } + ) + ); + }) + ); + + // Destroy the existing TeamDomains that were removed + const deletedDomains = existingAllowedDomains.filter( + (x) => !allowedDomains.includes(x.name) + ); + for (const deletedDomain of deletedDomains) { + deletedDomain.destroy({ transaction }); + } + + team.allowedDomains = newAllowedDomains; + } const changes = team.changed(); - const transaction: Transaction = await sequelize.transaction(); - try { const savedTeam = await team.save({ transaction, diff --git a/server/commands/userCreator.test.ts b/server/commands/userCreator.test.ts index 979ea7efc..5555ae7e6 100644 --- a/server/commands/userCreator.test.ts +++ b/server/commands/userCreator.test.ts @@ -1,5 +1,6 @@ +import { TeamDomain } from "@server/models"; import { buildUser, buildTeam, buildInvite } from "@server/test/factories"; -import { flushdb } from "@server/test/support"; +import { flushdb, seed } from "@server/test/support"; import userCreator from "./userCreator"; beforeEach(() => flushdb()); @@ -239,4 +240,68 @@ describe("userCreator", () => { "You need an invite to join this team" ); }); + + it("should create a user from allowed Domain", async () => { + const { admin, team } = await seed(); + await TeamDomain.create({ + teamId: team.id, + name: "example.com", + createdById: admin.id, + }); + + const authenticationProviders = await team.$get("authenticationProviders"); + const authenticationProvider = authenticationProviders[0]; + const result = await userCreator({ + name: "Test Name", + email: "user@example.com", + teamId: team.id, + ip, + authentication: { + authenticationProviderId: authenticationProvider.id, + providerId: "fake-service-id", + accessToken: "123", + scopes: ["read"], + }, + }); + const { user, authentication, isNewUser } = result; + expect(authentication.accessToken).toEqual("123"); + expect(authentication.scopes.length).toEqual(1); + expect(authentication.scopes[0]).toEqual("read"); + expect(user.email).toEqual("user@example.com"); + expect(isNewUser).toEqual(true); + }); + + it("should reject an user when the domain is not allowed", async () => { + const { admin, team } = await seed(); + await TeamDomain.create({ + teamId: team.id, + name: "other.com", + createdById: admin.id, + }); + + const authenticationProviders = await team.$get("authenticationProviders"); + const authenticationProvider = authenticationProviders[0]; + let error; + + try { + await userCreator({ + name: "Bad Domain User", + email: "user@example.com", + teamId: team.id, + ip, + authentication: { + authenticationProviderId: authenticationProvider.id, + providerId: "fake-service-id", + accessToken: "123", + scopes: ["read"], + }, + }); + } catch (err) { + error = err; + } + + expect(error && error.toString()).toContain( + "The domain is not allowed for this team" + ); + }); }); diff --git a/server/commands/userCreator.ts b/server/commands/userCreator.ts index e55d8fad4..b5d47628a 100644 --- a/server/commands/userCreator.ts +++ b/server/commands/userCreator.ts @@ -1,5 +1,5 @@ import { Op } from "sequelize"; -import { InviteRequiredError } from "@server/errors"; +import { DomainNotAllowedError, InviteRequiredError } from "@server/errors"; import { Event, Team, User, UserAuthentication } from "@server/models"; type UserCreatorResult = { @@ -145,7 +145,7 @@ export default async function userCreator({ try { const team = await Team.findByPk(teamId, { - attributes: ["defaultUserRole", "inviteRequired"], + attributes: ["defaultUserRole", "inviteRequired", "id"], transaction, }); @@ -155,6 +155,13 @@ export default async function userCreator({ throw InviteRequiredError(); } + // If the team settings do not allow this domain, + // throw an error and fail user creation. + const domain = email.split("@")[1]; + if (team && !(await team.isDomainAllowed(domain))) { + throw DomainNotAllowedError(); + } + const defaultUserRole = team?.defaultUserRole; const user = await User.create( diff --git a/server/errors.ts b/server/errors.ts index 474c56465..e8437fc40 100644 --- a/server/errors.ts +++ b/server/errors.ts @@ -28,6 +28,14 @@ export function InviteRequiredError( }); } +export function DomainNotAllowedError( + message = "The domain is not allowed for this team" +) { + return httpErrors(403, message, { + id: "domain_not_allowed", + }); +} + export function AdminRequiredError( message = "An admin role is required to access this resource" ) { @@ -130,14 +138,6 @@ export function GoogleWorkspaceRequiredError( }); } -export function GoogleWorkspaceInvalidError( - message = "Google Workspace is invalid" -) { - return httpErrors(400, message, { - id: "hd_not_allowed", - }); -} - export function OIDCMalformedUserInfoError( message = "User profile information malformed" ) { diff --git a/server/migrations/20220419052832-create-team-domains.js b/server/migrations/20220419052832-create-team-domains.js new file mode 100644 index 000000000..3944de739 --- /dev/null +++ b/server/migrations/20220419052832-create-team-domains.js @@ -0,0 +1,87 @@ +"use strict"; + +const { v4 } = require("uuid"); + +module.exports = { + up: async (queryInterface, Sequelize) => { + await queryInterface.sequelize.transaction(async (transaction) => { + await queryInterface.createTable("team_domains", { + id: { + type: Sequelize.UUID, + allowNull: false, + primaryKey: true, + }, + teamId: { + type: Sequelize.UUID, + allowNull: false, + onDelete: "cascade", + references: { + model: "teams", + }, + }, + createdById: { + type: Sequelize.UUID, + allowNull: false, + references: { + model: "users", + }, + }, + name: { + type: Sequelize.STRING, + allowNull: false, + }, + createdAt: { + type: Sequelize.DATE, + allowNull: false, + }, + updatedAt: { + type: Sequelize.DATE, + allowNull: false, + }, + }, { + transaction + }); + + await queryInterface.addIndex("team_domains", ["teamId", "name"], { + transaction, + unique: true, + }); + + const currentAllowedDomainsEnv = process.env.ALLOWED_DOMAINS || process.env.GOOGLE_ALLOWED_DOMAINS; + const currentAllowedDomains = currentAllowedDomainsEnv ? currentAllowedDomainsEnv.split(",") : []; + + if (currentAllowedDomains.length > 0) { + const [adminUserIDs] = await queryInterface.sequelize.query('select id from users where "isAdmin" = true limit 1', { transaction }) + const adminUserID = adminUserIDs[0]?.id + + if (adminUserID) { + const [teams] = await queryInterface.sequelize.query('select id from teams', { transaction }) + const now = new Date(); + + for (const team of teams) { + for (const domain of currentAllowedDomains) { + await queryInterface.sequelize.query(` + INSERT INTO team_domains ("id", "teamId", "createdById", "name", "createdAt", "updatedAt") + VALUES (:id, :teamId, :createdById, :name, :createdAt, :updatedAt) + `, { + replacements: { + id: v4(), + teamId: team.id, + createdById: adminUserID, + name: domain, + createdAt: now, + updatedAt: now, + }, + transaction, + } + ); + } + } + } + } + }); + }, + down: async (queryInterface) => { + return queryInterface.dropTable("team_domains"); + }, +}; diff --git a/server/models/Team.ts b/server/models/Team.ts index b37ed921e..b9192c5ac 100644 --- a/server/models/Team.ts +++ b/server/models/Team.ts @@ -27,6 +27,7 @@ import { publicS3Endpoint, uploadToS3FromUrl } from "@server/utils/s3"; import AuthenticationProvider from "./AuthenticationProvider"; import Collection from "./Collection"; import Document from "./Document"; +import TeamDomain from "./TeamDomain"; import User from "./User"; import ParanoidModel from "./base/ParanoidModel"; import Fix from "./decorators/Fix"; @@ -238,6 +239,15 @@ class Team extends ParanoidModel { return models.map((c) => c.id); }; + isDomainAllowed = async function (domain: string) { + const allowedDomains = (await this.$get("allowedDomains")) || []; + + return ( + allowedDomains.length === 0 || + allowedDomains.map((d: TeamDomain) => d.name).includes(domain) + ); + }; + // associations @HasMany(() => Collection) @@ -252,8 +262,10 @@ class Team extends ParanoidModel { @HasMany(() => AuthenticationProvider) authenticationProviders: AuthenticationProvider[]; - // hooks + @HasMany(() => TeamDomain) + allowedDomains: TeamDomain[]; + // hooks @BeforeSave static uploadAvatar = async (model: Team) => { const endpoint = publicS3Endpoint(); diff --git a/server/models/TeamDomain.ts b/server/models/TeamDomain.ts new file mode 100644 index 000000000..24668d2ec --- /dev/null +++ b/server/models/TeamDomain.ts @@ -0,0 +1,37 @@ +import { + Column, + Table, + BelongsTo, + ForeignKey, + NotEmpty, +} from "sequelize-typescript"; +import Team from "./Team"; +import User from "./User"; +import BaseModel from "./base/BaseModel"; +import Fix from "./decorators/Fix"; + +@Table({ tableName: "team_domains", modelName: "team_domain" }) +@Fix +class TeamDomain extends BaseModel { + @NotEmpty + @Column + name: string; + + // associations + + @BelongsTo(() => Team, "teamId") + team: Team; + + @ForeignKey(() => Team) + @Column + teamId: string; + + @BelongsTo(() => User, "createdById") + createdBy: User; + + @ForeignKey(() => User) + @Column + createdById: string; +} + +export default TeamDomain; diff --git a/server/models/index.ts b/server/models/index.ts index 1d8c8e4e6..16558caf7 100644 --- a/server/models/index.ts +++ b/server/models/index.ts @@ -42,6 +42,8 @@ export { default as Star } from "./Star"; export { default as Team } from "./Team"; +export { default as TeamDomain } from "./TeamDomain"; + export { default as User } from "./User"; export { default as UserAuthentication } from "./UserAuthentication"; diff --git a/server/presenters/team.ts b/server/presenters/team.ts index d7e86c8eb..400941cd1 100644 --- a/server/presenters/team.ts +++ b/server/presenters/team.ts @@ -16,5 +16,6 @@ export default function present(team: Team) { url: team.url, defaultUserRole: team.defaultUserRole, inviteRequired: team.inviteRequired, + allowedDomains: team.allowedDomains.map((d) => d.name), }; } diff --git a/server/routes/api/auth.test.ts b/server/routes/api/auth.test.ts index 75dff53b0..1f025f6ef 100644 --- a/server/routes/api/auth.test.ts +++ b/server/routes/api/auth.test.ts @@ -23,6 +23,7 @@ describe("#auth.info", () => { expect(res.status).toEqual(200); expect(body.data.user.name).toBe(user.name); expect(body.data.team.name).toBe(team.name); + expect(body.data.team.allowedDomains).toEqual([]); }); it("should require the team to not be deleted", async () => { diff --git a/server/routes/api/auth.ts b/server/routes/api/auth.ts index 3072a2f83..5bb478639 100644 --- a/server/routes/api/auth.ts +++ b/server/routes/api/auth.ts @@ -3,7 +3,7 @@ import Router from "koa-router"; import { find } from "lodash"; import { parseDomain, isCustomSubdomain } from "@shared/utils/domains"; import auth from "@server/middlewares/authentication"; -import { Team } from "@server/models"; +import { Team, TeamDomain } from "@server/models"; import { presentUser, presentTeam, presentPolicies } from "@server/presenters"; import { isCustomDomain } from "@server/utils/domains"; import providers from "../auth/providers"; @@ -111,7 +111,9 @@ router.post("auth.config", async (ctx) => { router.post("auth.info", auth(), async (ctx) => { const { user } = ctx.state; - const team = await Team.findByPk(user.teamId); + const team = await Team.findByPk(user.teamId, { + include: [{ model: TeamDomain }], + }); invariant(team, "Team not found"); ctx.body = { diff --git a/server/routes/api/team.test.ts b/server/routes/api/team.test.ts index f89304ef1..f34b58c07 100644 --- a/server/routes/api/team.test.ts +++ b/server/routes/api/team.test.ts @@ -1,4 +1,5 @@ import TestServer from "fetch-test-server"; +import { TeamDomain } from "@server/models"; import webService from "@server/services/web"; import { buildAdmin, buildCollection, buildTeam } from "@server/test/factories"; import { flushdb, seed } from "@server/test/support"; @@ -22,6 +23,81 @@ describe("#team.update", () => { expect(body.data.name).toEqual("New name"); }); + it("should add new allowed Domains, removing empty string values", async () => { + const { admin, team } = await seed(); + const res = await server.post("/api/team.update", { + body: { + token: admin.getJwtToken(), + allowedDomains: ["example.com", "", "example.org", "", ""], + }, + }); + const body = await res.json(); + expect(res.status).toEqual(200); + expect(body.data.allowedDomains).toEqual(["example.com", "example.org"]); + + const teamDomains: TeamDomain[] = await TeamDomain.findAll({ + where: { teamId: team.id }, + }); + expect(teamDomains.map((d) => d.name)).toEqual([ + "example.com", + "example.org", + ]); + }); + + it("should remove old allowed Domains", async () => { + const { admin, team } = await seed(); + const existingTeamDomain = await TeamDomain.create({ + teamId: team.id, + name: "example.com", + createdById: admin.id, + }); + + const res = await server.post("/api/team.update", { + body: { + token: admin.getJwtToken(), + allowedDomains: [], + }, + }); + const body = await res.json(); + expect(res.status).toEqual(200); + expect(body.data.allowedDomains).toEqual([]); + + const teamDomains: TeamDomain[] = await TeamDomain.findAll({ + where: { teamId: team.id }, + }); + expect(teamDomains.map((d) => d.name)).toEqual([]); + + expect(await TeamDomain.findByPk(existingTeamDomain.id)).toBeNull(); + }); + + it("should add new allowed domains and remove old ones", async () => { + const { admin, team } = await seed(); + const existingTeamDomain = await TeamDomain.create({ + teamId: team.id, + name: "example.com", + createdById: admin.id, + }); + + const res = await server.post("/api/team.update", { + body: { + token: admin.getJwtToken(), + allowedDomains: ["example.org", "example.net"], + }, + }); + const body = await res.json(); + expect(res.status).toEqual(200); + expect(body.data.allowedDomains).toEqual(["example.org", "example.net"]); + + const teamDomains: TeamDomain[] = await TeamDomain.findAll({ + where: { teamId: team.id }, + }); + expect(teamDomains.map((d) => d.name).sort()).toEqual( + ["example.org", "example.net"].sort() + ); + + expect(await TeamDomain.findByPk(existingTeamDomain.id)).toBeNull(); + }); + it("should only allow member,viewer or admin as default role", async () => { const { admin } = await seed(); const res = await server.post("/api/team.update", { diff --git a/server/routes/api/team.ts b/server/routes/api/team.ts index a838bb339..871ae3a78 100644 --- a/server/routes/api/team.ts +++ b/server/routes/api/team.ts @@ -1,7 +1,7 @@ import Router from "koa-router"; import teamUpdater from "@server/commands/teamUpdater"; import auth from "@server/middlewares/authentication"; -import { Team } from "@server/models"; +import { Team, TeamDomain } from "@server/models"; import { authorize } from "@server/policies"; import { presentTeam, presentPolicies } from "@server/presenters"; import { assertUuid } from "@server/validation"; @@ -21,10 +21,13 @@ router.post("team.update", auth(), async (ctx) => { defaultCollectionId, defaultUserRole, inviteRequired, + allowedDomains, } = ctx.body; const { user } = ctx.state; - const team = await Team.findByPk(user.teamId); + const team = await Team.findByPk(user.teamId, { + include: [{ model: TeamDomain }], + }); authorize(user, "update", team); if (defaultCollectionId !== undefined && defaultCollectionId !== null) { @@ -44,6 +47,7 @@ router.post("team.update", auth(), async (ctx) => { defaultCollectionId, defaultUserRole, inviteRequired, + allowedDomains, }, user, team, diff --git a/server/routes/auth/providers/google.ts b/server/routes/auth/providers/google.ts index 95e827935..fe5840290 100644 --- a/server/routes/auth/providers/google.ts +++ b/server/routes/auth/providers/google.ts @@ -8,13 +8,9 @@ import accountProvisioner, { AccountProvisionerResult, } from "@server/commands/accountProvisioner"; import env from "@server/env"; -import { - GoogleWorkspaceRequiredError, - GoogleWorkspaceInvalidError, -} from "@server/errors"; +import { GoogleWorkspaceRequiredError } from "@server/errors"; import passportMiddleware from "@server/middlewares/passport"; import { User } from "@server/models"; -import { isDomainAllowed } from "@server/utils/authentication"; import { StateStore } from "@server/utils/passport"; const router = new Router(); @@ -69,10 +65,6 @@ if (GOOGLE_CLIENT_ID && GOOGLE_CLIENT_SECRET) { throw GoogleWorkspaceRequiredError(); } - if (!isDomainAllowed(domain)) { - throw GoogleWorkspaceInvalidError(); - } - const subdomain = domain.split(".")[0]; const teamName = capitalize(subdomain); const result = await accountProvisioner({ diff --git a/server/routes/auth/providers/oidc.ts b/server/routes/auth/providers/oidc.ts index ec9088ac5..3cdc32d49 100644 --- a/server/routes/auth/providers/oidc.ts +++ b/server/routes/auth/providers/oidc.ts @@ -9,7 +9,6 @@ import { AuthenticationError, } from "@server/errors"; import passportMiddleware from "@server/middlewares/passport"; -import { isDomainAllowed } from "@server/utils/authentication"; import { StateStore, request } from "@server/utils/passport"; const router = new Router(); @@ -83,12 +82,6 @@ if (OIDC_CLIENT_ID) { throw OIDCMalformedUserInfoError(); } - if (!isDomainAllowed(domain)) { - throw AuthenticationError( - `Domain ${domain} is not on the whitelist` - ); - } - const subdomain = domain.split(".")[0]; const result = await accountProvisioner({ ip: req.ip, diff --git a/server/test/setup.ts b/server/test/setup.ts index ddcbe6d99..f57020f9a 100644 --- a/server/test/setup.ts +++ b/server/test/setup.ts @@ -7,7 +7,6 @@ process.env.NODE_ENV = "test"; process.env.GOOGLE_CLIENT_ID = "123"; process.env.SLACK_KEY = "123"; process.env.DEPLOYMENT = ""; -process.env.ALLOWED_DOMAINS = "allowed-domain.com"; // NOTE: this require must come after the ENV var override above // so that sequelize uses the test config variables diff --git a/server/utils/authentication.ts b/server/utils/authentication.ts index 1847e589e..76ca0588b 100644 --- a/server/utils/authentication.ts +++ b/server/utils/authentication.ts @@ -6,17 +6,6 @@ import Logger from "@server/logging/logger"; import { User, Event, Team, Collection, View } from "@server/models"; import { getCookieDomain } from "@server/utils/domains"; -export function getAllowedDomains(): string[] { - // GOOGLE_ALLOWED_DOMAINS included here for backwards compatability - const env = process.env.ALLOWED_DOMAINS || process.env.GOOGLE_ALLOWED_DOMAINS; - return env ? env.split(",") : []; -} - -export function isDomainAllowed(domain: string): boolean { - const allowedDomains = getAllowedDomains(); - return allowedDomains.includes(domain) || allowedDomains.length === 0; -} - export async function signIn( ctx: Context, user: User, diff --git a/shared/i18n/locales/en_US/translation.json b/shared/i18n/locales/en_US/translation.json index af0fc460c..2bb10a561 100644 --- a/shared/i18n/locales/en_US/translation.json +++ b/shared/i18n/locales/en_US/translation.json @@ -658,6 +658,10 @@ "Allow authorized {{ authenticationMethods }} users to create new accounts without first receiving an invite": "Allow authorized {{ authenticationMethods }} users to create new accounts without first receiving an invite", "Default role": "Default role", "The default user role for new accounts. Changing this setting does not affect existing user accounts.": "The default user role for new accounts. Changing this setting does not affect existing user accounts.", + "Allowed Domains": "Allowed Domains", + "The domains which should be allowed to create accounts. This applies to both SSO and Email logins. Changing this setting does not affect existing user accounts.": "The domains which should be allowed to create accounts. This applies to both SSO and Email logins. Changing this setting does not affect existing user accounts.", + "Remove domain": "Remove domain", + "Save changes": "Save changes", "Sharing is currently disabled.": "Sharing is currently disabled.", "You can globally enable and disable public document sharing in the security settings.": "You can globally enable and disable public document sharing in the security settings.", "Documents that have been shared are listed below. Anyone that has the public link can access a read-only version of the document until the link has been revoked.": "Documents that have been shared are listed below. Anyone that has the public link can access a read-only version of the document until the link has been revoked.",