diff --git a/server/middlewares/multipart.ts b/server/middlewares/multipart.ts new file mode 100644 index 000000000..5b0425e41 --- /dev/null +++ b/server/middlewares/multipart.ts @@ -0,0 +1,37 @@ +import { Next } from "koa"; +import { bytesToHumanReadable } from "@shared/utils/files"; +import { InvalidRequestError } from "@server/errors"; +import { APIContext } from "@server/types"; +import { getFileFromRequest } from "@server/utils/koa"; + +export default function multipart({ + maximumFileSize, +}: { + maximumFileSize: number; +}) { + return async function multipartMiddleware(ctx: APIContext, next: Next) { + if (!ctx.is("multipart/form-data")) { + ctx.throw( + InvalidRequestError("Request type must be multipart/form-data") + ); + } + + const file = getFileFromRequest(ctx.request); + if (!file) { + ctx.throw(InvalidRequestError("Request must include a file parameter")); + } + + if (file.size > maximumFileSize) { + ctx.throw( + InvalidRequestError( + `The selected file was larger than the ${bytesToHumanReadable( + maximumFileSize + )} maximum size` + ) + ); + } + + ctx.input = { ...(ctx.input ?? {}), file }; + return next(); + }; +} diff --git a/server/middlewares/validate.ts b/server/middlewares/validate.ts index 24f7e55ef..a36e4d354 100644 --- a/server/middlewares/validate.ts +++ b/server/middlewares/validate.ts @@ -6,7 +6,10 @@ import { APIContext, BaseReq } from "@server/types"; export default function validate>(schema: T) { return async function validateMiddleware(ctx: APIContext, next: Next) { try { - ctx.input = schema.parse(ctx.request); + ctx.input = { + ...(ctx.input ?? {}), + ...schema.parse(ctx.request), + }; } catch (err) { if (err instanceof ZodError) { const { path, message } = err.issues[0]; diff --git a/server/routes/api/BaseSchema.ts b/server/routes/api/BaseSchema.ts index 04d7dba43..f78137358 100644 --- a/server/routes/api/BaseSchema.ts +++ b/server/routes/api/BaseSchema.ts @@ -1,8 +1,10 @@ +import formidable from "formidable"; import { z } from "zod"; const BaseSchema = z.object({ body: z.unknown(), query: z.unknown(), + file: z.custom().optional(), }); export default BaseSchema; diff --git a/server/routes/api/documents/documents.ts b/server/routes/api/documents/documents.ts index eb77701e6..e2ce59799 100644 --- a/server/routes/api/documents/documents.ts +++ b/server/routes/api/documents/documents.ts @@ -8,7 +8,6 @@ import mime from "mime-types"; import { Op, ScopeOptions, WhereOptions } from "sequelize"; import { TeamPreference } from "@shared/types"; import { subtractDate } from "@shared/utils/date"; -import { bytesToHumanReadable } from "@shared/utils/files"; import slugify from "@shared/utils/slugify"; import documentCreator from "@server/commands/documentCreator"; import documentImporter from "@server/commands/documentImporter"; @@ -26,6 +25,7 @@ import { } from "@server/errors"; import Logger from "@server/logging/Logger"; import auth from "@server/middlewares/authentication"; +import multipart from "@server/middlewares/multipart"; import { rateLimiter } from "@server/middlewares/rateLimiter"; import { transaction } from "@server/middlewares/transaction"; import validate from "@server/middlewares/validate"; @@ -53,7 +53,6 @@ import { import { APIContext } from "@server/types"; import { RateLimiterStrategy } from "@server/utils/RateLimiter"; import ZipHelper from "@server/utils/ZipHelper"; -import { getFileFromRequest } from "@server/utils/koa"; import parseAttachmentIds from "@server/utils/parseAttachmentIds"; import { getTeamFromContext } from "@server/utils/passport"; import { assertPresent } from "@server/validation"; @@ -1217,26 +1216,11 @@ router.post( auth(), rateLimiter(RateLimiterStrategy.TwentyFivePerMinute), validate(T.DocumentsImportSchema), + multipart({ maximumFileSize: env.MAXIMUM_IMPORT_SIZE }), transaction(), async (ctx: APIContext) => { - if (!ctx.is("multipart/form-data")) { - throw InvalidRequestError("Request type must be multipart/form-data"); - } - const { collectionId, parentDocumentId, publish } = ctx.input.body; - - const file = getFileFromRequest(ctx.request); - if (!file) { - throw InvalidRequestError("Request must include a file parameter"); - } - - if (env.MAXIMUM_IMPORT_SIZE && file.size > env.MAXIMUM_IMPORT_SIZE) { - throw InvalidRequestError( - `The selected file was larger than the ${bytesToHumanReadable( - env.MAXIMUM_IMPORT_SIZE - )} maximum size` - ); - } + const file = ctx.input.file; const { transaction } = ctx.state; const { user } = ctx.state.auth; diff --git a/server/routes/api/documents/schema.ts b/server/routes/api/documents/schema.ts index 747321383..03f697040 100644 --- a/server/routes/api/documents/schema.ts +++ b/server/routes/api/documents/schema.ts @@ -1,4 +1,5 @@ import emojiRegex from "emoji-regex"; +import formidable from "formidable"; import isEmpty from "lodash/isEmpty"; import isUUID from "validator/lib/isUUID"; import { z } from "zod"; @@ -272,6 +273,7 @@ export const DocumentsImportSchema = BaseSchema.extend({ /** Import under this parent doc */ parentDocumentId: z.string().uuid().nullish(), }), + file: z.custom(), }); export type DocumentsImportReq = z.infer;