multipart middleware (#5809)

* fix: multipart middleware

* fix: reviews
This commit is contained in:
Apoorv Mishra
2023-09-12 10:21:58 +05:30
committed by GitHub
parent 99e3a305d3
commit 401d1ba871
5 changed files with 48 additions and 20 deletions

View File

@@ -0,0 +1,37 @@
import { Next } from "koa";
import { bytesToHumanReadable } from "@shared/utils/files";
import { InvalidRequestError } from "@server/errors";
import { APIContext } from "@server/types";
import { getFileFromRequest } from "@server/utils/koa";
export default function multipart({
maximumFileSize,
}: {
maximumFileSize: number;
}) {
return async function multipartMiddleware(ctx: APIContext, next: Next) {
if (!ctx.is("multipart/form-data")) {
ctx.throw(
InvalidRequestError("Request type must be multipart/form-data")
);
}
const file = getFileFromRequest(ctx.request);
if (!file) {
ctx.throw(InvalidRequestError("Request must include a file parameter"));
}
if (file.size > maximumFileSize) {
ctx.throw(
InvalidRequestError(
`The selected file was larger than the ${bytesToHumanReadable(
maximumFileSize
)} maximum size`
)
);
}
ctx.input = { ...(ctx.input ?? {}), file };
return next();
};
}

View File

@@ -6,7 +6,10 @@ import { APIContext, BaseReq } from "@server/types";
export default function validate<T extends z.ZodType<BaseReq>>(schema: T) {
return async function validateMiddleware(ctx: APIContext, next: Next) {
try {
ctx.input = schema.parse(ctx.request);
ctx.input = {
...(ctx.input ?? {}),
...schema.parse(ctx.request),
};
} catch (err) {
if (err instanceof ZodError) {
const { path, message } = err.issues[0];

View File

@@ -1,8 +1,10 @@
import formidable from "formidable";
import { z } from "zod";
const BaseSchema = z.object({
body: z.unknown(),
query: z.unknown(),
file: z.custom<formidable.File>().optional(),
});
export default BaseSchema;

View File

@@ -8,7 +8,6 @@ import mime from "mime-types";
import { Op, ScopeOptions, WhereOptions } from "sequelize";
import { TeamPreference } from "@shared/types";
import { subtractDate } from "@shared/utils/date";
import { bytesToHumanReadable } from "@shared/utils/files";
import slugify from "@shared/utils/slugify";
import documentCreator from "@server/commands/documentCreator";
import documentImporter from "@server/commands/documentImporter";
@@ -26,6 +25,7 @@ import {
} from "@server/errors";
import Logger from "@server/logging/Logger";
import auth from "@server/middlewares/authentication";
import multipart from "@server/middlewares/multipart";
import { rateLimiter } from "@server/middlewares/rateLimiter";
import { transaction } from "@server/middlewares/transaction";
import validate from "@server/middlewares/validate";
@@ -53,7 +53,6 @@ import {
import { APIContext } from "@server/types";
import { RateLimiterStrategy } from "@server/utils/RateLimiter";
import ZipHelper from "@server/utils/ZipHelper";
import { getFileFromRequest } from "@server/utils/koa";
import parseAttachmentIds from "@server/utils/parseAttachmentIds";
import { getTeamFromContext } from "@server/utils/passport";
import { assertPresent } from "@server/validation";
@@ -1217,26 +1216,11 @@ router.post(
auth(),
rateLimiter(RateLimiterStrategy.TwentyFivePerMinute),
validate(T.DocumentsImportSchema),
multipart({ maximumFileSize: env.MAXIMUM_IMPORT_SIZE }),
transaction(),
async (ctx: APIContext<T.DocumentsImportReq>) => {
if (!ctx.is("multipart/form-data")) {
throw InvalidRequestError("Request type must be multipart/form-data");
}
const { collectionId, parentDocumentId, publish } = ctx.input.body;
const file = getFileFromRequest(ctx.request);
if (!file) {
throw InvalidRequestError("Request must include a file parameter");
}
if (env.MAXIMUM_IMPORT_SIZE && file.size > env.MAXIMUM_IMPORT_SIZE) {
throw InvalidRequestError(
`The selected file was larger than the ${bytesToHumanReadable(
env.MAXIMUM_IMPORT_SIZE
)} maximum size`
);
}
const file = ctx.input.file;
const { transaction } = ctx.state;
const { user } = ctx.state.auth;

View File

@@ -1,4 +1,5 @@
import emojiRegex from "emoji-regex";
import formidable from "formidable";
import isEmpty from "lodash/isEmpty";
import isUUID from "validator/lib/isUUID";
import { z } from "zod";
@@ -272,6 +273,7 @@ export const DocumentsImportSchema = BaseSchema.extend({
/** Import under this parent doc */
parentDocumentId: z.string().uuid().nullish(),
}),
file: z.custom<formidable.File>(),
});
export type DocumentsImportReq = z.infer<typeof DocumentsImportSchema>;