diff --git a/backend/src/auth.ts b/backend/src/auth.ts index eac7169..a2ef16e 100644 --- a/backend/src/auth.ts +++ b/backend/src/auth.ts @@ -2,14 +2,25 @@ import express, { Request, Response } from "express"; import crypto from "crypto"; import jwt, { SignOptions } from "jsonwebtoken"; import ms, { type StringValue } from "ms"; -import { PrismaClient, Prisma } from "./generated/client"; +import { Prisma, PrismaClient } from "./generated/client"; import { config } from "./config"; -import { requireAuth, optionalAuth } from "./middleware/auth"; +import { + requireAuth as defaultRequireAuth, + optionalAuth as defaultOptionalAuth, + authModeService as defaultAuthModeService, +} from "./middleware/auth"; import { getCsrfTokenHeader, sanitizeText, validateCsrfToken } from "./security"; import rateLimit, { MemoryStore } from "express-rate-limit"; import { registerAccountRoutes } from "./auth/accountRoutes"; import { registerAdminRoutes } from "./auth/adminRoutes"; import { registerCoreRoutes } from "./auth/coreRoutes"; +import { prisma as defaultPrisma } from "./db/prisma"; +import { + BOOTSTRAP_USER_ID, + DEFAULT_SYSTEM_CONFIG_ID, + type AuthModeService, +} from "./auth/authMode"; +import { getCsrfValidationClientIds } from "./security/csrfClient"; interface JwtPayload { userId: string; @@ -30,379 +41,351 @@ const isJwtPayload = (decoded: unknown): decoded is JwtPayload => { ); }; -const router = express.Router(); -const prisma = new PrismaClient(); - -const BOOTSTRAP_USER_ID = "bootstrap-admin"; -const DEFAULT_SYSTEM_CONFIG_ID = "default"; - -const ensureSystemConfig = async () => { - return prisma.systemConfig.upsert({ - where: { id: DEFAULT_SYSTEM_CONFIG_ID }, - update: {}, - create: { - id: DEFAULT_SYSTEM_CONFIG_ID, - authEnabled: false, - registrationEnabled: false, - authLoginRateLimitEnabled: true, - authLoginRateLimitWindowMs: 15 * 60 * 1000, - authLoginRateLimitMax: 20, - }, - }); +type CreateAuthRouterDeps = { + prisma: PrismaClient; + requireAuth: express.RequestHandler; + optionalAuth: express.RequestHandler; + authModeService: AuthModeService; }; -const ensureAuthEnabled = async (res: Response): Promise => { - const systemConfig = await ensureSystemConfig(); - if (!systemConfig.authEnabled) { - res.status(404).json({ - error: "Not found", - message: "Authentication is disabled", +export const createAuthRouter = (deps: CreateAuthRouterDeps): express.Router => { + const { prisma, requireAuth, optionalAuth, authModeService } = deps; + const router = express.Router(); + + const ensureSystemConfig = authModeService.ensureSystemConfig; + + const ensureAuthEnabled = async (res: Response): Promise => { + const systemConfig = await ensureSystemConfig(); + if (!systemConfig.authEnabled) { + res.status(404).json({ + error: "Not found", + message: "Authentication is disabled", + }); + return false; + } + return true; + }; + + type LoginRateLimitConfig = { + enabled: boolean; + windowMs: number; + max: number; + }; + + const DEFAULT_LOGIN_RATE_LIMIT: LoginRateLimitConfig = { + enabled: true, + windowMs: 15 * 60 * 1000, + max: 20, + }; + + let loginRateLimitConfig: LoginRateLimitConfig = { ...DEFAULT_LOGIN_RATE_LIMIT }; + let loginAttemptLimiter: ReturnType | null = null; + let loginLimiterInitPromise: Promise | null = null; + + const parseLoginRateLimitConfig = ( + systemConfig: Awaited> + ): LoginRateLimitConfig => { + const enabled = + typeof systemConfig.authLoginRateLimitEnabled === "boolean" + ? systemConfig.authLoginRateLimitEnabled + : DEFAULT_LOGIN_RATE_LIMIT.enabled; + const windowMs = + Number.isFinite(Number(systemConfig.authLoginRateLimitWindowMs)) && + Number(systemConfig.authLoginRateLimitWindowMs) > 0 + ? Number(systemConfig.authLoginRateLimitWindowMs) + : DEFAULT_LOGIN_RATE_LIMIT.windowMs; + const max = + Number.isFinite(Number(systemConfig.authLoginRateLimitMax)) && + Number(systemConfig.authLoginRateLimitMax) > 0 + ? Number(systemConfig.authLoginRateLimitMax) + : DEFAULT_LOGIN_RATE_LIMIT.max; + return { enabled, windowMs, max }; + }; + + const resolveAuthIdentifier = (req: Request): string | null => { + const body = (req.body || {}) as Record; + const raw = + (typeof body.email === "string" && body.email) || + (typeof body.username === "string" && body.username) || + (typeof body.identifier === "string" && body.identifier) || + null; + if (!raw) return null; + const trimmed = raw.trim().toLowerCase(); + return trimmed.length > 0 ? trimmed.slice(0, 255) : null; + }; + + const buildLoginAttemptLimiter = (cfg: LoginRateLimitConfig) => { + const store = new MemoryStore(); + const limiter = rateLimit({ + windowMs: cfg.windowMs, + max: cfg.max, + message: { + error: "Too many requests", + message: "Too many login attempts, please try again later", + }, + standardHeaders: true, + legacyHeaders: false, + validate: { + trustProxy: false, + }, + store, + keyGenerator: (req) => { + const identifier = resolveAuthIdentifier(req as Request); + if (identifier) return `login:${identifier}`; + const ip = (req as Request).ip || "unknown"; + return `login-ip:${ip}`; + }, }); - return false; - } - return true; -}; -type LoginRateLimitConfig = { - enabled: boolean; - windowMs: number; - max: number; -}; + loginAttemptLimiter = limiter; + }; -const DEFAULT_LOGIN_RATE_LIMIT: LoginRateLimitConfig = { - enabled: true, - windowMs: 15 * 60 * 1000, - max: 20, -}; + const initLoginAttemptLimiter = async () => { + const systemConfig = await ensureSystemConfig(); + loginRateLimitConfig = parseLoginRateLimitConfig(systemConfig); + buildLoginAttemptLimiter(loginRateLimitConfig); + }; -let loginRateLimitConfig: LoginRateLimitConfig = { ...DEFAULT_LOGIN_RATE_LIMIT }; -let loginAttemptLimiter: ReturnType | null = null; -let loginLimiterInitPromise: Promise | null = null; + const ensureLoginAttemptLimiter = async () => { + if (loginAttemptLimiter) return; + if (!loginLimiterInitPromise) { + loginLimiterInitPromise = initLoginAttemptLimiter().finally(() => { + loginLimiterInitPromise = null; + }); + } + await loginLimiterInitPromise; + }; -const parseLoginRateLimitConfig = (systemConfig: Awaited>): LoginRateLimitConfig => { - const enabled = typeof systemConfig.authLoginRateLimitEnabled === "boolean" ? systemConfig.authLoginRateLimitEnabled : DEFAULT_LOGIN_RATE_LIMIT.enabled; - const windowMs = - Number.isFinite(Number(systemConfig.authLoginRateLimitWindowMs)) && Number(systemConfig.authLoginRateLimitWindowMs) > 0 - ? Number(systemConfig.authLoginRateLimitWindowMs) - : DEFAULT_LOGIN_RATE_LIMIT.windowMs; - const max = - Number.isFinite(Number(systemConfig.authLoginRateLimitMax)) && Number(systemConfig.authLoginRateLimitMax) > 0 - ? Number(systemConfig.authLoginRateLimitMax) - : DEFAULT_LOGIN_RATE_LIMIT.max; - return { enabled, windowMs, max }; -}; + const applyLoginRateLimitConfig = ( + systemConfig: Pick< + Awaited>, + "authLoginRateLimitEnabled" | "authLoginRateLimitWindowMs" | "authLoginRateLimitMax" + > + ): LoginRateLimitConfig => { + loginRateLimitConfig = parseLoginRateLimitConfig( + systemConfig as Awaited> + ); + buildLoginAttemptLimiter(loginRateLimitConfig); + return loginRateLimitConfig; + }; -const resolveAuthIdentifier = (req: Request): string | null => { - const body = (req.body || {}) as Record; - const raw = - (typeof body.email === "string" && body.email) || - (typeof body.username === "string" && body.username) || - (typeof body.identifier === "string" && body.identifier) || - null; - if (!raw) return null; - const trimmed = raw.trim().toLowerCase(); - return trimmed.length > 0 ? trimmed.slice(0, 255) : null; -}; + const resetLoginAttemptKey = async (identifier: string): Promise => { + await ensureLoginAttemptLimiter(); + const key = `login:${identifier}`; + try { + await loginAttemptLimiter?.resetKey(key); + } catch (error) { + if (process.env.NODE_ENV === "development") { + console.debug("Rate limit reset skipped:", error); + } + } + }; -const buildLoginAttemptLimiter = (cfg: LoginRateLimitConfig) => { - const store = new MemoryStore(); - const limiter = rateLimit({ - windowMs: cfg.windowMs, - max: cfg.max, + const loginAttemptRateLimiter = async ( + req: Request, + res: Response, + next: express.NextFunction + ) => { + await ensureLoginAttemptLimiter(); + if (!loginRateLimitConfig.enabled) return next(); + return (loginAttemptLimiter as ReturnType)(req, res, next); + }; + + const accountActionRateLimiter = rateLimit({ + windowMs: 5 * 60 * 1000, + max: 60, message: { error: "Too many requests", - message: "Too many login attempts, please try again later", + message: "Too many requests, please try again later", }, standardHeaders: true, legacyHeaders: false, validate: { trustProxy: false, }, - store, - keyGenerator: (req) => { - const identifier = resolveAuthIdentifier(req as Request); - if (identifier) return `login:${identifier}`; - const ip = (req as Request).ip || "unknown"; - return `login-ip:${ip}`; - }, }); - loginAttemptLimiter = limiter; -}; - -const initLoginAttemptLimiter = async () => { - const systemConfig = await ensureSystemConfig(); - loginRateLimitConfig = parseLoginRateLimitConfig(systemConfig); - buildLoginAttemptLimiter(loginRateLimitConfig); -}; - -const ensureLoginAttemptLimiter = async () => { - if (loginAttemptLimiter) return; - if (!loginLimiterInitPromise) { - loginLimiterInitPromise = initLoginAttemptLimiter().finally(() => { - loginLimiterInitPromise = null; - }); - } - await loginLimiterInitPromise; -}; - -const applyLoginRateLimitConfig = ( - systemConfig: Pick>, "authLoginRateLimitEnabled" | "authLoginRateLimitWindowMs" | "authLoginRateLimitMax"> -): LoginRateLimitConfig => { - loginRateLimitConfig = parseLoginRateLimitConfig(systemConfig as Awaited>); - buildLoginAttemptLimiter(loginRateLimitConfig); - return loginRateLimitConfig; -}; - -const resetLoginAttemptKey = async (identifier: string): Promise => { - await ensureLoginAttemptLimiter(); - const key = `login:${identifier}`; - try { - await loginAttemptLimiter?.resetKey(key); - } catch (error) { - if (process.env.NODE_ENV === "development") { - console.debug("Rate limit reset skipped:", error); - } - } -}; - -const loginAttemptRateLimiter = async (req: Request, res: Response, next: express.NextFunction) => { - await ensureLoginAttemptLimiter(); - if (!loginRateLimitConfig.enabled) return next(); - return (loginAttemptLimiter as ReturnType)(req, res, next); -}; - -const accountActionRateLimiter = rateLimit({ - windowMs: 5 * 60 * 1000, - max: 60, - message: { - error: "Too many requests", - message: "Too many requests, please try again later", - }, - standardHeaders: true, - legacyHeaders: false, - validate: { - trustProxy: false, - }, -}); - -const generateTempPassword = (): string => { - const buf = crypto.randomBytes(18); - return buf.toString("base64").replace(/[+/=]/g, "").slice(0, 24); -}; - -const findUserByIdentifier = async (identifier: string) => { - const trimmed = identifier.trim(); - if (trimmed.length === 0) return null; - - const looksLikeEmail = trimmed.includes("@"); - if (looksLikeEmail) { - return prisma.user.findUnique({ - where: { email: trimmed.toLowerCase() }, - }); - } - - return prisma.user.findFirst({ - where: { - OR: [{ username: trimmed }, { email: trimmed.toLowerCase() }], - }, - }); -}; - -const requireAdmin = ( - req: Request, - res: Response -): req is Request & { user: NonNullable } => { - if (!req.user) { - res.status(401).json({ error: "Unauthorized", message: "User not authenticated" }); - return false; - } - if (req.user.role !== "ADMIN") { - res.status(403).json({ error: "Forbidden", message: "Admin access required" }); - return false; - } - return true; -}; - -const CSRF_CLIENT_COOKIE_NAME = "excalidash-csrf-client"; - -const parseCookies = (cookieHeader: string | undefined): Record => { - if (!cookieHeader) return {}; - const cookies: Record = {}; - for (const part of cookieHeader.split(";")) { - const [rawKey, ...rawValueParts] = part.split("="); - const key = rawKey?.trim(); - if (!key) continue; - const rawValue = rawValueParts.join("=").trim(); - try { - cookies[key] = decodeURIComponent(rawValue); - } catch { - cookies[key] = rawValue; - } - } - return cookies; -}; - -const getCsrfClientCookieValue = (req: Request): string | null => { - const cookies = parseCookies(req.headers.cookie); - const value = cookies[CSRF_CLIENT_COOKIE_NAME]; - if (!value) return null; - if (!/^[A-Za-z0-9_-]{16,128}$/.test(value)) return null; - return value; -}; - -const getLegacyClientId = (req: Request): string => { - const ip = req.ip || req.connection.remoteAddress || "unknown"; - const userAgent = req.headers["user-agent"] || "unknown"; - return `${ip}:${userAgent}`.slice(0, 256); -}; - -const getCsrfValidationClientIds = (req: Request): string[] => { - const candidates: string[] = []; - const cookieValue = getCsrfClientCookieValue(req); - if (cookieValue) { - candidates.push(`cookie:${cookieValue}`); - } - const legacyClientId = getLegacyClientId(req); - if (!candidates.includes(legacyClientId)) { - candidates.push(legacyClientId); - } - return candidates; -}; - -const requireCsrf = (req: Request, res: Response): boolean => { - const headerName = getCsrfTokenHeader(); - const tokenHeader = req.headers[headerName]; - const token = Array.isArray(tokenHeader) ? tokenHeader[0] : tokenHeader; - - if (!token) { - res.status(403).json({ - error: "CSRF token missing", - message: `Missing ${headerName} header`, - }); - return false; - } - - const clientIds = getCsrfValidationClientIds(req); - const isValidToken = clientIds.some((clientId) => validateCsrfToken(clientId, token)); - if (!isValidToken) { - res.status(403).json({ - error: "CSRF token invalid", - message: "Invalid or expired CSRF token. Please refresh and try again.", - }); - return false; - } - - return true; -}; - -const countActiveAdmins = async () => { - return prisma.user.count({ - where: { role: "ADMIN", isActive: true }, - }); -}; - -const generateTokens = ( - userId: string, - email: string, - options?: { impersonatorId?: string } -) => { - const signOptions: SignOptions = { - expiresIn: config.jwtAccessExpiresIn as StringValue, + const generateTempPassword = (): string => { + const buf = crypto.randomBytes(18); + return buf.toString("base64").replace(/[+/=]/g, "").slice(0, 24); }; - const accessToken = jwt.sign( - { userId, email, type: "access", impersonatorId: options?.impersonatorId }, - config.jwtSecret, - signOptions - ); - const refreshSignOptions: SignOptions = { - expiresIn: config.jwtRefreshExpiresIn as StringValue, - }; - const refreshToken = jwt.sign( - { userId, email, type: "refresh", impersonatorId: options?.impersonatorId }, - config.jwtSecret, - refreshSignOptions - ); + const findUserByIdentifier = async (identifier: string) => { + const trimmed = identifier.trim(); + if (trimmed.length === 0) return null; - return { accessToken, refreshToken }; -}; - -const resolveExpiresAt = (expiresIn: string, fallbackMs: number): Date => { - const parsed = ms(expiresIn as StringValue); - const ttlMs = typeof parsed === "number" && parsed > 0 ? parsed : fallbackMs; - return new Date(Date.now() + ttlMs); -}; - -const isMissingRefreshTokenTableError = (error: unknown): boolean => { - if (error instanceof Prisma.PrismaClientKnownRequestError) { - if (error.code === "P2021") { - return true; + const looksLikeEmail = trimmed.includes("@"); + if (looksLikeEmail) { + return prisma.user.findUnique({ + where: { email: trimmed.toLowerCase() }, + }); } - } - const message = - typeof error === "object" && error && "message" in error - ? String((error as any).message) - : ""; - return /no such table:\s*RefreshToken/i.test(message); + return prisma.user.findFirst({ + where: { + OR: [{ username: trimmed }, { email: trimmed.toLowerCase() }], + }, + }); + }; + + const requireAdmin = ( + req: Request, + res: Response + ): req is Request & { user: NonNullable } => { + if (!req.user) { + res.status(401).json({ error: "Unauthorized", message: "User not authenticated" }); + return false; + } + if (req.user.role !== "ADMIN") { + res.status(403).json({ error: "Forbidden", message: "Admin access required" }); + return false; + } + return true; + }; + + const requireCsrf = (req: Request, res: Response): boolean => { + const headerName = getCsrfTokenHeader(); + const tokenHeader = req.headers[headerName]; + const token = Array.isArray(tokenHeader) ? tokenHeader[0] : tokenHeader; + + if (!token) { + res.status(403).json({ + error: "CSRF token missing", + message: `Missing ${headerName} header`, + }); + return false; + } + + const clientIds = getCsrfValidationClientIds(req); + const isValidToken = clientIds.some((clientId) => validateCsrfToken(clientId, token)); + if (!isValidToken) { + res.status(403).json({ + error: "CSRF token invalid", + message: "Invalid or expired CSRF token. Please refresh and try again.", + }); + return false; + } + + return true; + }; + + const countActiveAdmins = async () => { + return prisma.user.count({ + where: { role: "ADMIN", isActive: true }, + }); + }; + + const generateTokens = ( + userId: string, + email: string, + options?: { impersonatorId?: string } + ) => { + const signOptions: SignOptions = { + expiresIn: config.jwtAccessExpiresIn as StringValue, + }; + const accessToken = jwt.sign( + { userId, email, type: "access", impersonatorId: options?.impersonatorId }, + config.jwtSecret, + signOptions + ); + + const refreshSignOptions: SignOptions = { + expiresIn: config.jwtRefreshExpiresIn as StringValue, + }; + const refreshToken = jwt.sign( + { userId, email, type: "refresh", impersonatorId: options?.impersonatorId }, + config.jwtSecret, + refreshSignOptions + ); + + return { accessToken, refreshToken }; + }; + + const resolveExpiresAt = (expiresIn: string, fallbackMs: number): Date => { + const parsed = ms(expiresIn as StringValue); + const ttlMs = typeof parsed === "number" && parsed > 0 ? parsed : fallbackMs; + return new Date(Date.now() + ttlMs); + }; + + const isMissingRefreshTokenTableError = (error: unknown): boolean => { + if (error instanceof Prisma.PrismaClientKnownRequestError) { + if (error.code === "P2021") { + return true; + } + } + + const message = + typeof error === "object" && error && "message" in error + ? String((error as any).message) + : ""; + return /no such table:\s*RefreshToken/i.test(message); + }; + + const getRefreshTokenExpiresAt = (): Date => + resolveExpiresAt(config.jwtRefreshExpiresIn, 7 * 24 * 60 * 60 * 1000); + + registerCoreRoutes({ + router, + prisma, + requireAuth, + optionalAuth, + loginAttemptRateLimiter, + ensureAuthEnabled, + ensureSystemConfig, + findUserByIdentifier, + sanitizeText, + requireCsrf, + isJwtPayload, + config, + generateTokens, + getRefreshTokenExpiresAt, + isMissingRefreshTokenTableError, + bootstrapUserId: BOOTSTRAP_USER_ID, + defaultSystemConfigId: DEFAULT_SYSTEM_CONFIG_ID, + }); + + registerAdminRoutes({ + router, + prisma, + requireAuth, + accountActionRateLimiter, + ensureAuthEnabled, + ensureSystemConfig, + parseLoginRateLimitConfig, + applyLoginRateLimitConfig, + resetLoginAttemptKey, + requireAdmin, + findUserByIdentifier, + countActiveAdmins, + sanitizeText, + generateTempPassword, + generateTokens, + getRefreshTokenExpiresAt, + config, + defaultSystemConfigId: DEFAULT_SYSTEM_CONFIG_ID, + }); + + registerAccountRoutes({ + router, + prisma, + requireAuth, + loginAttemptRateLimiter, + accountActionRateLimiter, + ensureAuthEnabled, + sanitizeText, + config, + generateTokens, + getRefreshTokenExpiresAt, + }); + + return router; }; -const getRefreshTokenExpiresAt = (): Date => - resolveExpiresAt(config.jwtRefreshExpiresIn, 7 * 24 * 60 * 60 * 1000); - -registerCoreRoutes({ - router, - prisma, - requireAuth, - optionalAuth, - loginAttemptRateLimiter, - ensureAuthEnabled, - ensureSystemConfig, - findUserByIdentifier, - sanitizeText, - requireCsrf, - isJwtPayload, - config, - generateTokens, - getRefreshTokenExpiresAt, - isMissingRefreshTokenTableError, - bootstrapUserId: BOOTSTRAP_USER_ID, - defaultSystemConfigId: DEFAULT_SYSTEM_CONFIG_ID, +const authRouter = createAuthRouter({ + prisma: defaultPrisma, + requireAuth: defaultRequireAuth, + optionalAuth: defaultOptionalAuth, + authModeService: defaultAuthModeService, }); -registerAdminRoutes({ - router, - prisma, - requireAuth, - accountActionRateLimiter, - ensureAuthEnabled, - ensureSystemConfig, - parseLoginRateLimitConfig, - applyLoginRateLimitConfig, - resetLoginAttemptKey, - requireAdmin, - findUserByIdentifier, - countActiveAdmins, - sanitizeText, - generateTempPassword, - generateTokens, - getRefreshTokenExpiresAt, - config, - defaultSystemConfigId: DEFAULT_SYSTEM_CONFIG_ID, -}); - -registerAccountRoutes({ - router, - prisma, - requireAuth, - loginAttemptRateLimiter, - accountActionRateLimiter, - ensureAuthEnabled, - sanitizeText, - config, - generateTokens, - getRefreshTokenExpiresAt, -}); - -export default router; +export default authRouter; diff --git a/backend/src/auth/authMode.ts b/backend/src/auth/authMode.ts new file mode 100644 index 0000000..504e51c --- /dev/null +++ b/backend/src/auth/authMode.ts @@ -0,0 +1,82 @@ +import { PrismaClient } from "../generated/client"; + +export const BOOTSTRAP_USER_ID = "bootstrap-admin"; +export const DEFAULT_SYSTEM_CONFIG_ID = "default"; + +type AuthEnabledCache = { + value: boolean; + fetchedAt: number; +}; + +export type AuthModeService = ReturnType; + +export const createAuthModeService = ( + prisma: PrismaClient, + options?: { authEnabledTtlMs?: number } +) => { + const authEnabledTtlMs = options?.authEnabledTtlMs ?? 5000; + let authEnabledCache: AuthEnabledCache | null = null; + + const ensureSystemConfig = async () => { + return prisma.systemConfig.upsert({ + where: { id: DEFAULT_SYSTEM_CONFIG_ID }, + update: {}, + create: { + id: DEFAULT_SYSTEM_CONFIG_ID, + authEnabled: false, + registrationEnabled: false, + authLoginRateLimitEnabled: true, + authLoginRateLimitWindowMs: 15 * 60 * 1000, + authLoginRateLimitMax: 20, + }, + }); + }; + + const getAuthEnabled = async (): Promise => { + const now = Date.now(); + if (authEnabledCache && now - authEnabledCache.fetchedAt < authEnabledTtlMs) { + return authEnabledCache.value; + } + + const systemConfig = await ensureSystemConfig(); + authEnabledCache = { value: systemConfig.authEnabled, fetchedAt: now }; + return systemConfig.authEnabled; + }; + + const clearAuthEnabledCache = () => { + authEnabledCache = null; + }; + + const getBootstrapActingUser = async () => { + return prisma.user.upsert({ + where: { id: BOOTSTRAP_USER_ID }, + update: {}, + create: { + id: BOOTSTRAP_USER_ID, + email: "bootstrap@excalidash.local", + username: null, + passwordHash: "", + name: "Bootstrap Admin", + role: "ADMIN", + mustResetPassword: true, + isActive: false, + }, + select: { + id: true, + username: true, + email: true, + name: true, + role: true, + mustResetPassword: true, + isActive: true, + }, + }); + }; + + return { + ensureSystemConfig, + getAuthEnabled, + clearAuthEnabledCache, + getBootstrapActingUser, + }; +}; diff --git a/backend/src/db/prisma.ts b/backend/src/db/prisma.ts new file mode 100644 index 0000000..809095e --- /dev/null +++ b/backend/src/db/prisma.ts @@ -0,0 +1,14 @@ +import { PrismaClient } from "../generated/client"; + +declare global { + // eslint-disable-next-line no-var + var __excalidashPrisma: PrismaClient | undefined; +} + +const prismaClient = globalThis.__excalidashPrisma ?? new PrismaClient(); + +if (process.env.NODE_ENV !== "production") { + globalThis.__excalidashPrisma = prismaClient; +} + +export { prismaClient as prisma }; diff --git a/backend/src/index.ts b/backend/src/index.ts index 1279ba9..6787d57 100644 --- a/backend/src/index.ts +++ b/backend/src/index.ts @@ -19,19 +19,18 @@ import { sanitizeSvg, elementSchema, appStateSchema, - createCsrfToken, - validateCsrfToken, - getCsrfTokenHeader, - getOriginFromReferer, } from "./security"; -import jwt from "jsonwebtoken"; import { config } from "./config"; -import { requireAuth } from "./middleware/auth"; +import { authModeService, requireAuth } from "./middleware/auth"; import { errorHandler, asyncHandler } from "./middleware/errorHandler"; import authRouter from "./auth"; import { logAuditEvent } from "./utils/audit"; import { registerDashboardRoutes } from "./routes/dashboard"; import { registerImportExportRoutes } from "./routes/importExport"; +import { prisma } from "./db/prisma"; +import { createDrawingsCacheStore } from "./server/drawingsCache"; +import { registerCsrfProtection } from "./server/csrf"; +import { registerSocketHandlers } from "./server/socket"; const backendRoot = path.resolve(__dirname, "../"); console.log("Resolved DATABASE_URL:", process.env.DATABASE_URL); @@ -135,7 +134,6 @@ const io = new Server(httpServer, { }, maxHttpBufferSize: 1e8, }); -const prisma = new PrismaClient(); const parseJsonField = ( rawValue: string | null | undefined, fallback: T @@ -159,48 +157,12 @@ const DRAWINGS_CACHE_TTL_MS = (() => { } return parsed; })(); -type DrawingsCacheEntry = { body: Buffer; expiresAt: number }; -const drawingsCache = new Map(); - -const buildDrawingsCacheKey = (keyParts: { - userId: string; - searchTerm: string; - collectionFilter: string; - includeData: boolean; - sortField: "name" | "createdAt" | "updatedAt"; - sortDirection: "asc" | "desc"; -}) => - JSON.stringify([ - keyParts.userId, - keyParts.searchTerm, - keyParts.collectionFilter, - keyParts.includeData ? "full" : "summary", - keyParts.sortField, - keyParts.sortDirection, - ]); - -const getCachedDrawingsBody = (key: string): Buffer | null => { - const entry = drawingsCache.get(key); - if (!entry) return null; - if (Date.now() > entry.expiresAt) { - drawingsCache.delete(key); - return null; - } - return entry.body; -}; - -const cacheDrawingsResponse = (key: string, payload: unknown): Buffer => { - const body = Buffer.from(JSON.stringify(payload)); - drawingsCache.set(key, { - body, - expiresAt: Date.now() + DRAWINGS_CACHE_TTL_MS, - }); - return body; -}; - -const invalidateDrawingsCache = () => { - drawingsCache.clear(); -}; +const { + buildDrawingsCacheKey, + getCachedDrawingsBody, + cacheDrawingsResponse, + invalidateDrawingsCache, +} = createDrawingsCacheStore(DRAWINGS_CACHE_TTL_MS); const getUserTrashCollectionId = (userId: string): string => `trash:${userId}`; @@ -230,15 +192,6 @@ const ensureTrashCollection = async ( }); }; -setInterval(() => { - const now = Date.now(); - for (const [key, entry] of drawingsCache.entries()) { - if (now > entry.expiresAt) { - drawingsCache.delete(key); - } - } -}, 60_000).unref(); - const PORT = config.port; const upload = multer({ @@ -350,18 +303,8 @@ app.use((req, res, next) => { next(); }); -const requestCounts = new Map(); const RATE_LIMIT_WINDOW = 15 * 60 * 1000; -setInterval(() => { - const now = Date.now(); - for (const [ip, data] of requestCounts.entries()) { - if (now > data.resetTime) { - requestCounts.delete(ip); - } - } -}, 5 * 60 * 1000).unref(); - // General rate limiting with express-rate-limit const generalRateLimiter = rateLimit({ windowMs: RATE_LIMIT_WINDOW, @@ -382,252 +325,11 @@ const generalRateLimiter = rateLimit({ app.use(generalRateLimiter); -// CSRF Protection Middleware -// Generates a unique client ID based on IP and User-Agent for token association -const CSRF_CLIENT_COOKIE_NAME = "excalidash-csrf-client"; -const CSRF_CLIENT_COOKIE_MAX_AGE_SECONDS = 60 * 60 * 24 * 30; // 30 days - -const parseCookies = (cookieHeader: string | undefined): Record => { - if (!cookieHeader) return {}; - const cookies: Record = {}; - for (const part of cookieHeader.split(";")) { - const [rawKey, ...rawValueParts] = part.split("="); - const key = rawKey?.trim(); - if (!key) continue; - const rawValue = rawValueParts.join("=").trim(); - try { - cookies[key] = decodeURIComponent(rawValue); - } catch { - cookies[key] = rawValue; - } - } - return cookies; -}; - -const getCsrfClientCookieValue = (req: express.Request): string | null => { - const cookies = parseCookies(req.headers.cookie); - const value = cookies[CSRF_CLIENT_COOKIE_NAME]; - if (!value) return null; - if (!/^[A-Za-z0-9_-]{16,128}$/.test(value)) return null; - return value; -}; - -const requestUsesHttps = (req: express.Request): boolean => { - if (req.secure) return true; - const forwardedProto = req.headers["x-forwarded-proto"]; - const raw = Array.isArray(forwardedProto) ? forwardedProto[0] : forwardedProto; - const firstHop = String(raw || "") - .split(",")[0] - .trim() - .toLowerCase(); - return firstHop === "https"; -}; - -const setCsrfClientCookie = (req: express.Request, res: express.Response, value: string): void => { - const secure = requestUsesHttps(req) ? "; Secure" : ""; - res.append( - "Set-Cookie", - `${CSRF_CLIENT_COOKIE_NAME}=${encodeURIComponent( - value - )}; Path=/; HttpOnly; SameSite=Lax; Max-Age=${CSRF_CLIENT_COOKIE_MAX_AGE_SECONDS}${secure}` - ); -}; - -const getLegacyClientId = (req: express.Request): string => { - const ip = req.ip || req.connection.remoteAddress || "unknown"; - const userAgent = req.headers["user-agent"] || "unknown"; - return `${ip}:${userAgent}`.slice(0, 256); -}; - -const getClientIdForTokenIssue = ( - req: express.Request, - res: express.Response -): { clientId: string; strategy: "cookie" | "legacy-bootstrap" } => { - const existingCookieValue = getCsrfClientCookieValue(req); - if (existingCookieValue) { - return { - clientId: `cookie:${existingCookieValue}`, - strategy: "cookie", - }; - } - - // No cookie presented by client yet: - // - issue a token bound to legacy identity for compatibility with non-cookie clients - // - still set a cookie so subsequent browser requests can transition to cookie-bound tokens - const generatedCookieValue = uuidv4().replace(/-/g, ""); - setCsrfClientCookie(req, res, generatedCookieValue); - return { - clientId: getLegacyClientId(req), - strategy: "legacy-bootstrap", - }; -}; - -const getClientIdCandidatesForValidation = (req: express.Request): string[] => { - const candidates: string[] = []; - const cookieValue = getCsrfClientCookieValue(req); - if (cookieValue) { - candidates.push(`cookie:${cookieValue}`); - } - - const legacyClientId = getLegacyClientId(req); - if (!candidates.includes(legacyClientId)) { - candidates.push(legacyClientId); - } - - return candidates; -}; - -const getClientIdForTokenIssueDebug = ( - req: express.Request, - res: express.Response -): string => { - const { clientId, strategy } = getClientIdForTokenIssue(req, res); - - // Debug logging for CSRF troubleshooting (issue #38) - if (process.env.DEBUG_CSRF === "true") { - const validationCandidates = getClientIdCandidatesForValidation(req); - const ip = req.ip || req.connection.remoteAddress || "unknown"; - console.log("[CSRF DEBUG] getClientId", { - method: req.method, - path: req.path, - ip, - remoteAddress: req.connection.remoteAddress, - "x-forwarded-for": req.headers["x-forwarded-for"], - "x-real-ip": req.headers["x-real-ip"], - hasCsrfCookie: Boolean(getCsrfClientCookieValue(req)), - clientIdPreview: clientId.slice(0, 60) + "...", - trustProxySetting: req.app.get("trust proxy"), - strategy, - validationCandidatesPreview: validationCandidates.map((candidate) => - `${candidate.slice(0, 60)}...` - ), - }); - } - - return clientId; -}; - -// Rate limiter specifically for CSRF token generation to prevent store exhaustion -const csrfRateLimit = new Map(); -const CSRF_RATE_LIMIT_WINDOW = 60 * 1000; // 1 minute -let csrfCleanupCounter = 0; -const CSRF_MAX_REQUESTS = (() => { - const parsed = Number(process.env.CSRF_MAX_REQUESTS); - if (!Number.isFinite(parsed) || parsed <= 0) { - return 60; // 1 per second average - } - return parsed; -})(); - -// CSRF token endpoint - clients should call this to get a token -app.get("/csrf-token", (req, res) => { - const ip = req.ip || req.connection.remoteAddress || "unknown"; - const now = Date.now(); - const clientLimit = csrfRateLimit.get(ip); - - if (clientLimit && now < clientLimit.resetTime) { - if (clientLimit.count >= CSRF_MAX_REQUESTS) { - return res.status(429).json({ - error: "Rate limit exceeded", - message: "Too many CSRF token requests", - }); - } - clientLimit.count++; - } else { - csrfRateLimit.set(ip, { count: 1, resetTime: now + CSRF_RATE_LIMIT_WINDOW }); - } - - // Cleanup every 100 requests. - csrfCleanupCounter += 1; - if (csrfCleanupCounter % 100 === 0) { - for (const [key, data] of csrfRateLimit.entries()) { - if (now > data.resetTime) csrfRateLimit.delete(key); - } - } - - const clientId = getClientIdForTokenIssueDebug(req, res); - const token = createCsrfToken(clientId); - - res.json({ - token, - header: getCsrfTokenHeader() - }); -}); - -// CSRF validation middleware for state-changing requests -const csrfProtectionMiddleware = ( - req: express.Request, - res: express.Response, - next: express.NextFunction -) => { - // Skip CSRF validation for safe methods (GET, HEAD, OPTIONS) - // Note: /csrf-token is a GET endpoint, so it's automatically exempt - const safeMethods = ["GET", "HEAD", "OPTIONS"]; - if (safeMethods.includes(req.method)) { - return next(); - } - - // Origin/Referer check for defense in depth - const origin = req.headers["origin"]; - const referer = req.headers["referer"]; - - // If Origin is present, it must match allowed origins - const originValue = Array.isArray(origin) ? origin[0] : origin; - const refererValue = Array.isArray(referer) ? referer[0] : referer; - - if (originValue) { - if (!isAllowedOrigin(originValue)) { - return res.status(403).json({ - error: "CSRF origin mismatch", - message: "Origin not allowed", - }); - } - } else if (refererValue) { - // If no Origin but Referer exists, validate its *origin* (avoid prefix bypass) - const refererOrigin = getOriginFromReferer(refererValue); - if (!refererOrigin || !isAllowedOrigin(refererOrigin)) { - return res.status(403).json({ - error: "CSRF referer mismatch", - message: "Referer not allowed", - }); - } - } - // Note: If neither Origin nor Referer is present, we proceed to token check. - // Some legitimate clients/proxies might strip these, so we don't block strictly on their absence, - // but relying on the token is the primary defense. - - const clientIdCandidates = getClientIdCandidatesForValidation(req); - const headerName = getCsrfTokenHeader(); - const tokenHeader = req.headers[headerName]; - const token = Array.isArray(tokenHeader) ? tokenHeader[0] : tokenHeader; - - if (!token) { - return res.status(403).json({ - error: "CSRF token missing", - message: `Missing ${headerName} header`, - }); - } - - const isValidToken = clientIdCandidates.some((clientId) => - validateCsrfToken(clientId, token) - ); - if (!isValidToken) { - return res.status(403).json({ - error: "CSRF token invalid", - message: "Invalid or expired CSRF token. Please refresh and try again.", - }); - } - - next(); -}; - -// Apply CSRF protection to all routes (except auth endpoints) -app.use((req, res, next) => { - // Skip CSRF for auth endpoints - if (req.path.startsWith("/auth/")) { - return next(); - } - csrfProtectionMiddleware(req, res, next); +registerCsrfProtection({ + app, + isAllowedOrigin, + maxRequestsPerWindow: config.csrfMaxRequests, + enableDebugLogging: process.env.DEBUG_CSRF === "true", }); // Authentication routes (no CSRF required, uses JWT) @@ -843,223 +545,11 @@ const removeFileIfExists = async (filePath?: string) => { } }; -interface User { - id: string; - name: string; - initials: string; - color: string; - socketId: string; - isActive: boolean; -} - -const roomUsers = new Map(); - -// Track which authenticated user owns each socket for authorization checks -const socketUserMap = new Map(); - -const toPresenceName = (value: unknown): string => { - if (typeof value !== "string") return "User"; - const trimmed = value.trim().slice(0, 120); - return trimmed.length > 0 ? trimmed : "User"; -}; - -const toPresenceInitials = (name: string): string => { - const words = name - .split(/\s+/) - .map((part) => part.trim()) - .filter((part) => part.length > 0); - if (words.length === 0) return "U"; - const first = words[0]?.[0] ?? ""; - const second = words.length > 1 ? words[1]?.[0] ?? "" : ""; - const initials = `${first}${second}`.toUpperCase().slice(0, 2); - return initials.length > 0 ? initials : "U"; -}; - -const toPresenceColor = (value: unknown): string => { - if (typeof value !== "string") return "#4f46e5"; - const trimmed = value.trim(); - if (/^#[0-9a-fA-F]{3,8}$/.test(trimmed)) { - return trimmed; - } - return "#4f46e5"; -}; - -/** - * Verify JWT from Socket.io auth and check if auth is required. - * When auth is disabled (single-user mode), all connections are allowed. - */ -const getSocketAuthUserId = async (token?: string): Promise => { - // Check if auth is enabled - const systemConfig = await prisma.systemConfig.findUnique({ - where: { id: "default" }, - select: { authEnabled: true }, - }); - - if (!systemConfig || !systemConfig.authEnabled) { - // Auth disabled: allow all connections (single-user / bootstrap mode) - return "bootstrap-admin"; - } - - // Auth enabled: require valid JWT - if (!token) return null; - - try { - const decoded = jwt.verify(token, config.jwtSecret) as Record; - if ( - typeof decoded.userId !== "string" || - typeof decoded.email !== "string" || - decoded.type !== "access" - ) { - return null; - } - - // Verify user is still active - const user = await prisma.user.findUnique({ - where: { id: decoded.userId }, - select: { id: true, isActive: true }, - }); - - if (!user || !user.isActive) return null; - return user.id; - } catch { - return null; - } -}; - -io.use(async (socket, next) => { - try { - const token = socket.handshake.auth?.token as string | undefined; - const userId = await getSocketAuthUserId(token); - - if (!userId) { - return next(new Error("Authentication required")); - } - - socketUserMap.set(socket.id, userId); - next(); - } catch { - next(new Error("Authentication failed")); - } -}); - -io.on("connection", (socket) => { - const authenticatedUserId = socketUserMap.get(socket.id); - const authorizedDrawingIds = new Set(); - - socket.on( - "join-room", - async ({ - drawingId, - user, - }: { - drawingId: string; - user: Omit; - }) => { - try { - // Verify the authenticated user owns this drawing - if (authenticatedUserId) { - const drawing = await prisma.drawing.findFirst({ - where: { id: drawingId, userId: authenticatedUserId }, - select: { id: true }, - }); - - if (!drawing) { - socket.emit("error", { message: "You do not have access to this drawing" }); - return; - } - } - - const roomId = `drawing_${drawingId}`; - socket.join(roomId); - authorizedDrawingIds.add(drawingId); - - let trustedUserId = - typeof user?.id === "string" && user.id.trim().length > 0 - ? user.id.trim().slice(0, 200) - : socket.id; - let trustedName = toPresenceName(user?.name); - - // In auth-enabled mode, identity should come from the authenticated account. - if (authenticatedUserId && authenticatedUserId !== "bootstrap-admin") { - const account = await prisma.user.findUnique({ - where: { id: authenticatedUserId }, - select: { id: true, name: true }, - }); - if (account) { - trustedUserId = account.id; - trustedName = toPresenceName(account.name); - } - } - - const newUser: User = { - id: trustedUserId, - name: trustedName, - initials: toPresenceInitials(trustedName), - color: toPresenceColor(user?.color), - socketId: socket.id, - isActive: true, - }; - - const currentUsers = roomUsers.get(roomId) || []; - const filteredUsers = currentUsers.filter((u) => u.id !== newUser.id); - filteredUsers.push(newUser); - roomUsers.set(roomId, filteredUsers); - - io.to(roomId).emit("presence-update", filteredUsers); - } catch (err) { - console.error("Error in join-room handler:", err); - socket.emit("error", { message: "Failed to join room" }); - } - } - ); - - socket.on("cursor-move", (data) => { - const drawingId = typeof data?.drawingId === "string" ? data.drawingId : null; - if (!drawingId || !authorizedDrawingIds.has(drawingId)) { - return; - } - const roomId = `drawing_${drawingId}`; - socket.volatile.to(roomId).emit("cursor-move", data); - }); - - socket.on("element-update", (data) => { - const drawingId = typeof data?.drawingId === "string" ? data.drawingId : null; - if (!drawingId || !authorizedDrawingIds.has(drawingId)) { - return; - } - const roomId = `drawing_${drawingId}`; - socket.to(roomId).emit("element-update", data); - }); - - socket.on( - "user-activity", - ({ drawingId, isActive }: { drawingId: string; isActive: boolean }) => { - if (!authorizedDrawingIds.has(drawingId)) { - return; - } - const roomId = `drawing_${drawingId}`; - const users = roomUsers.get(roomId); - if (users) { - const user = users.find((u) => u.socketId === socket.id); - if (user) { - user.isActive = isActive; - io.to(roomId).emit("presence-update", users); - } - } - } - ); - - socket.on("disconnect", () => { - socketUserMap.delete(socket.id); - roomUsers.forEach((users, roomId) => { - const index = users.findIndex((u) => u.socketId === socket.id); - if (index !== -1) { - users.splice(index, 1); - roomUsers.set(roomId, users); - io.to(roomId).emit("presence-update", users); - } - }); - }); +registerSocketHandlers({ + io, + prisma, + authModeService, + jwtSecret: config.jwtSecret, }); app.get("/health", (req, res) => { diff --git a/backend/src/middleware/auth.ts b/backend/src/middleware/auth.ts index a51d82b..40554ae 100644 --- a/backend/src/middleware/auth.ts +++ b/backend/src/middleware/auth.ts @@ -2,81 +2,8 @@ import { Request, Response, NextFunction } from "express"; import jwt from "jsonwebtoken"; import { config } from "../config"; import { PrismaClient } from "../generated/client"; - -const prisma = new PrismaClient(); -const DEFAULT_SYSTEM_CONFIG_ID = "default"; -const BOOTSTRAP_USER_ID = "bootstrap-admin"; - -type AuthEnabledCache = { - value: boolean; - fetchedAt: number; -}; - -let authEnabledCache: AuthEnabledCache | null = null; -const AUTH_ENABLED_TTL_MS = 5000; - -const getAuthEnabled = async (): Promise => { - const now = Date.now(); - if (authEnabledCache && now - authEnabledCache.fetchedAt < AUTH_ENABLED_TTL_MS) { - return authEnabledCache.value; - } - - let systemConfig = await prisma.systemConfig.findUnique({ - where: { id: DEFAULT_SYSTEM_CONFIG_ID }, - select: { authEnabled: true }, - }); - - if (!systemConfig) { - try { - systemConfig = await prisma.systemConfig.create({ - data: { - id: DEFAULT_SYSTEM_CONFIG_ID, - authEnabled: false, - registrationEnabled: false, - }, - select: { authEnabled: true }, - }); - } catch { - // Handle race from concurrent initialization. - systemConfig = await prisma.systemConfig.findUnique({ - where: { id: DEFAULT_SYSTEM_CONFIG_ID }, - select: { authEnabled: true }, - }); - if (!systemConfig) { - throw new Error("Failed to initialize system config"); - } - } - } - - authEnabledCache = { value: systemConfig.authEnabled, fetchedAt: now }; - return systemConfig.authEnabled; -}; - -const getBootstrapActingUser = async () => { - return prisma.user.upsert({ - where: { id: BOOTSTRAP_USER_ID }, - update: {}, - create: { - id: BOOTSTRAP_USER_ID, - email: "bootstrap@excalidash.local", - username: null, - passwordHash: "", - name: "Bootstrap Admin", - role: "ADMIN", - mustResetPassword: true, - isActive: false, - }, - select: { - id: true, - username: true, - email: true, - name: true, - role: true, - mustResetPassword: true, - isActive: true, - }, - }); -}; +import { prisma as defaultPrisma } from "../db/prisma"; +import { createAuthModeService, type AuthModeService } from "../auth/authMode"; // Extend Express Request type to include user declare global { @@ -161,150 +88,97 @@ const isAllowedWhileMustResetPassword = (req: Request): boolean => { return false; }; -export const requireAuth = async ( - req: Request, - res: Response, - next: NextFunction -): Promise => { - // Single-user mode: authentication disabled -> treat all requests as the bootstrap user. - try { - const authEnabled = await getAuthEnabled(); - if (!authEnabled) { - const user = await getBootstrapActingUser(); - req.user = { - id: user.id, - username: user.username, - email: user.email, - name: user.name, - role: user.role, - mustResetPassword: user.mustResetPassword, - }; - return next(); - } - } catch (error) { - console.error("Error reading auth mode:", error); - res.status(500).json({ - error: "Internal server error", - message: "Failed to read authentication mode", - }); - return; - } - - const token = extractToken(req); - - if (!token) { - res.status(401).json({ - error: "Unauthorized", - message: "Authentication token required", - }); - return; - } - - const payload = verifyToken(token); - - if (!payload) { - res.status(401).json({ - error: "Unauthorized", - message: "Invalid or expired token", - }); - return; - } - - // Verify user still exists and is active - try { - const user = await prisma.user.findUnique({ - where: { id: payload.userId }, - select: { - id: true, - username: true, - email: true, - name: true, - role: true, - mustResetPassword: true, - isActive: true, - }, - }); - - if (!user || !user.isActive) { - res.status(401).json({ - error: "Unauthorized", - message: "User account not found or inactive", - }); - return; - } - - if (user.mustResetPassword && !isAllowedWhileMustResetPassword(req)) { - res.status(403).json({ - error: "Forbidden", - code: "MUST_RESET_PASSWORD", - message: "You must reset your password before using the app", - }); - return; - } - - // Attach user to request - req.user = { - id: user.id, - username: user.username, - email: user.email, - name: user.name, - role: user.role, - mustResetPassword: user.mustResetPassword, - impersonatorId: payload.impersonatorId, - }; - - next(); - } catch (error) { - console.error("Error verifying user:", error); - res.status(500).json({ - error: "Internal server error", - message: "Failed to verify user", - }); - } +export type AuthMiddlewareDeps = { + prisma: PrismaClient; + authModeService: AuthModeService; }; -export const optionalAuth = async ( - req: Request, - res: Response, - next: NextFunction -): Promise => { - try { - const authEnabled = await getAuthEnabled(); - if (!authEnabled) { - return next(); +export const createAuthMiddleware = ({ + prisma, + authModeService, +}: AuthMiddlewareDeps) => { + const requireAuth = async ( + req: Request, + res: Response, + next: NextFunction + ): Promise => { + // Single-user mode: authentication disabled -> treat all requests as the bootstrap user. + try { + const authEnabled = await authModeService.getAuthEnabled(); + if (!authEnabled) { + const user = await authModeService.getBootstrapActingUser(); + req.user = { + id: user.id, + username: user.username, + email: user.email, + name: user.name, + role: user.role, + mustResetPassword: user.mustResetPassword, + }; + return next(); + } + } catch (error) { + console.error("Error reading auth mode:", error); + res.status(500).json({ + error: "Internal server error", + message: "Failed to read authentication mode", + }); + return; } - } catch (error) { - console.error("Error reading auth mode:", error); - return next(); - } - const token = extractToken(req); + const token = extractToken(req); - if (!token) { - return next(); - } + if (!token) { + res.status(401).json({ + error: "Unauthorized", + message: "Authentication token required", + }); + return; + } - const payload = verifyToken(token); + const payload = verifyToken(token); - if (!payload) { - return next(); - } + if (!payload) { + res.status(401).json({ + error: "Unauthorized", + message: "Invalid or expired token", + }); + return; + } - try { - const user = await prisma.user.findUnique({ - where: { id: payload.userId }, - select: { - id: true, - username: true, - email: true, - name: true, - role: true, - mustResetPassword: true, - isActive: true, - }, - }); + // Verify user still exists and is active + try { + const user = await prisma.user.findUnique({ + where: { id: payload.userId }, + select: { + id: true, + username: true, + email: true, + name: true, + role: true, + mustResetPassword: true, + isActive: true, + }, + }); - if (user && user.isActive) { + if (!user || !user.isActive) { + res.status(401).json({ + error: "Unauthorized", + message: "User account not found or inactive", + }); + return; + } + + if (user.mustResetPassword && !isAllowedWhileMustResetPassword(req)) { + res.status(403).json({ + error: "Forbidden", + code: "MUST_RESET_PASSWORD", + message: "You must reset your password before using the app", + }); + return; + } + + // Attach user to request req.user = { id: user.id, username: user.username, @@ -314,11 +188,89 @@ export const optionalAuth = async ( mustResetPassword: user.mustResetPassword, impersonatorId: payload.impersonatorId, }; - } - } catch (error) { - // Silently fail for optional auth - console.error("Error in optional auth:", error); - } - next(); + next(); + } catch (error) { + console.error("Error verifying user:", error); + res.status(500).json({ + error: "Internal server error", + message: "Failed to verify user", + }); + } + }; + + const optionalAuth = async ( + req: Request, + res: Response, + next: NextFunction + ): Promise => { + try { + const authEnabled = await authModeService.getAuthEnabled(); + if (!authEnabled) { + return next(); + } + } catch (error) { + console.error("Error reading auth mode:", error); + return next(); + } + + const token = extractToken(req); + + if (!token) { + return next(); + } + + const payload = verifyToken(token); + + if (!payload) { + return next(); + } + + try { + const user = await prisma.user.findUnique({ + where: { id: payload.userId }, + select: { + id: true, + username: true, + email: true, + name: true, + role: true, + mustResetPassword: true, + isActive: true, + }, + }); + + if (user && user.isActive) { + req.user = { + id: user.id, + username: user.username, + email: user.email, + name: user.name, + role: user.role, + mustResetPassword: user.mustResetPassword, + impersonatorId: payload.impersonatorId, + }; + } + } catch (error) { + // Silently fail for optional auth + console.error("Error in optional auth:", error); + } + + next(); + }; + + return { + requireAuth, + optionalAuth, + }; }; + +const defaultAuthModeService = createAuthModeService(defaultPrisma); +const defaultAuthMiddleware = createAuthMiddleware({ + prisma: defaultPrisma, + authModeService: defaultAuthModeService, +}); + +export const authModeService = defaultAuthModeService; +export const requireAuth = defaultAuthMiddleware.requireAuth; +export const optionalAuth = defaultAuthMiddleware.optionalAuth; diff --git a/backend/src/routes/dashboard.ts b/backend/src/routes/dashboard.ts index 148eadb..82d4c7e 100644 --- a/backend/src/routes/dashboard.ts +++ b/backend/src/routes/dashboard.ts @@ -1,625 +1,2 @@ -import express from "express"; -import { z } from "zod"; -import { Prisma, PrismaClient } from "../generated/client"; - -type SortField = "name" | "createdAt" | "updatedAt"; -type SortDirection = "asc" | "desc"; - -type BuildDrawingsCacheKey = (keyParts: { - userId: string; - searchTerm: string; - collectionFilter: string; - includeData: boolean; - sortField: SortField; - sortDirection: SortDirection; -}) => string; - -type EnsureTrashCollection = ( - db: Prisma.TransactionClient | PrismaClient, - userId: string -) => Promise; - -type LogAuditEvent = (params: { - userId: string; - action: string; - resource?: string; - ipAddress?: string; - userAgent?: string; - details?: Record; -}) => Promise; - -type DashboardRouteDeps = { - prisma: PrismaClient; - requireAuth: express.RequestHandler; - asyncHandler: ( - fn: (req: express.Request, res: express.Response, next: express.NextFunction) => Promise - ) => express.RequestHandler; - parseJsonField: (rawValue: string | null | undefined, fallback: T) => T; - sanitizeText: (input: unknown, maxLength?: number) => string; - validateImportedDrawing: (data: unknown) => boolean; - drawingCreateSchema: z.ZodTypeAny; - drawingUpdateSchema: z.ZodTypeAny; - respondWithValidationErrors: (res: express.Response, issues: z.ZodIssue[]) => void; - collectionNameSchema: z.ZodTypeAny; - ensureTrashCollection: EnsureTrashCollection; - invalidateDrawingsCache: () => void; - buildDrawingsCacheKey: BuildDrawingsCacheKey; - getCachedDrawingsBody: (key: string) => Buffer | null; - cacheDrawingsResponse: (key: string, payload: unknown) => Buffer; - MAX_PAGE_SIZE: number; - config: { - nodeEnv: string; - enableAuditLogging: boolean; - }; - logAuditEvent: LogAuditEvent; -}; - -export const registerDashboardRoutes = ( - app: express.Express, - deps: DashboardRouteDeps -) => { - const { - prisma, - requireAuth, - asyncHandler, - parseJsonField, - sanitizeText, - validateImportedDrawing, - drawingCreateSchema, - drawingUpdateSchema, - respondWithValidationErrors, - collectionNameSchema, - ensureTrashCollection, - invalidateDrawingsCache, - buildDrawingsCacheKey, - getCachedDrawingsBody, - cacheDrawingsResponse, - MAX_PAGE_SIZE, - config, - logAuditEvent, - } = deps; - - const getUserTrashCollectionId = (userId: string): string => `trash:${userId}`; - const isTrashCollectionId = ( - collectionId: string | null | undefined, - userId: string - ): boolean => - Boolean(collectionId) && - (collectionId === "trash" || collectionId === getUserTrashCollectionId(userId)); - const toInternalTrashCollectionId = ( - collectionId: string | null | undefined, - userId: string - ): string | null | undefined => - collectionId === "trash" ? getUserTrashCollectionId(userId) : collectionId; - const toPublicTrashCollectionId = ( - collectionId: string | null | undefined, - userId: string - ): string | null | undefined => - isTrashCollectionId(collectionId, userId) ? "trash" : collectionId; - - app.get("/drawings", requireAuth, asyncHandler(async (req, res) => { - if (!req.user) { - return res.status(401).json({ error: "Unauthorized" }); - } - - const trashCollectionId = getUserTrashCollectionId(req.user.id); - const { search, collectionId, includeData, limit, offset, sortField, sortDirection } = req.query; - const where: Prisma.DrawingWhereInput = { userId: req.user.id }; - const searchTerm = - typeof search === "string" && search.trim().length > 0 ? search.trim() : undefined; - - if (searchTerm) { - where.name = { contains: searchTerm }; - } - - let collectionFilterKey = "default"; - if (collectionId === "null") { - where.collectionId = null; - collectionFilterKey = "null"; - } else if (collectionId) { - const normalizedCollectionId = String(collectionId); - if (normalizedCollectionId === "trash") { - where.collectionId = { in: [trashCollectionId, "trash"] }; - collectionFilterKey = "trash"; - } else { - const collection = await prisma.collection.findFirst({ - where: { id: normalizedCollectionId, userId: req.user.id }, - }); - if (!collection) { - return res.status(404).json({ error: "Collection not found" }); - } - where.collectionId = normalizedCollectionId; - collectionFilterKey = `id:${normalizedCollectionId}`; - } - } else { - where.OR = [ - { collectionId: { notIn: [trashCollectionId, "trash"] } }, - { collectionId: null }, - ]; - } - - const shouldIncludeData = - typeof includeData === "string" - ? includeData.toLowerCase() === "true" || includeData === "1" - : false; - const parsedSortField: SortField = - sortField === "name" || sortField === "createdAt" || sortField === "updatedAt" - ? sortField - : "updatedAt"; - const parsedSortDirection: SortDirection = - sortDirection === "asc" || sortDirection === "desc" - ? sortDirection - : parsedSortField === "name" - ? "asc" - : "desc"; - - const rawLimit = limit ? Number.parseInt(limit as string, 10) : undefined; - const rawOffset = offset ? Number.parseInt(offset as string, 10) : undefined; - const parsedLimit = - rawLimit !== undefined && Number.isFinite(rawLimit) - ? Math.min(Math.max(rawLimit, 1), MAX_PAGE_SIZE) - : undefined; - const parsedOffset = - rawOffset !== undefined && Number.isFinite(rawOffset) ? Math.max(rawOffset, 0) : undefined; - - const cacheKey = - buildDrawingsCacheKey({ - userId: req.user.id, - searchTerm: searchTerm ?? "", - collectionFilter: collectionFilterKey, - includeData: shouldIncludeData, - sortField: parsedSortField, - sortDirection: parsedSortDirection, - }) + `:${parsedLimit}:${parsedOffset}`; - - const cachedBody = getCachedDrawingsBody(cacheKey); - if (cachedBody) { - res.setHeader("X-Cache", "HIT"); - res.setHeader("Content-Type", "application/json"); - return res.send(cachedBody); - } - - const summarySelect: Prisma.DrawingSelect = { - id: true, - name: true, - collectionId: true, - preview: true, - version: true, - createdAt: true, - updatedAt: true, - }; - - const orderBy: Prisma.DrawingOrderByWithRelationInput = - parsedSortField === "name" - ? { name: parsedSortDirection } - : parsedSortField === "createdAt" - ? { createdAt: parsedSortDirection } - : { updatedAt: parsedSortDirection }; - - const queryOptions: Prisma.DrawingFindManyArgs = { where, orderBy }; - if (parsedLimit !== undefined) queryOptions.take = parsedLimit; - if (parsedOffset !== undefined) queryOptions.skip = parsedOffset; - if (!shouldIncludeData) queryOptions.select = summarySelect; - - const [drawings, totalCount] = await Promise.all([ - prisma.drawing.findMany(queryOptions), - prisma.drawing.count({ where }), - ]); - - let responsePayload: any[] = drawings as any[]; - if (shouldIncludeData) { - responsePayload = (drawings as any[]).map((d: any) => ({ - ...d, - collectionId: toPublicTrashCollectionId(d.collectionId, req.user!.id), - elements: parseJsonField(d.elements, []), - appState: parseJsonField(d.appState, {}), - files: parseJsonField(d.files, {}), - })); - } else { - responsePayload = (drawings as any[]).map((d: any) => ({ - ...d, - collectionId: toPublicTrashCollectionId(d.collectionId, req.user!.id), - })); - } - - const finalResponse = { - drawings: responsePayload, - totalCount, - limit: parsedLimit, - offset: parsedOffset, - }; - - const body = cacheDrawingsResponse(cacheKey, finalResponse); - res.setHeader("X-Cache", "MISS"); - res.setHeader("Content-Type", "application/json"); - return res.send(body); - })); - - app.get("/drawings/:id", requireAuth, asyncHandler(async (req, res) => { - if (!req.user) return res.status(401).json({ error: "Unauthorized" }); - - const { id } = req.params; - const drawing = await prisma.drawing.findFirst({ - where: { - id, - userId: req.user.id, - }, - }); - if (!drawing) { - return res.status(404).json({ error: "Drawing not found", message: "Drawing does not exist" }); - } - - return res.json({ - ...drawing, - collectionId: toPublicTrashCollectionId(drawing.collectionId, req.user.id), - elements: parseJsonField(drawing.elements, []), - appState: parseJsonField(drawing.appState, {}), - files: parseJsonField(drawing.files, {}), - }); - })); - - app.post("/drawings", requireAuth, asyncHandler(async (req, res) => { - if (!req.user) return res.status(401).json({ error: "Unauthorized" }); - - const isImportedDrawing = req.headers["x-imported-file"] === "true"; - if (isImportedDrawing && !validateImportedDrawing(req.body)) { - return res.status(400).json({ - error: "Invalid imported drawing file", - message: "The imported file contains potentially malicious content or invalid structure", - }); - } - - const parsed = drawingCreateSchema.safeParse(req.body); - if (!parsed.success) { - return respondWithValidationErrors(res, parsed.error.issues); - } - - const payload = parsed.data as { - name?: string; - collectionId?: string | null; - elements: unknown[]; - appState: Record; - preview?: string | null; - files?: Record; - }; - const drawingName = payload.name ?? "Untitled Drawing"; - const targetCollectionIdRaw = payload.collectionId === undefined ? null : payload.collectionId; - const targetCollectionId = - toInternalTrashCollectionId(targetCollectionIdRaw, req.user.id) ?? null; - - if (targetCollectionId && !isTrashCollectionId(targetCollectionId, req.user.id)) { - const collection = await prisma.collection.findFirst({ - where: { id: targetCollectionId, userId: req.user.id }, - }); - if (!collection) return res.status(404).json({ error: "Collection not found" }); - } else if (targetCollectionIdRaw === "trash") { - await ensureTrashCollection(prisma, req.user.id); - } - - const newDrawing = await prisma.drawing.create({ - data: { - name: drawingName, - elements: JSON.stringify(payload.elements), - appState: JSON.stringify(payload.appState), - userId: req.user.id, - collectionId: targetCollectionId, - preview: payload.preview ?? null, - files: JSON.stringify(payload.files ?? {}), - }, - }); - invalidateDrawingsCache(); - - return res.json({ - ...newDrawing, - collectionId: toPublicTrashCollectionId(newDrawing.collectionId, req.user.id), - elements: parseJsonField(newDrawing.elements, []), - appState: parseJsonField(newDrawing.appState, {}), - files: parseJsonField(newDrawing.files, {}), - }); - })); - - app.put("/drawings/:id", requireAuth, asyncHandler(async (req, res) => { - if (!req.user) return res.status(401).json({ error: "Unauthorized" }); - - const { id } = req.params; - const existingDrawing = await prisma.drawing.findFirst({ - where: { id, userId: req.user.id }, - }); - if (!existingDrawing) return res.status(404).json({ error: "Drawing not found" }); - - const parsed = drawingUpdateSchema.safeParse(req.body); - if (!parsed.success) { - if (config.nodeEnv === "development") { - console.error("[API] Validation failed", { id, errors: parsed.error.issues }); - } - return respondWithValidationErrors(res, parsed.error.issues); - } - - const payload = parsed.data as { - name?: string; - collectionId?: string | null; - elements?: unknown[]; - appState?: Record; - preview?: string | null; - files?: Record; - version?: number; - }; - const trashCollectionId = getUserTrashCollectionId(req.user.id); - const isSceneUpdate = - payload.elements !== undefined || - payload.appState !== undefined || - payload.files !== undefined; - const data: Prisma.DrawingUpdateInput = isSceneUpdate - ? { version: { increment: 1 } } - : {}; - - if (payload.name !== undefined) data.name = payload.name; - if (payload.elements !== undefined) data.elements = JSON.stringify(payload.elements); - if (payload.appState !== undefined) data.appState = JSON.stringify(payload.appState); - if (payload.files !== undefined) data.files = JSON.stringify(payload.files); - if (payload.preview !== undefined) data.preview = payload.preview; - - if (payload.collectionId !== undefined) { - if (payload.collectionId === "trash") { - await ensureTrashCollection(prisma, req.user.id); - (data as Prisma.DrawingUncheckedUpdateInput).collectionId = trashCollectionId; - } else if (payload.collectionId) { - const collection = await prisma.collection.findFirst({ - where: { id: payload.collectionId, userId: req.user.id }, - }); - if (!collection) return res.status(404).json({ error: "Collection not found" }); - (data as Prisma.DrawingUncheckedUpdateInput).collectionId = payload.collectionId; - } else { - (data as Prisma.DrawingUncheckedUpdateInput).collectionId = null; - } - } - - const updateWhere: Prisma.DrawingWhereInput = { id, userId: req.user.id }; - if (isSceneUpdate && payload.version !== undefined) { - updateWhere.version = payload.version; - } - - const updateResult = await prisma.drawing.updateMany({ - where: updateWhere, - data, - }); - if (updateResult.count === 0) { - if (isSceneUpdate && payload.version !== undefined) { - const latestDrawing = await prisma.drawing.findFirst({ - where: { id, userId: req.user.id }, - select: { version: true }, - }); - return res.status(409).json({ - error: "Conflict", - code: "VERSION_CONFLICT", - message: "Drawing has changed since this editor state was loaded.", - currentVersion: latestDrawing?.version ?? null, - }); - } - return res.status(404).json({ error: "Drawing not found" }); - } - - const updatedDrawing = await prisma.drawing.findFirst({ - where: { id, userId: req.user.id }, - }); - if (!updatedDrawing) { - return res.status(404).json({ error: "Drawing not found" }); - } - invalidateDrawingsCache(); - - return res.json({ - ...updatedDrawing, - collectionId: toPublicTrashCollectionId(updatedDrawing.collectionId, req.user.id), - elements: parseJsonField(updatedDrawing.elements, []), - appState: parseJsonField(updatedDrawing.appState, {}), - files: parseJsonField(updatedDrawing.files, {}), - }); - })); - - app.delete("/drawings/:id", requireAuth, asyncHandler(async (req, res) => { - if (!req.user) return res.status(401).json({ error: "Unauthorized" }); - const { id } = req.params; - - const drawing = await prisma.drawing.findFirst({ where: { id, userId: req.user.id } }); - if (!drawing) return res.status(404).json({ error: "Drawing not found" }); - - const deleteResult = await prisma.drawing.deleteMany({ - where: { id, userId: req.user.id }, - }); - if (deleteResult.count === 0) { - return res.status(404).json({ error: "Drawing not found" }); - } - invalidateDrawingsCache(); - - if (config.enableAuditLogging) { - await logAuditEvent({ - userId: req.user.id, - action: "drawing_deleted", - resource: `drawing:${id}`, - ipAddress: req.ip || req.connection.remoteAddress || undefined, - userAgent: req.headers["user-agent"] || undefined, - details: { drawingId: id, drawingName: drawing.name }, - }); - } - - return res.json({ success: true }); - })); - - app.post("/drawings/:id/duplicate", requireAuth, asyncHandler(async (req, res) => { - if (!req.user) return res.status(401).json({ error: "Unauthorized" }); - - const { id } = req.params; - const original = await prisma.drawing.findFirst({ where: { id, userId: req.user.id } }); - if (!original) return res.status(404).json({ error: "Original drawing not found" }); - let duplicatedCollectionId = original.collectionId; - if (isTrashCollectionId(original.collectionId, req.user.id)) { - await ensureTrashCollection(prisma, req.user.id); - duplicatedCollectionId = getUserTrashCollectionId(req.user.id); - } - - const newDrawing = await prisma.drawing.create({ - data: { - name: `${original.name} (Copy)`, - elements: original.elements, - appState: original.appState, - files: original.files, - userId: req.user.id, - collectionId: duplicatedCollectionId, - version: 1, - }, - }); - invalidateDrawingsCache(); - - return res.json({ - ...newDrawing, - collectionId: toPublicTrashCollectionId(newDrawing.collectionId, req.user.id), - elements: parseJsonField(newDrawing.elements, []), - appState: parseJsonField(newDrawing.appState, {}), - files: parseJsonField(newDrawing.files, {}), - }); - })); - - app.get("/collections", requireAuth, asyncHandler(async (req, res) => { - if (!req.user) return res.status(401).json({ error: "Unauthorized" }); - const trashCollectionId = getUserTrashCollectionId(req.user.id); - await ensureTrashCollection(prisma, req.user.id); - - const rawCollections = await prisma.collection.findMany({ - where: { userId: req.user.id }, - orderBy: { createdAt: "desc" }, - }); - const hasInternalTrash = rawCollections.some((collection) => collection.id === trashCollectionId); - const collections = rawCollections - .filter((collection) => !(hasInternalTrash && collection.id === "trash")) - .map((collection) => - collection.id === trashCollectionId - ? { ...collection, id: "trash", name: "Trash" } - : collection - ); - return res.json(collections); - })); - - app.post("/collections", requireAuth, asyncHandler(async (req, res) => { - if (!req.user) return res.status(401).json({ error: "Unauthorized" }); - - const parsed = collectionNameSchema.safeParse(req.body.name); - if (!parsed.success) { - return res.status(400).json({ - error: "Validation error", - message: "Collection name must be between 1 and 100 characters", - }); - } - - const sanitizedName = sanitizeText(parsed.data, 100); - const newCollection = await prisma.collection.create({ - data: { name: sanitizedName, userId: req.user.id }, - }); - return res.json(newCollection); - })); - - app.put("/collections/:id", requireAuth, asyncHandler(async (req, res) => { - if (!req.user) return res.status(401).json({ error: "Unauthorized" }); - - const { id } = req.params; - if (isTrashCollectionId(id, req.user.id)) { - return res.status(400).json({ - error: "Validation error", - message: "Trash collection cannot be renamed", - }); - } - const existingCollection = await prisma.collection.findFirst({ - where: { id, userId: req.user.id }, - }); - if (!existingCollection) return res.status(404).json({ error: "Collection not found" }); - - const parsed = collectionNameSchema.safeParse(req.body.name); - if (!parsed.success) { - return res.status(400).json({ - error: "Validation error", - message: "Collection name must be between 1 and 100 characters", - }); - } - - const sanitizedName = sanitizeText(parsed.data, 100); - const updateResult = await prisma.collection.updateMany({ - where: { id, userId: req.user.id }, - data: { name: sanitizedName }, - }); - if (updateResult.count === 0) { - return res.status(404).json({ error: "Collection not found" }); - } - const updatedCollection = await prisma.collection.findFirst({ - where: { id, userId: req.user.id }, - }); - if (!updatedCollection) { - return res.status(404).json({ error: "Collection not found" }); - } - return res.json(updatedCollection); - })); - - app.delete("/collections/:id", requireAuth, asyncHandler(async (req, res) => { - if (!req.user) return res.status(401).json({ error: "Unauthorized" }); - - const { id } = req.params; - if (isTrashCollectionId(id, req.user.id)) { - return res.status(400).json({ - error: "Validation error", - message: "Trash collection cannot be deleted", - }); - } - const collection = await prisma.collection.findFirst({ - where: { id, userId: req.user.id }, - }); - if (!collection) return res.status(404).json({ error: "Collection not found" }); - - await prisma.$transaction([ - prisma.drawing.updateMany({ - where: { collectionId: id, userId: req.user.id }, - data: { collectionId: null }, - }), - prisma.collection.deleteMany({ where: { id, userId: req.user.id } }), - ]); - invalidateDrawingsCache(); - - if (config.enableAuditLogging) { - await logAuditEvent({ - userId: req.user.id, - action: "collection_deleted", - resource: `collection:${id}`, - ipAddress: req.ip || req.connection.remoteAddress || undefined, - userAgent: req.headers["user-agent"] || undefined, - details: { collectionId: id, collectionName: collection.name }, - }); - } - - return res.json({ success: true }); - })); - - app.get("/library", requireAuth, asyncHandler(async (req, res) => { - if (!req.user) return res.status(401).json({ error: "Unauthorized" }); - - const libraryId = `user_${req.user.id}`; - const library = await prisma.library.findUnique({ where: { id: libraryId } }); - if (!library) return res.json({ items: [] }); - - return res.json({ items: parseJsonField(library.items, []) }); - })); - - app.put("/library", requireAuth, asyncHandler(async (req, res) => { - if (!req.user) return res.status(401).json({ error: "Unauthorized" }); - - const { items } = req.body; - if (!Array.isArray(items)) { - return res.status(400).json({ error: "Items must be an array" }); - } - - const libraryId = `user_${req.user.id}`; - const library = await prisma.library.upsert({ - where: { id: libraryId }, - update: { items: JSON.stringify(items) }, - create: { id: libraryId, items: JSON.stringify(items) }, - }); - - return res.json({ items: parseJsonField(library.items, []) }); - })); -}; +export { registerDashboardRoutes } from "./dashboard/index"; +export type { DashboardRouteDeps } from "./dashboard/index"; diff --git a/backend/src/routes/dashboard/collections.ts b/backend/src/routes/dashboard/collections.ts new file mode 100644 index 0000000..b812aa0 --- /dev/null +++ b/backend/src/routes/dashboard/collections.ts @@ -0,0 +1,136 @@ +import express from "express"; +import { DashboardRouteDeps } from "./types"; +import { getUserTrashCollectionId, isTrashCollectionId } from "./trash"; + +export const registerCollectionRoutes = ( + app: express.Express, + deps: DashboardRouteDeps +) => { + const { + prisma, + requireAuth, + asyncHandler, + collectionNameSchema, + sanitizeText, + ensureTrashCollection, + invalidateDrawingsCache, + config, + logAuditEvent, + } = deps; + + app.get("/collections", requireAuth, asyncHandler(async (req, res) => { + if (!req.user) return res.status(401).json({ error: "Unauthorized" }); + const trashCollectionId = getUserTrashCollectionId(req.user.id); + await ensureTrashCollection(prisma, req.user.id); + + const rawCollections = await prisma.collection.findMany({ + where: { userId: req.user.id }, + orderBy: { createdAt: "desc" }, + }); + const hasInternalTrash = rawCollections.some((collection) => collection.id === trashCollectionId); + const collections = rawCollections + .filter((collection) => !(hasInternalTrash && collection.id === "trash")) + .map((collection) => + collection.id === trashCollectionId + ? { ...collection, id: "trash", name: "Trash" } + : collection + ); + return res.json(collections); + })); + + app.post("/collections", requireAuth, asyncHandler(async (req, res) => { + if (!req.user) return res.status(401).json({ error: "Unauthorized" }); + + const parsed = collectionNameSchema.safeParse(req.body.name); + if (!parsed.success) { + return res.status(400).json({ + error: "Validation error", + message: "Collection name must be between 1 and 100 characters", + }); + } + + const sanitizedName = sanitizeText(parsed.data, 100); + const newCollection = await prisma.collection.create({ + data: { name: sanitizedName, userId: req.user.id }, + }); + return res.json(newCollection); + })); + + app.put("/collections/:id", requireAuth, asyncHandler(async (req, res) => { + if (!req.user) return res.status(401).json({ error: "Unauthorized" }); + + const { id } = req.params; + if (isTrashCollectionId(id, req.user.id)) { + return res.status(400).json({ + error: "Validation error", + message: "Trash collection cannot be renamed", + }); + } + const existingCollection = await prisma.collection.findFirst({ + where: { id, userId: req.user.id }, + }); + if (!existingCollection) return res.status(404).json({ error: "Collection not found" }); + + const parsed = collectionNameSchema.safeParse(req.body.name); + if (!parsed.success) { + return res.status(400).json({ + error: "Validation error", + message: "Collection name must be between 1 and 100 characters", + }); + } + + const sanitizedName = sanitizeText(parsed.data, 100); + const updateResult = await prisma.collection.updateMany({ + where: { id, userId: req.user.id }, + data: { name: sanitizedName }, + }); + if (updateResult.count === 0) { + return res.status(404).json({ error: "Collection not found" }); + } + const updatedCollection = await prisma.collection.findFirst({ + where: { id, userId: req.user.id }, + }); + if (!updatedCollection) { + return res.status(404).json({ error: "Collection not found" }); + } + return res.json(updatedCollection); + })); + + app.delete("/collections/:id", requireAuth, asyncHandler(async (req, res) => { + if (!req.user) return res.status(401).json({ error: "Unauthorized" }); + + const { id } = req.params; + if (isTrashCollectionId(id, req.user.id)) { + return res.status(400).json({ + error: "Validation error", + message: "Trash collection cannot be deleted", + }); + } + const collection = await prisma.collection.findFirst({ + where: { id, userId: req.user.id }, + }); + if (!collection) return res.status(404).json({ error: "Collection not found" }); + + await prisma.$transaction([ + prisma.drawing.updateMany({ + where: { collectionId: id, userId: req.user.id }, + data: { collectionId: null }, + }), + prisma.collection.deleteMany({ where: { id, userId: req.user.id } }), + ]); + invalidateDrawingsCache(); + + if (config.enableAuditLogging) { + await logAuditEvent({ + userId: req.user.id, + action: "collection_deleted", + resource: `collection:${id}`, + ipAddress: req.ip || req.connection.remoteAddress || undefined, + userAgent: req.headers["user-agent"] || undefined, + details: { collectionId: id, collectionName: collection.name }, + }); + } + + return res.json({ success: true }); + })); +}; diff --git a/backend/src/routes/dashboard/drawings.ts b/backend/src/routes/dashboard/drawings.ts new file mode 100644 index 0000000..600611e --- /dev/null +++ b/backend/src/routes/dashboard/drawings.ts @@ -0,0 +1,415 @@ +import express from "express"; +import { Prisma } from "../../generated/client"; +import { DashboardRouteDeps, SortDirection, SortField } from "./types"; +import { + getUserTrashCollectionId, + isTrashCollectionId, + toInternalTrashCollectionId, + toPublicTrashCollectionId, +} from "./trash"; + +export const registerDrawingRoutes = ( + app: express.Express, + deps: DashboardRouteDeps +) => { + const { + prisma, + requireAuth, + asyncHandler, + parseJsonField, + validateImportedDrawing, + drawingCreateSchema, + drawingUpdateSchema, + respondWithValidationErrors, + ensureTrashCollection, + invalidateDrawingsCache, + buildDrawingsCacheKey, + getCachedDrawingsBody, + cacheDrawingsResponse, + MAX_PAGE_SIZE, + config, + logAuditEvent, + } = deps; + + app.get("/drawings", requireAuth, asyncHandler(async (req, res) => { + if (!req.user) { + return res.status(401).json({ error: "Unauthorized" }); + } + + const trashCollectionId = getUserTrashCollectionId(req.user.id); + const { search, collectionId, includeData, limit, offset, sortField, sortDirection } = req.query; + const where: Prisma.DrawingWhereInput = { userId: req.user.id }; + const searchTerm = + typeof search === "string" && search.trim().length > 0 ? search.trim() : undefined; + + if (searchTerm) { + where.name = { contains: searchTerm }; + } + + let collectionFilterKey = "default"; + if (collectionId === "null") { + where.collectionId = null; + collectionFilterKey = "null"; + } else if (collectionId) { + const normalizedCollectionId = String(collectionId); + if (normalizedCollectionId === "trash") { + where.collectionId = { in: [trashCollectionId, "trash"] }; + collectionFilterKey = "trash"; + } else { + const collection = await prisma.collection.findFirst({ + where: { id: normalizedCollectionId, userId: req.user.id }, + }); + if (!collection) { + return res.status(404).json({ error: "Collection not found" }); + } + where.collectionId = normalizedCollectionId; + collectionFilterKey = `id:${normalizedCollectionId}`; + } + } else { + where.OR = [ + { collectionId: { notIn: [trashCollectionId, "trash"] } }, + { collectionId: null }, + ]; + } + + const shouldIncludeData = + typeof includeData === "string" + ? includeData.toLowerCase() === "true" || includeData === "1" + : false; + const parsedSortField: SortField = + sortField === "name" || sortField === "createdAt" || sortField === "updatedAt" + ? sortField + : "updatedAt"; + const parsedSortDirection: SortDirection = + sortDirection === "asc" || sortDirection === "desc" + ? sortDirection + : parsedSortField === "name" + ? "asc" + : "desc"; + + const rawLimit = limit ? Number.parseInt(limit as string, 10) : undefined; + const rawOffset = offset ? Number.parseInt(offset as string, 10) : undefined; + const parsedLimit = + rawLimit !== undefined && Number.isFinite(rawLimit) + ? Math.min(Math.max(rawLimit, 1), MAX_PAGE_SIZE) + : undefined; + const parsedOffset = + rawOffset !== undefined && Number.isFinite(rawOffset) ? Math.max(rawOffset, 0) : undefined; + + const cacheKey = + buildDrawingsCacheKey({ + userId: req.user.id, + searchTerm: searchTerm ?? "", + collectionFilter: collectionFilterKey, + includeData: shouldIncludeData, + sortField: parsedSortField, + sortDirection: parsedSortDirection, + }) + `:${parsedLimit}:${parsedOffset}`; + + const cachedBody = getCachedDrawingsBody(cacheKey); + if (cachedBody) { + res.setHeader("X-Cache", "HIT"); + res.setHeader("Content-Type", "application/json"); + return res.send(cachedBody); + } + + const summarySelect: Prisma.DrawingSelect = { + id: true, + name: true, + collectionId: true, + preview: true, + version: true, + createdAt: true, + updatedAt: true, + }; + + const orderBy: Prisma.DrawingOrderByWithRelationInput = + parsedSortField === "name" + ? { name: parsedSortDirection } + : parsedSortField === "createdAt" + ? { createdAt: parsedSortDirection } + : { updatedAt: parsedSortDirection }; + + const queryOptions: Prisma.DrawingFindManyArgs = { where, orderBy }; + if (parsedLimit !== undefined) queryOptions.take = parsedLimit; + if (parsedOffset !== undefined) queryOptions.skip = parsedOffset; + if (!shouldIncludeData) queryOptions.select = summarySelect; + + const [drawings, totalCount] = await Promise.all([ + prisma.drawing.findMany(queryOptions), + prisma.drawing.count({ where }), + ]); + + let responsePayload: any[] = drawings as any[]; + if (shouldIncludeData) { + responsePayload = (drawings as any[]).map((d: any) => ({ + ...d, + collectionId: toPublicTrashCollectionId(d.collectionId, req.user!.id), + elements: parseJsonField(d.elements, []), + appState: parseJsonField(d.appState, {}), + files: parseJsonField(d.files, {}), + })); + } else { + responsePayload = (drawings as any[]).map((d: any) => ({ + ...d, + collectionId: toPublicTrashCollectionId(d.collectionId, req.user!.id), + })); + } + + const finalResponse = { + drawings: responsePayload, + totalCount, + limit: parsedLimit, + offset: parsedOffset, + }; + + const body = cacheDrawingsResponse(cacheKey, finalResponse); + res.setHeader("X-Cache", "MISS"); + res.setHeader("Content-Type", "application/json"); + return res.send(body); + })); + + app.get("/drawings/:id", requireAuth, asyncHandler(async (req, res) => { + if (!req.user) return res.status(401).json({ error: "Unauthorized" }); + + const { id } = req.params; + const drawing = await prisma.drawing.findFirst({ + where: { + id, + userId: req.user.id, + }, + }); + if (!drawing) { + return res.status(404).json({ error: "Drawing not found", message: "Drawing does not exist" }); + } + + return res.json({ + ...drawing, + collectionId: toPublicTrashCollectionId(drawing.collectionId, req.user.id), + elements: parseJsonField(drawing.elements, []), + appState: parseJsonField(drawing.appState, {}), + files: parseJsonField(drawing.files, {}), + }); + })); + + app.post("/drawings", requireAuth, asyncHandler(async (req, res) => { + if (!req.user) return res.status(401).json({ error: "Unauthorized" }); + + const isImportedDrawing = req.headers["x-imported-file"] === "true"; + if (isImportedDrawing && !validateImportedDrawing(req.body)) { + return res.status(400).json({ + error: "Invalid imported drawing file", + message: "The imported file contains potentially malicious content or invalid structure", + }); + } + + const parsed = drawingCreateSchema.safeParse(req.body); + if (!parsed.success) { + return respondWithValidationErrors(res, parsed.error.issues); + } + + const payload = parsed.data as { + name?: string; + collectionId?: string | null; + elements: unknown[]; + appState: Record; + preview?: string | null; + files?: Record; + }; + const drawingName = payload.name ?? "Untitled Drawing"; + const targetCollectionIdRaw = payload.collectionId === undefined ? null : payload.collectionId; + const targetCollectionId = + toInternalTrashCollectionId(targetCollectionIdRaw, req.user.id) ?? null; + + if (targetCollectionId && !isTrashCollectionId(targetCollectionId, req.user.id)) { + const collection = await prisma.collection.findFirst({ + where: { id: targetCollectionId, userId: req.user.id }, + }); + if (!collection) return res.status(404).json({ error: "Collection not found" }); + } else if (targetCollectionIdRaw === "trash") { + await ensureTrashCollection(prisma, req.user.id); + } + + const newDrawing = await prisma.drawing.create({ + data: { + name: drawingName, + elements: JSON.stringify(payload.elements), + appState: JSON.stringify(payload.appState), + userId: req.user.id, + collectionId: targetCollectionId, + preview: payload.preview ?? null, + files: JSON.stringify(payload.files ?? {}), + }, + }); + invalidateDrawingsCache(); + + return res.json({ + ...newDrawing, + collectionId: toPublicTrashCollectionId(newDrawing.collectionId, req.user.id), + elements: parseJsonField(newDrawing.elements, []), + appState: parseJsonField(newDrawing.appState, {}), + files: parseJsonField(newDrawing.files, {}), + }); + })); + + app.put("/drawings/:id", requireAuth, asyncHandler(async (req, res) => { + if (!req.user) return res.status(401).json({ error: "Unauthorized" }); + + const { id } = req.params; + const existingDrawing = await prisma.drawing.findFirst({ + where: { id, userId: req.user.id }, + }); + if (!existingDrawing) return res.status(404).json({ error: "Drawing not found" }); + + const parsed = drawingUpdateSchema.safeParse(req.body); + if (!parsed.success) { + if (config.nodeEnv === "development") { + console.error("[API] Validation failed", { id, errors: parsed.error.issues }); + } + return respondWithValidationErrors(res, parsed.error.issues); + } + + const payload = parsed.data as { + name?: string; + collectionId?: string | null; + elements?: unknown[]; + appState?: Record; + preview?: string | null; + files?: Record; + version?: number; + }; + const trashCollectionId = getUserTrashCollectionId(req.user.id); + const isSceneUpdate = + payload.elements !== undefined || + payload.appState !== undefined || + payload.files !== undefined; + const data: Prisma.DrawingUpdateInput = isSceneUpdate + ? { version: { increment: 1 } } + : {}; + + if (payload.name !== undefined) data.name = payload.name; + if (payload.elements !== undefined) data.elements = JSON.stringify(payload.elements); + if (payload.appState !== undefined) data.appState = JSON.stringify(payload.appState); + if (payload.files !== undefined) data.files = JSON.stringify(payload.files); + if (payload.preview !== undefined) data.preview = payload.preview; + + if (payload.collectionId !== undefined) { + if (payload.collectionId === "trash") { + await ensureTrashCollection(prisma, req.user.id); + (data as Prisma.DrawingUncheckedUpdateInput).collectionId = trashCollectionId; + } else if (payload.collectionId) { + const collection = await prisma.collection.findFirst({ + where: { id: payload.collectionId, userId: req.user.id }, + }); + if (!collection) return res.status(404).json({ error: "Collection not found" }); + (data as Prisma.DrawingUncheckedUpdateInput).collectionId = payload.collectionId; + } else { + (data as Prisma.DrawingUncheckedUpdateInput).collectionId = null; + } + } + + const updateWhere: Prisma.DrawingWhereInput = { id, userId: req.user.id }; + if (isSceneUpdate && payload.version !== undefined) { + updateWhere.version = payload.version; + } + + const updateResult = await prisma.drawing.updateMany({ + where: updateWhere, + data, + }); + if (updateResult.count === 0) { + if (isSceneUpdate && payload.version !== undefined) { + const latestDrawing = await prisma.drawing.findFirst({ + where: { id, userId: req.user.id }, + select: { version: true }, + }); + return res.status(409).json({ + error: "Conflict", + code: "VERSION_CONFLICT", + message: "Drawing has changed since this editor state was loaded.", + currentVersion: latestDrawing?.version ?? null, + }); + } + return res.status(404).json({ error: "Drawing not found" }); + } + + const updatedDrawing = await prisma.drawing.findFirst({ + where: { id, userId: req.user.id }, + }); + if (!updatedDrawing) { + return res.status(404).json({ error: "Drawing not found" }); + } + invalidateDrawingsCache(); + + return res.json({ + ...updatedDrawing, + collectionId: toPublicTrashCollectionId(updatedDrawing.collectionId, req.user.id), + elements: parseJsonField(updatedDrawing.elements, []), + appState: parseJsonField(updatedDrawing.appState, {}), + files: parseJsonField(updatedDrawing.files, {}), + }); + })); + + app.delete("/drawings/:id", requireAuth, asyncHandler(async (req, res) => { + if (!req.user) return res.status(401).json({ error: "Unauthorized" }); + const { id } = req.params; + + const drawing = await prisma.drawing.findFirst({ where: { id, userId: req.user.id } }); + if (!drawing) return res.status(404).json({ error: "Drawing not found" }); + + const deleteResult = await prisma.drawing.deleteMany({ + where: { id, userId: req.user.id }, + }); + if (deleteResult.count === 0) { + return res.status(404).json({ error: "Drawing not found" }); + } + invalidateDrawingsCache(); + + if (config.enableAuditLogging) { + await logAuditEvent({ + userId: req.user.id, + action: "drawing_deleted", + resource: `drawing:${id}`, + ipAddress: req.ip || req.connection.remoteAddress || undefined, + userAgent: req.headers["user-agent"] || undefined, + details: { drawingId: id, drawingName: drawing.name }, + }); + } + + return res.json({ success: true }); + })); + + app.post("/drawings/:id/duplicate", requireAuth, asyncHandler(async (req, res) => { + if (!req.user) return res.status(401).json({ error: "Unauthorized" }); + + const { id } = req.params; + const original = await prisma.drawing.findFirst({ where: { id, userId: req.user.id } }); + if (!original) return res.status(404).json({ error: "Original drawing not found" }); + let duplicatedCollectionId = original.collectionId; + if (isTrashCollectionId(original.collectionId, req.user.id)) { + await ensureTrashCollection(prisma, req.user.id); + duplicatedCollectionId = getUserTrashCollectionId(req.user.id); + } + + const newDrawing = await prisma.drawing.create({ + data: { + name: `${original.name} (Copy)`, + elements: original.elements, + appState: original.appState, + files: original.files, + userId: req.user.id, + collectionId: duplicatedCollectionId, + version: 1, + }, + }); + invalidateDrawingsCache(); + + return res.json({ + ...newDrawing, + collectionId: toPublicTrashCollectionId(newDrawing.collectionId, req.user.id), + elements: parseJsonField(newDrawing.elements, []), + appState: parseJsonField(newDrawing.appState, {}), + files: parseJsonField(newDrawing.files, {}), + }); + })); +}; diff --git a/backend/src/routes/dashboard/index.ts b/backend/src/routes/dashboard/index.ts new file mode 100644 index 0000000..82888cf --- /dev/null +++ b/backend/src/routes/dashboard/index.ts @@ -0,0 +1,16 @@ +import express from "express"; +import { registerCollectionRoutes } from "./collections"; +import { registerDrawingRoutes } from "./drawings"; +import { registerLibraryRoutes } from "./library"; +import { DashboardRouteDeps } from "./types"; + +export const registerDashboardRoutes = ( + app: express.Express, + deps: DashboardRouteDeps +) => { + registerDrawingRoutes(app, deps); + registerCollectionRoutes(app, deps); + registerLibraryRoutes(app, deps); +}; + +export type { DashboardRouteDeps } from "./types"; diff --git a/backend/src/routes/dashboard/library.ts b/backend/src/routes/dashboard/library.ts new file mode 100644 index 0000000..d159c55 --- /dev/null +++ b/backend/src/routes/dashboard/library.ts @@ -0,0 +1,37 @@ +import express from "express"; +import { DashboardRouteDeps } from "./types"; + +export const registerLibraryRoutes = ( + app: express.Express, + deps: DashboardRouteDeps +) => { + const { prisma, requireAuth, asyncHandler, parseJsonField } = deps; + + app.get("/library", requireAuth, asyncHandler(async (req, res) => { + if (!req.user) return res.status(401).json({ error: "Unauthorized" }); + + const libraryId = `user_${req.user.id}`; + const library = await prisma.library.findUnique({ where: { id: libraryId } }); + if (!library) return res.json({ items: [] }); + + return res.json({ items: parseJsonField(library.items, []) }); + })); + + app.put("/library", requireAuth, asyncHandler(async (req, res) => { + if (!req.user) return res.status(401).json({ error: "Unauthorized" }); + + const { items } = req.body; + if (!Array.isArray(items)) { + return res.status(400).json({ error: "Items must be an array" }); + } + + const libraryId = `user_${req.user.id}`; + const library = await prisma.library.upsert({ + where: { id: libraryId }, + update: { items: JSON.stringify(items) }, + create: { id: libraryId, items: JSON.stringify(items) }, + }); + + return res.json({ items: parseJsonField(library.items, []) }); + })); +}; diff --git a/backend/src/routes/dashboard/trash.ts b/backend/src/routes/dashboard/trash.ts new file mode 100644 index 0000000..1a00814 --- /dev/null +++ b/backend/src/routes/dashboard/trash.ts @@ -0,0 +1,20 @@ +export const getUserTrashCollectionId = (userId: string): string => `trash:${userId}`; + +export const isTrashCollectionId = ( + collectionId: string | null | undefined, + userId: string +): boolean => + Boolean(collectionId) && + (collectionId === "trash" || collectionId === getUserTrashCollectionId(userId)); + +export const toInternalTrashCollectionId = ( + collectionId: string | null | undefined, + userId: string +): string | null | undefined => + collectionId === "trash" ? getUserTrashCollectionId(userId) : collectionId; + +export const toPublicTrashCollectionId = ( + collectionId: string | null | undefined, + userId: string +): string | null | undefined => + isTrashCollectionId(collectionId, userId) ? "trash" : collectionId; diff --git a/backend/src/routes/dashboard/types.ts b/backend/src/routes/dashboard/types.ts new file mode 100644 index 0000000..920af0f --- /dev/null +++ b/backend/src/routes/dashboard/types.ts @@ -0,0 +1,55 @@ +import express from "express"; +import { z } from "zod"; +import { Prisma, PrismaClient } from "../../generated/client"; + +export type SortField = "name" | "createdAt" | "updatedAt"; +export type SortDirection = "asc" | "desc"; + +type BuildDrawingsCacheKey = (keyParts: { + userId: string; + searchTerm: string; + collectionFilter: string; + includeData: boolean; + sortField: SortField; + sortDirection: SortDirection; +}) => string; + +type EnsureTrashCollection = ( + db: Prisma.TransactionClient | PrismaClient, + userId: string +) => Promise; + +type LogAuditEvent = (params: { + userId: string; + action: string; + resource?: string; + ipAddress?: string; + userAgent?: string; + details?: Record; +}) => Promise; + +export type DashboardRouteDeps = { + prisma: PrismaClient; + requireAuth: express.RequestHandler; + asyncHandler: ( + fn: (req: express.Request, res: express.Response, next: express.NextFunction) => Promise + ) => express.RequestHandler; + parseJsonField: (rawValue: string | null | undefined, fallback: T) => T; + sanitizeText: (input: unknown, maxLength?: number) => string; + validateImportedDrawing: (data: unknown) => boolean; + drawingCreateSchema: z.ZodTypeAny; + drawingUpdateSchema: z.ZodTypeAny; + respondWithValidationErrors: (res: express.Response, issues: z.ZodIssue[]) => void; + collectionNameSchema: z.ZodTypeAny; + ensureTrashCollection: EnsureTrashCollection; + invalidateDrawingsCache: () => void; + buildDrawingsCacheKey: BuildDrawingsCacheKey; + getCachedDrawingsBody: (key: string) => Buffer | null; + cacheDrawingsResponse: (key: string, payload: unknown) => Buffer; + MAX_PAGE_SIZE: number; + config: { + nodeEnv: string; + enableAuditLogging: boolean; + }; + logAuditEvent: LogAuditEvent; +}; diff --git a/backend/src/routes/importExport.ts b/backend/src/routes/importExport.ts index 4206703..6831f8f 100644 --- a/backend/src/routes/importExport.ts +++ b/backend/src/routes/importExport.ts @@ -1,1196 +1,2 @@ -import express from "express"; -import path from "path"; -import { promises as fsPromises } from "fs"; -import archiver from "archiver"; -import JSZip from "jszip"; -import { z } from "zod"; -import { v4 as uuidv4 } from "uuid"; -import { Prisma, PrismaClient } from "../generated/client"; -import { sanitizeDrawingData } from "../security"; - -class ImportValidationError extends Error { - status: number; - - constructor(message: string, status = 400) { - super(message); - this.name = "ImportValidationError"; - this.status = status; - } -} - -const excalidashManifestSchemaV1 = z.object({ - format: z.literal("excalidash"), - formatVersion: z.literal(1), - exportedAt: z.string().min(1), - excalidashBackendVersion: z.string().optional(), - userId: z.string().optional(), - unorganizedFolder: z.string().min(1), - collections: z.array( - z.object({ - id: z.string().min(1), - name: z.string(), - folder: z.string().min(1), - createdAt: z.string().optional(), - updatedAt: z.string().optional(), - }) - ), - drawings: z.array( - z.object({ - id: z.string().min(1), - name: z.string(), - filePath: z.string().min(1), - collectionId: z.string().nullable(), - version: z.number().int().optional(), - createdAt: z.string().optional(), - updatedAt: z.string().optional(), - }) - ), -}); - -type RegisterImportExportDeps = { - app: express.Express; - prisma: PrismaClient; - requireAuth: express.RequestHandler; - asyncHandler: ( - fn: (req: express.Request, res: express.Response, next: express.NextFunction) => Promise - ) => express.RequestHandler; - upload: any; - uploadDir: string; - backendRoot: string; - getBackendVersion: () => string; - parseJsonField: (rawValue: string | null | undefined, fallback: T) => T; - sanitizeText: (input: unknown, maxLength?: number) => string; - validateImportedDrawing: (data: unknown) => boolean; - ensureTrashCollection: ( - db: Prisma.TransactionClient | PrismaClient, - userId: string - ) => Promise; - invalidateDrawingsCache: () => void; - removeFileIfExists: (filePath?: string) => Promise; - verifyDatabaseIntegrityAsync: (filePath: string) => Promise; - MAX_IMPORT_ARCHIVE_ENTRIES: number; - MAX_IMPORT_COLLECTIONS: number; - MAX_IMPORT_DRAWINGS: number; - MAX_IMPORT_MANIFEST_BYTES: number; - MAX_IMPORT_DRAWING_BYTES: number; - MAX_IMPORT_TOTAL_EXTRACTED_BYTES: number; -}; - -const getZipEntries = (zip: JSZip) => Object.values(zip.files).filter((entry) => !entry.dir); - -const normalizeArchivePath = (filePath: string): string => - path.posix.normalize(filePath.replace(/\\/g, "/")); - -const assertSafeArchivePath = (filePath: string) => { - const normalized = normalizeArchivePath(filePath); - if ( - normalized.length === 0 || - path.posix.isAbsolute(normalized) || - normalized === ".." || - normalized.startsWith("../") || - normalized.includes("\0") - ) { - throw new ImportValidationError(`Unsafe archive path: ${filePath}`); - } -}; - -const assertSafeZipArchive = (zip: JSZip, maxEntries: number) => { - const entries = getZipEntries(zip); - if (entries.length > maxEntries) { - throw new ImportValidationError("Archive contains too many files"); - } - for (const entry of entries) { - assertSafeArchivePath(entry.name); - } -}; - -const getSafeZipEntry = (zip: JSZip, filePath: string) => { - const normalizedPath = normalizeArchivePath(filePath); - assertSafeArchivePath(normalizedPath); - return zip.file(normalizedPath); -}; - -const sanitizePathSegment = (input: string, fallback: string): string => { - const value = typeof input === "string" ? input.trim() : ""; - const cleaned = value - .replace(/[<>:"/\\|?*\x00-\x1F]/g, "_") - .replace(/\s+/g, " ") - .slice(0, 120) - .trim(); - return cleaned.length > 0 ? cleaned : fallback; -}; - -const makeUniqueName = (base: string, used: Set): string => { - let candidate = base; - let n = 2; - while (used.has(candidate)) { - candidate = `${base}__${n}`; - n += 1; - } - used.add(candidate); - return candidate; -}; - -const findFirstDuplicate = (values: string[]): string | null => { - const seen = new Set(); - for (const value of values) { - if (seen.has(value)) return value; - seen.add(value); - } - return null; -}; - -const normalizeNonEmptyId = (value: unknown): string | null => { - if (typeof value !== "string") return null; - const trimmed = value.trim(); - return trimmed.length > 0 ? trimmed : null; -}; - -const getUserTrashCollectionId = (userId: string): string => `trash:${userId}`; - -const isTrashCollectionId = ( - collectionId: string | null | undefined, - userId: string -): boolean => - Boolean(collectionId) && - (collectionId === "trash" || collectionId === getUserTrashCollectionId(userId)); - -const toPublicTrashCollectionId = ( - collectionId: string | null | undefined, - userId: string -): string | null => - isTrashCollectionId(collectionId, userId) ? "trash" : collectionId ?? null; - -const findSqliteTable = (tables: string[], candidates: string[]): string | null => { - const byLower = new Map(tables.map((t) => [t.toLowerCase(), t])); - for (const candidate of candidates) { - const found = byLower.get(candidate.toLowerCase()); - if (found) return found; - } - return null; -}; - -const parseOptionalJson = (raw: unknown, fallback: T): T => { - if (typeof raw === "string") { - try { - return JSON.parse(raw) as T; - } catch { - return fallback; - } - } - if (typeof raw === "object" && raw !== null) { - return raw as T; - } - return fallback; -}; - -const isPathInsideDirectory = (candidatePath: string, rootDir: string): boolean => { - const relativePath = path.relative(rootDir, candidatePath); - return ( - relativePath === "" || - (!relativePath.startsWith("..") && !path.isAbsolute(relativePath)) - ); -}; - -const isSafeMulterTempFilename = (value: string): boolean => - /^[a-f0-9]{32}$/.test(value); - -const resolveSafeUploadedFilePath = async ( - fileMeta: { filename?: unknown }, - uploadRoot: string -): Promise => { - const absoluteUploadRoot = path.resolve(uploadRoot); - let canonicalUploadRoot = absoluteUploadRoot; - - try { - canonicalUploadRoot = await fsPromises.realpath(absoluteUploadRoot); - } catch { - throw new ImportValidationError("Invalid upload path"); - } - - const filename = typeof fileMeta.filename === "string" ? fileMeta.filename : ""; - if (!isSafeMulterTempFilename(filename)) { - throw new ImportValidationError("Invalid upload path"); - } - - const joinedPath = path.resolve(canonicalUploadRoot, filename); - if (!isPathInsideDirectory(joinedPath, canonicalUploadRoot)) { - throw new ImportValidationError("Invalid upload path"); - } - - return joinedPath; -}; - -const openReadonlySqliteDb = (filePath: string): any => { - try { - // eslint-disable-next-line @typescript-eslint/no-var-requires - const { DatabaseSync } = require("node:sqlite") as any; - return new DatabaseSync(filePath, { - readOnly: true, - enableForeignKeyConstraints: false, - }); - } catch { - // eslint-disable-next-line @typescript-eslint/no-var-requires - const Database = require("better-sqlite3") as any; - return new Database(filePath, { readonly: true, fileMustExist: true }); - } -}; - -const getCurrentLatestPrismaMigrationName = async (backendRoot: string): Promise => { - try { - const migrationsDir = path.resolve(backendRoot, "prisma/migrations"); - const entries = await fsPromises.readdir(migrationsDir, { withFileTypes: true }); - const dirs = entries - .filter((e) => e.isDirectory()) - .map((e) => e.name) - .filter((name) => !name.startsWith(".")); - if (dirs.length === 0) return null; - dirs.sort(); - return dirs[dirs.length - 1] || null; - } catch { - return null; - } -}; - -export const registerImportExportRoutes = (deps: RegisterImportExportDeps) => { - const { - app, - prisma, - requireAuth, - asyncHandler, - upload, - uploadDir, - backendRoot, - getBackendVersion, - parseJsonField, - sanitizeText, - validateImportedDrawing, - ensureTrashCollection, - invalidateDrawingsCache, - removeFileIfExists, - verifyDatabaseIntegrityAsync, - MAX_IMPORT_ARCHIVE_ENTRIES, - MAX_IMPORT_COLLECTIONS, - MAX_IMPORT_DRAWINGS, - MAX_IMPORT_MANIFEST_BYTES, - MAX_IMPORT_DRAWING_BYTES, - MAX_IMPORT_TOTAL_EXTRACTED_BYTES, - } = deps; - - app.get("/export/excalidash", requireAuth, asyncHandler(async (req, res) => { - if (!req.user) return res.status(401).json({ error: "Unauthorized" }); - const trashCollectionId = getUserTrashCollectionId(req.user.id); - - const extParam = typeof req.query.ext === "string" ? req.query.ext.toLowerCase() : ""; - const zipSuffix = extParam === "zip"; - const date = new Date().toISOString().split("T")[0]; - const filename = zipSuffix - ? `excalidash-backup-${date}.excalidash.zip` - : `excalidash-backup-${date}.excalidash`; - - const exportedAt = new Date().toISOString(); - const drawings = await prisma.drawing.findMany({ - where: { userId: req.user.id }, - include: { collection: true }, - }); - const userCollections = await prisma.collection.findMany({ - where: { userId: req.user.id }, - }); - - const hasInternalTrashCollection = userCollections.some((collection) => collection.id === trashCollectionId); - const normalizedUserCollections = userCollections.filter( - (collection) => !(hasInternalTrashCollection && collection.id === "trash") - ); - const hasTrashDrawings = drawings.some((drawing) => - isTrashCollectionId(drawing.collectionId, req.user!.id) - ); - const collectionsToExport = [...normalizedUserCollections]; - if ( - hasTrashDrawings && - !collectionsToExport.some((collection) => - isTrashCollectionId(collection.id, req.user!.id) - ) - ) { - const trash = await prisma.collection.findFirst({ - where: { userId: req.user.id, id: { in: [trashCollectionId, "trash"] } }, - }); - if (trash) collectionsToExport.push(trash); - } - - const exportSource = `${req.protocol}://${req.get("host")}`; - const usedFolderNames = new Set(); - const unorganizedFolder = makeUniqueName("Unorganized", usedFolderNames); - const folderByCollectionId = new Map(); - for (const collection of collectionsToExport) { - const base = sanitizePathSegment(collection.name, "Collection"); - const folder = makeUniqueName(base, usedFolderNames); - folderByCollectionId.set(collection.id, folder); - } - - type DrawingWithCollection = Prisma.DrawingGetPayload<{ include: { collection: true } }>; - const drawingsManifest = drawings.map((drawing: DrawingWithCollection) => { - const folder = drawing.collectionId - ? folderByCollectionId.get(drawing.collectionId) || unorganizedFolder - : unorganizedFolder; - const fileNameBase = sanitizePathSegment(drawing.name, "Untitled"); - const fileName = `${fileNameBase}__${drawing.id.slice(0, 8)}.excalidraw`; - return { - id: drawing.id, - name: drawing.name, - filePath: `${folder}/${fileName}`, - collectionId: toPublicTrashCollectionId(drawing.collectionId, req.user!.id), - version: drawing.version, - createdAt: drawing.createdAt.toISOString(), - updatedAt: drawing.updatedAt.toISOString(), - }; - }); - - const manifestCollections = collectionsToExport - .map((collection) => ({ - id: toPublicTrashCollectionId(collection.id, req.user!.id) || collection.id, - name: isTrashCollectionId(collection.id, req.user!.id) ? "Trash" : collection.name, - folder: folderByCollectionId.get(collection.id) || sanitizePathSegment(collection.name, "Collection"), - createdAt: collection.createdAt.toISOString(), - updatedAt: collection.updatedAt.toISOString(), - })) - .filter((collection, index, all) => all.findIndex((c) => c.id === collection.id) === index); - - const manifest = { - format: "excalidash" as const, - formatVersion: 1 as const, - exportedAt, - excalidashBackendVersion: getBackendVersion(), - userId: req.user.id, - unorganizedFolder, - collections: manifestCollections, - drawings: drawingsManifest, - }; - - res.setHeader("Content-Type", "application/zip"); - res.setHeader("Content-Disposition", `attachment; filename="${filename}"`); - - const archive = archiver("zip", { zlib: { level: 9 } }); - archive.on("error", (err) => { - console.error("Archive error:", err); - res.status(500).json({ error: "Failed to create archive" }); - }); - archive.pipe(res); - - archive.append(JSON.stringify(manifest, null, 2), { name: "excalidash.manifest.json" }); - - const drawingsManifestById = new Map(drawingsManifest.map((d) => [d.id, d])); - for (const drawing of drawings) { - const meta = drawingsManifestById.get(drawing.id); - if (!meta) continue; - const drawingData = { - type: "excalidraw" as const, - version: 2 as const, - source: exportSource, - elements: parseJsonField(drawing.elements, [] as unknown[]), - appState: parseJsonField(drawing.appState, {} as Record), - files: parseJsonField(drawing.files, {} as Record), - excalidash: { - drawingId: drawing.id, - collectionId: drawing.collectionId ?? null, - exportedAt, - }, - }; - assertSafeArchivePath(meta.filePath); - archive.append(JSON.stringify(drawingData, null, 2), { name: meta.filePath }); - } - - const readme = `ExcaliDash Backup (.excalidash) - -This file is a zip archive containing a versioned ExcaliDash manifest and your drawings, -organized into folders by collection. - -Files: -- excalidash.manifest.json (required) -- /*.excalidraw - -ExportedAt: ${exportedAt} -FormatVersion: 1 -BackendVersion: ${getBackendVersion()} -Collections: ${collectionsToExport.length} -Drawings: ${drawings.length} -`; - archive.append(readme, { name: "README.txt" }); - await archive.finalize(); - })); - - app.post("/import/excalidash/verify", requireAuth, upload.single("archive"), asyncHandler(async (req, res) => { - if (!req.user) return res.status(401).json({ error: "Unauthorized" }); - if (!req.file) return res.status(400).json({ error: "No file uploaded" }); - - let stagedPath: string; - try { - stagedPath = await resolveSafeUploadedFilePath( - { filename: req.file.filename }, - uploadDir - ); - } catch (error) { - if (error instanceof ImportValidationError) { - return res.status(error.status).json({ error: "Invalid upload", message: error.message }); - } - throw error; - } - try { - const buffer = await fsPromises.readFile(stagedPath); - const zip = await JSZip.loadAsync(buffer); - try { - assertSafeZipArchive(zip, MAX_IMPORT_ARCHIVE_ENTRIES); - } catch (error) { - if (error instanceof ImportValidationError) { - return res.status(error.status).json({ error: "Invalid backup", message: error.message }); - } - throw error; - } - - const manifestFile = getSafeZipEntry(zip, "excalidash.manifest.json"); - if (!manifestFile) { - return res.status(400).json({ error: "Invalid backup", message: "Missing excalidash.manifest.json" }); - } - const rawManifest = await manifestFile.async("string"); - if (Buffer.byteLength(rawManifest, "utf8") > MAX_IMPORT_MANIFEST_BYTES) { - return res.status(400).json({ - error: "Invalid backup manifest", - message: "excalidash.manifest.json is too large", - }); - } - - let manifestJson: unknown; - try { - manifestJson = JSON.parse(rawManifest); - } catch { - return res.status(400).json({ - error: "Invalid backup manifest", - message: "excalidash.manifest.json is not valid JSON", - }); - } - const parsed = excalidashManifestSchemaV1.safeParse(manifestJson); - if (!parsed.success) { - return res.status(400).json({ - error: "Invalid backup manifest", - message: "Malformed excalidash.manifest.json", - }); - } - const manifest = parsed.data; - if (manifest.collections.length > MAX_IMPORT_COLLECTIONS) { - return res.status(400).json({ - error: "Invalid backup manifest", - message: `Too many collections (max ${MAX_IMPORT_COLLECTIONS})`, - }); - } - if (manifest.drawings.length > MAX_IMPORT_DRAWINGS) { - return res.status(400).json({ - error: "Invalid backup manifest", - message: `Too many drawings (max ${MAX_IMPORT_DRAWINGS})`, - }); - } - - const duplicateCollectionId = findFirstDuplicate(manifest.collections.map((c) => c.id)); - if (duplicateCollectionId) { - return res.status(400).json({ - error: "Invalid backup manifest", - message: `Duplicate collection id in manifest: ${duplicateCollectionId}`, - }); - } - const duplicateDrawingId = findFirstDuplicate(manifest.drawings.map((d) => d.id)); - if (duplicateDrawingId) { - return res.status(400).json({ - error: "Invalid backup manifest", - message: `Duplicate drawing id in manifest: ${duplicateDrawingId}`, - }); - } - const duplicateDrawingPath = findFirstDuplicate(manifest.drawings.map((d) => d.filePath)); - if (duplicateDrawingPath) { - return res.status(400).json({ - error: "Invalid backup manifest", - message: `Duplicate drawing file path in manifest: ${duplicateDrawingPath}`, - }); - } - for (const drawing of manifest.drawings) { - if (!getSafeZipEntry(zip, drawing.filePath)) { - return res.status(400).json({ - error: "Invalid backup", - message: `Missing drawing file: ${drawing.filePath}`, - }); - } - } - - return res.json({ - valid: true, - formatVersion: manifest.formatVersion, - exportedAt: manifest.exportedAt, - excalidashBackendVersion: manifest.excalidashBackendVersion || null, - collections: manifest.collections.length, - drawings: manifest.drawings.length, - }); - } finally { - await removeFileIfExists(stagedPath); - } - })); - - app.post("/import/excalidash", requireAuth, upload.single("archive"), asyncHandler(async (req, res) => { - if (!req.user) return res.status(401).json({ error: "Unauthorized" }); - if (!req.file) return res.status(400).json({ error: "No file uploaded" }); - - let stagedPath: string; - try { - stagedPath = await resolveSafeUploadedFilePath( - { filename: req.file.filename }, - uploadDir - ); - } catch (error) { - if (error instanceof ImportValidationError) { - return res.status(error.status).json({ error: "Invalid upload", message: error.message }); - } - throw error; - } - try { - const buffer = await fsPromises.readFile(stagedPath); - const zip = await JSZip.loadAsync(buffer); - try { - assertSafeZipArchive(zip, MAX_IMPORT_ARCHIVE_ENTRIES); - } catch (error) { - if (error instanceof ImportValidationError) { - return res.status(error.status).json({ error: "Invalid backup", message: error.message }); - } - throw error; - } - - const manifestFile = getSafeZipEntry(zip, "excalidash.manifest.json"); - if (!manifestFile) { - return res.status(400).json({ error: "Invalid backup", message: "Missing excalidash.manifest.json" }); - } - const rawManifest = await manifestFile.async("string"); - if (Buffer.byteLength(rawManifest, "utf8") > MAX_IMPORT_MANIFEST_BYTES) { - return res.status(400).json({ - error: "Invalid backup manifest", - message: "excalidash.manifest.json is too large", - }); - } - - let manifestJson: unknown; - try { - manifestJson = JSON.parse(rawManifest); - } catch { - return res.status(400).json({ - error: "Invalid backup manifest", - message: "excalidash.manifest.json is not valid JSON", - }); - } - const parsed = excalidashManifestSchemaV1.safeParse(manifestJson); - if (!parsed.success) { - return res.status(400).json({ - error: "Invalid backup manifest", - message: "Malformed excalidash.manifest.json", - }); - } - const manifest = parsed.data; - - if (manifest.collections.length > MAX_IMPORT_COLLECTIONS) { - return res.status(400).json({ - error: "Invalid backup manifest", - message: `Too many collections (max ${MAX_IMPORT_COLLECTIONS})`, - }); - } - if (manifest.drawings.length > MAX_IMPORT_DRAWINGS) { - return res.status(400).json({ - error: "Invalid backup manifest", - message: `Too many drawings (max ${MAX_IMPORT_DRAWINGS})`, - }); - } - - const duplicateCollectionId = findFirstDuplicate(manifest.collections.map((c) => c.id)); - if (duplicateCollectionId) { - return res.status(400).json({ - error: "Invalid backup manifest", - message: `Duplicate collection id in manifest: ${duplicateCollectionId}`, - }); - } - const duplicateDrawingId = findFirstDuplicate(manifest.drawings.map((d) => d.id)); - if (duplicateDrawingId) { - return res.status(400).json({ - error: "Invalid backup manifest", - message: `Duplicate drawing id in manifest: ${duplicateDrawingId}`, - }); - } - const duplicateDrawingPath = findFirstDuplicate(manifest.drawings.map((d) => d.filePath)); - if (duplicateDrawingPath) { - return res.status(400).json({ - error: "Invalid backup manifest", - message: `Duplicate drawing file path in manifest: ${duplicateDrawingPath}`, - }); - } - - type PreparedImportDrawing = { - id: string; - name: string; - version: number | undefined; - collectionId: string | null; - sanitized: ReturnType; - }; - const preparedDrawings: PreparedImportDrawing[] = []; - let extractedBytes = Buffer.byteLength(rawManifest, "utf8"); - try { - for (const d of manifest.drawings) { - const entry = getSafeZipEntry(zip, d.filePath); - if (!entry) throw new ImportValidationError(`Missing drawing file: ${d.filePath}`); - - const raw = await entry.async("string"); - const rawSize = Buffer.byteLength(raw, "utf8"); - if (rawSize > MAX_IMPORT_DRAWING_BYTES) { - throw new ImportValidationError(`Drawing is too large: ${d.filePath}`); - } - extractedBytes += rawSize; - if (extractedBytes > MAX_IMPORT_TOTAL_EXTRACTED_BYTES) { - throw new ImportValidationError("Backup contents exceed maximum import size"); - } - - let parsedJson: any; - try { - parsedJson = JSON.parse(raw) as any; - } catch { - throw new ImportValidationError(`Drawing JSON is invalid: ${d.filePath}`); - } - - const imported = { - name: d.name, - elements: Array.isArray(parsedJson?.elements) ? parsedJson.elements : [], - appState: - typeof parsedJson?.appState === "object" && parsedJson.appState !== null - ? parsedJson.appState - : {}, - files: - typeof parsedJson?.files === "object" && parsedJson.files !== null - ? parsedJson.files - : {}, - preview: null as string | null, - collectionId: d.collectionId, - }; - - if (!validateImportedDrawing(imported)) { - throw new ImportValidationError(`Drawing failed validation: ${d.filePath}`); - } - - preparedDrawings.push({ - id: d.id, - name: sanitizeText(imported.name, 255) || "Untitled Drawing", - version: typeof d.version === "number" ? d.version : undefined, - collectionId: d.collectionId, - sanitized: sanitizeDrawingData(imported), - }); - } - } catch (error) { - if (error instanceof ImportValidationError) { - return res.status(error.status).json({ error: "Invalid backup", message: error.message }); - } - throw error; - } - - const result = await prisma.$transaction(async (tx) => { - const trashCollectionId = getUserTrashCollectionId(req.user!.id); - const collectionIdMap = new Map(); - let collectionsCreated = 0; - let collectionsUpdated = 0; - let collectionIdConflicts = 0; - let drawingsCreated = 0; - let drawingsUpdated = 0; - let drawingIdConflicts = 0; - - const needsTrash = - manifest.collections.some((c) => c.id === "trash") || - preparedDrawings.some((d) => d.collectionId === "trash"); - if (needsTrash) await ensureTrashCollection(tx, req.user!.id); - - for (const c of manifest.collections) { - if (c.id === "trash") { - collectionIdMap.set("trash", trashCollectionId); - continue; - } - - const existing = await tx.collection.findUnique({ where: { id: c.id } }); - if (!existing) { - await tx.collection.create({ - data: { id: c.id, name: sanitizeText(c.name, 100) || "Collection", userId: req.user!.id }, - }); - collectionIdMap.set(c.id, c.id); - collectionsCreated += 1; - continue; - } - - if (existing.userId === req.user!.id) { - await tx.collection.update({ - where: { id: c.id }, - data: { name: sanitizeText(c.name, 100) || "Collection" }, - }); - collectionIdMap.set(c.id, c.id); - collectionsUpdated += 1; - continue; - } - - const newId = uuidv4(); - await tx.collection.create({ - data: { id: newId, name: sanitizeText(c.name, 100) || "Collection", userId: req.user!.id }, - }); - collectionIdMap.set(c.id, newId); - collectionsCreated += 1; - collectionIdConflicts += 1; - } - - const resolveCollectionId = (collectionId: string | null): string | null => { - if (!collectionId) return null; - if (collectionId === "trash") return trashCollectionId; - return collectionIdMap.get(collectionId) || null; - }; - - for (const prepared of preparedDrawings) { - const targetCollectionId = resolveCollectionId(prepared.collectionId); - const existing = await tx.drawing.findUnique({ where: { id: prepared.id } }); - if (!existing) { - await tx.drawing.create({ - data: { - id: prepared.id, - name: prepared.name, - elements: JSON.stringify(prepared.sanitized.elements), - appState: JSON.stringify(prepared.sanitized.appState), - files: JSON.stringify(prepared.sanitized.files || {}), - preview: prepared.sanitized.preview ?? null, - version: prepared.version ?? 1, - userId: req.user!.id, - collectionId: targetCollectionId, - }, - }); - drawingsCreated += 1; - continue; - } - - if (existing.userId === req.user!.id) { - await tx.drawing.update({ - where: { id: prepared.id }, - data: { - name: prepared.name, - elements: JSON.stringify(prepared.sanitized.elements), - appState: JSON.stringify(prepared.sanitized.appState), - files: JSON.stringify(prepared.sanitized.files || {}), - preview: prepared.sanitized.preview ?? null, - version: prepared.version ?? existing.version, - collectionId: targetCollectionId, - }, - }); - drawingsUpdated += 1; - continue; - } - - const newId = uuidv4(); - await tx.drawing.create({ - data: { - id: newId, - name: prepared.name, - elements: JSON.stringify(prepared.sanitized.elements), - appState: JSON.stringify(prepared.sanitized.appState), - files: JSON.stringify(prepared.sanitized.files || {}), - preview: prepared.sanitized.preview ?? null, - version: prepared.version ?? 1, - userId: req.user!.id, - collectionId: targetCollectionId, - }, - }); - drawingsCreated += 1; - drawingIdConflicts += 1; - } - - return { - collections: { created: collectionsCreated, updated: collectionsUpdated, idConflicts: collectionIdConflicts }, - drawings: { created: drawingsCreated, updated: drawingsUpdated, idConflicts: drawingIdConflicts }, - }; - }); - - invalidateDrawingsCache(); - return res.json({ success: true, message: "Backup imported successfully", ...result }); - } finally { - await removeFileIfExists(stagedPath); - } - })); - - app.post("/import/sqlite/legacy/verify", requireAuth, upload.single("db"), asyncHandler(async (req, res) => { - if (!req.user) return res.status(401).json({ error: "Unauthorized" }); - if (!req.file) return res.status(400).json({ error: "No file uploaded" }); - - let stagedPath: string; - try { - stagedPath = await resolveSafeUploadedFilePath( - { filename: req.file.filename }, - uploadDir - ); - } catch (error) { - if (error instanceof ImportValidationError) { - return res.status(error.status).json({ error: "Invalid upload", message: error.message }); - } - throw error; - } - try { - const isValid = await verifyDatabaseIntegrityAsync(stagedPath); - if (!isValid) return res.status(400).json({ error: "Invalid database format" }); - - let db: any | null = null; - try { - db = openReadonlySqliteDb(stagedPath); - const tables: string[] = db - .prepare("SELECT name FROM sqlite_master WHERE type='table'") - .all() - .map((row: any) => String(row.name)); - - const drawingTable = findSqliteTable(tables, ["Drawing", "drawings"]); - const collectionTable = findSqliteTable(tables, ["Collection", "collections"]); - if (!drawingTable) { - return res.status(400).json({ error: "Invalid legacy DB", message: "Missing Drawing table" }); - } - - const drawingsCount = Number(db.prepare(`SELECT COUNT(1) as c FROM "${drawingTable}"`).get()?.c ?? 0); - const collectionsCount = collectionTable - ? Number(db.prepare(`SELECT COUNT(1) as c FROM "${collectionTable}"`).get()?.c ?? 0) - : 0; - if (drawingsCount > MAX_IMPORT_DRAWINGS) { - return res.status(400).json({ - error: "Invalid legacy DB", - message: `Too many drawings (max ${MAX_IMPORT_DRAWINGS})`, - }); - } - if (collectionsCount > MAX_IMPORT_COLLECTIONS) { - return res.status(400).json({ - error: "Invalid legacy DB", - message: `Too many collections (max ${MAX_IMPORT_COLLECTIONS})`, - }); - } - - const duplicateDrawingIdRow = db - .prepare( - `SELECT id FROM "${drawingTable}" WHERE id IS NOT NULL GROUP BY id HAVING COUNT(1) > 1 LIMIT 1` - ) - .get(); - if (duplicateDrawingIdRow?.id) { - return res.status(400).json({ - error: "Invalid legacy DB", - message: `Duplicate drawing id in legacy DB: ${String(duplicateDrawingIdRow.id)}`, - }); - } - if (collectionTable) { - const duplicateCollectionIdRow = db - .prepare( - `SELECT id FROM "${collectionTable}" WHERE id IS NOT NULL GROUP BY id HAVING COUNT(1) > 1 LIMIT 1` - ) - .get(); - if (duplicateCollectionIdRow?.id) { - return res.status(400).json({ - error: "Invalid legacy DB", - message: `Duplicate collection id in legacy DB: ${String(duplicateCollectionIdRow.id)}`, - }); - } - } - - let latestMigration: string | null = null; - const migrationsTable = findSqliteTable(tables, ["_prisma_migrations"]); - if (migrationsTable) { - try { - const row = db - .prepare( - `SELECT migration_name as name, finished_at as finishedAt FROM "${migrationsTable}" ORDER BY finished_at DESC LIMIT 1` - ) - .get(); - if (row?.name) latestMigration = String(row.name); - } catch { - latestMigration = null; - } - } - - return res.json({ - valid: true, - drawings: drawingsCount, - collections: collectionsCount, - latestMigration, - currentLatestMigration: await getCurrentLatestPrismaMigrationName(backendRoot), - }); - } catch { - return res.status(500).json({ - error: "Legacy DB support unavailable", - message: - "Failed to open the SQLite database for inspection. If you're on Node < 22, you may need to rebuild native dependencies (e.g. `cd backend && npm rebuild better-sqlite3`).", - }); - } finally { - try { - db?.close?.(); - } catch {} - } - } finally { - await removeFileIfExists(stagedPath); - } - })); - - app.post("/import/sqlite/legacy", requireAuth, upload.single("db"), asyncHandler(async (req, res) => { - if (!req.user) return res.status(401).json({ error: "Unauthorized" }); - if (!req.file) return res.status(400).json({ error: "No file uploaded" }); - - let stagedPath: string; - try { - stagedPath = await resolveSafeUploadedFilePath( - { filename: req.file.filename }, - uploadDir - ); - } catch (error) { - if (error instanceof ImportValidationError) { - return res.status(error.status).json({ error: "Invalid upload", message: error.message }); - } - throw error; - } - try { - const isValid = await verifyDatabaseIntegrityAsync(stagedPath); - if (!isValid) return res.status(400).json({ error: "Invalid database format" }); - - let legacyDb: any | null = null; - try { - legacyDb = openReadonlySqliteDb(stagedPath); - const tables: string[] = legacyDb - .prepare("SELECT name FROM sqlite_master WHERE type='table'") - .all() - .map((row: any) => String(row.name)); - - const drawingTable = findSqliteTable(tables, ["Drawing", "drawings"]); - const collectionTable = findSqliteTable(tables, ["Collection", "collections"]); - if (!drawingTable) { - return res.status(400).json({ error: "Invalid legacy DB", message: "Missing Drawing table" }); - } - - const importedCollections: any[] = collectionTable - ? legacyDb.prepare(`SELECT * FROM "${collectionTable}"`).all() - : []; - const importedDrawings: any[] = legacyDb.prepare(`SELECT * FROM "${drawingTable}"`).all(); - - if (importedCollections.length > MAX_IMPORT_COLLECTIONS) { - return res.status(400).json({ - error: "Invalid legacy DB", - message: `Too many collections (max ${MAX_IMPORT_COLLECTIONS})`, - }); - } - if (importedDrawings.length > MAX_IMPORT_DRAWINGS) { - return res.status(400).json({ - error: "Invalid legacy DB", - message: `Too many drawings (max ${MAX_IMPORT_DRAWINGS})`, - }); - } - - const importedCollectionIds = importedCollections - .map((c) => normalizeNonEmptyId(c?.id)) - .filter((id): id is string => id !== null); - const duplicateCollectionId = findFirstDuplicate(importedCollectionIds); - if (duplicateCollectionId) { - return res.status(400).json({ - error: "Invalid legacy DB", - message: `Duplicate collection id in legacy DB: ${duplicateCollectionId}`, - }); - } - - const importedDrawingIds = importedDrawings - .map((d) => normalizeNonEmptyId(d?.id)) - .filter((id): id is string => id !== null); - const duplicateDrawingId = findFirstDuplicate(importedDrawingIds); - if (duplicateDrawingId) { - return res.status(400).json({ - error: "Invalid legacy DB", - message: `Duplicate drawing id in legacy DB: ${duplicateDrawingId}`, - }); - } - - type PreparedLegacyDrawing = { - importedId: string | null; - name: string; - sanitized: ReturnType; - collectionIdRaw: unknown; - collectionNameRaw: unknown; - versionRaw: unknown; - }; - - const preparedDrawings: PreparedLegacyDrawing[] = []; - for (const d of importedDrawings) { - const importPayload = { - name: typeof d.name === "string" ? d.name : "Untitled Drawing", - elements: parseOptionalJson(d.elements, []), - appState: parseOptionalJson>(d.appState, {}), - files: parseOptionalJson>(d.files, {}), - preview: typeof d.preview === "string" ? d.preview : null, - collectionId: null as string | null, - }; - - if (!validateImportedDrawing(importPayload)) { - return res.status(400).json({ - error: "Invalid imported drawing", - message: "Legacy database contains invalid drawing data", - }); - } - - preparedDrawings.push({ - importedId: typeof d.id === "string" ? d.id : null, - name: sanitizeText(importPayload.name, 255) || "Untitled Drawing", - sanitized: sanitizeDrawingData(importPayload), - collectionIdRaw: d.collectionId, - collectionNameRaw: d.collectionName, - versionRaw: d.version, - }); - } - - const result = await prisma.$transaction(async (tx) => { - const trashCollectionId = getUserTrashCollectionId(req.user!.id); - const hasTrash = importedDrawings.some((d) => String(d.collectionId || "") === "trash"); - if (hasTrash) await ensureTrashCollection(tx, req.user!.id); - - const collectionIdMap = new Map(); - let collectionsCreated = 0; - let collectionsUpdated = 0; - let collectionIdConflicts = 0; - let drawingsCreated = 0; - let drawingsUpdated = 0; - let drawingIdConflicts = 0; - - for (const c of importedCollections) { - const importedId = typeof c.id === "string" ? c.id : null; - const name = typeof c.name === "string" ? c.name : "Collection"; - - if (importedId === "trash" || name === "Trash") { - collectionIdMap.set(importedId || "trash", trashCollectionId); - continue; - } - - if (!importedId) { - const newId = uuidv4(); - await tx.collection.create({ - data: { id: newId, name: sanitizeText(name, 100) || "Collection", userId: req.user!.id }, - }); - collectionIdMap.set(`__name:${name}`, newId); - collectionsCreated += 1; - continue; - } - - const existing = await tx.collection.findUnique({ where: { id: importedId } }); - if (!existing) { - await tx.collection.create({ - data: { id: importedId, name: sanitizeText(name, 100) || "Collection", userId: req.user!.id }, - }); - collectionIdMap.set(importedId, importedId); - collectionsCreated += 1; - continue; - } - if (existing.userId === req.user!.id) { - await tx.collection.update({ - where: { id: importedId }, - data: { name: sanitizeText(name, 100) || "Collection" }, - }); - collectionIdMap.set(importedId, importedId); - collectionsUpdated += 1; - continue; - } - - const newId = uuidv4(); - await tx.collection.create({ - data: { id: newId, name: sanitizeText(name, 100) || "Collection", userId: req.user!.id }, - }); - collectionIdMap.set(importedId, newId); - collectionsCreated += 1; - collectionIdConflicts += 1; - } - - const resolveImportedCollectionId = ( - rawCollectionId: unknown, - rawCollectionName: unknown - ): string | null => { - const id = typeof rawCollectionId === "string" ? rawCollectionId : null; - const name = typeof rawCollectionName === "string" ? rawCollectionName : null; - - if (id === "trash" || name === "Trash") return trashCollectionId; - if (id && collectionIdMap.has(id)) return collectionIdMap.get(id)!; - if (name && collectionIdMap.has(`__name:${name}`)) return collectionIdMap.get(`__name:${name}`)!; - return null; - }; - - for (const d of preparedDrawings) { - const resolvedCollectionId = resolveImportedCollectionId(d.collectionIdRaw, d.collectionNameRaw); - const existing = d.importedId ? await tx.drawing.findUnique({ where: { id: d.importedId } }) : null; - - if (!existing) { - const idToUse = d.importedId || uuidv4(); - await tx.drawing.create({ - data: { - id: idToUse, - name: d.name, - elements: JSON.stringify(d.sanitized.elements), - appState: JSON.stringify(d.sanitized.appState), - files: JSON.stringify(d.sanitized.files || {}), - preview: d.sanitized.preview ?? null, - version: Number.isFinite(Number(d.versionRaw)) ? Number(d.versionRaw) : 1, - userId: req.user!.id, - collectionId: resolvedCollectionId ?? null, - }, - }); - drawingsCreated += 1; - continue; - } - - if (existing.userId === req.user!.id) { - await tx.drawing.update({ - where: { id: existing.id }, - data: { - name: d.name, - elements: JSON.stringify(d.sanitized.elements), - appState: JSON.stringify(d.sanitized.appState), - files: JSON.stringify(d.sanitized.files || {}), - preview: d.sanitized.preview ?? null, - version: Number.isFinite(Number(d.versionRaw)) ? Number(d.versionRaw) : existing.version, - collectionId: resolvedCollectionId ?? null, - }, - }); - drawingsUpdated += 1; - continue; - } - - const newId = uuidv4(); - await tx.drawing.create({ - data: { - id: newId, - name: d.name, - elements: JSON.stringify(d.sanitized.elements), - appState: JSON.stringify(d.sanitized.appState), - files: JSON.stringify(d.sanitized.files || {}), - preview: d.sanitized.preview ?? null, - version: Number.isFinite(Number(d.versionRaw)) ? Number(d.versionRaw) : 1, - userId: req.user!.id, - collectionId: resolvedCollectionId ?? null, - }, - }); - drawingsCreated += 1; - drawingIdConflicts += 1; - } - - return { - collections: { created: collectionsCreated, updated: collectionsUpdated, idConflicts: collectionIdConflicts }, - drawings: { created: drawingsCreated, updated: drawingsUpdated, idConflicts: drawingIdConflicts }, - }; - }); - - invalidateDrawingsCache(); - return res.json({ success: true, ...result }); - } catch { - return res.status(500).json({ - error: "Legacy DB support unavailable", - message: - "Failed to open the SQLite database for import. If you're on Node < 22, you may need to rebuild native dependencies (e.g. `cd backend && npm rebuild better-sqlite3`).", - }); - } finally { - try { - legacyDb?.close?.(); - } catch {} - } - } finally { - await removeFileIfExists(stagedPath); - } - })); -}; +export { registerImportExportRoutes } from "./importExport/index"; +export type { RegisterImportExportDeps } from "./importExport/index"; diff --git a/backend/src/routes/importExport/excalidashImportRoutes.ts b/backend/src/routes/importExport/excalidashImportRoutes.ts new file mode 100644 index 0000000..780cc49 --- /dev/null +++ b/backend/src/routes/importExport/excalidashImportRoutes.ts @@ -0,0 +1,432 @@ +import { promises as fsPromises } from "fs"; +import JSZip from "jszip"; +import { v4 as uuidv4 } from "uuid"; +import { + RegisterImportExportDeps, + ImportValidationError, + assertSafeZipArchive, + excalidashManifestSchemaV1, + findFirstDuplicate, + getSafeZipEntry, + getUserTrashCollectionId, + resolveSafeUploadedFilePath, + sanitizeDrawingData, +} from "./shared"; + +export const registerExcalidashImportRoutes = (deps: RegisterImportExportDeps) => { + const { + app, + prisma, + requireAuth, + asyncHandler, + upload, + uploadDir, + sanitizeText, + validateImportedDrawing, + ensureTrashCollection, + invalidateDrawingsCache, + removeFileIfExists, + MAX_IMPORT_ARCHIVE_ENTRIES, + MAX_IMPORT_COLLECTIONS, + MAX_IMPORT_DRAWINGS, + MAX_IMPORT_MANIFEST_BYTES, + MAX_IMPORT_DRAWING_BYTES, + MAX_IMPORT_TOTAL_EXTRACTED_BYTES, + } = deps; + + app.post("/import/excalidash/verify", requireAuth, upload.single("archive"), asyncHandler(async (req, res) => { + if (!req.user) return res.status(401).json({ error: "Unauthorized" }); + if (!req.file) return res.status(400).json({ error: "No file uploaded" }); + + let stagedPath: string; + try { + stagedPath = await resolveSafeUploadedFilePath( + { filename: req.file.filename }, + uploadDir + ); + } catch (error) { + if (error instanceof ImportValidationError) { + return res.status(error.status).json({ error: "Invalid upload", message: error.message }); + } + throw error; + } + try { + const buffer = await fsPromises.readFile(stagedPath); + const zip = await JSZip.loadAsync(buffer); + try { + assertSafeZipArchive(zip, MAX_IMPORT_ARCHIVE_ENTRIES); + } catch (error) { + if (error instanceof ImportValidationError) { + return res.status(error.status).json({ error: "Invalid backup", message: error.message }); + } + throw error; + } + + const manifestFile = getSafeZipEntry(zip, "excalidash.manifest.json"); + if (!manifestFile) { + return res.status(400).json({ error: "Invalid backup", message: "Missing excalidash.manifest.json" }); + } + const rawManifest = await manifestFile.async("string"); + if (Buffer.byteLength(rawManifest, "utf8") > MAX_IMPORT_MANIFEST_BYTES) { + return res.status(400).json({ + error: "Invalid backup manifest", + message: "excalidash.manifest.json is too large", + }); + } + + let manifestJson: unknown; + try { + manifestJson = JSON.parse(rawManifest); + } catch { + return res.status(400).json({ + error: "Invalid backup manifest", + message: "excalidash.manifest.json is not valid JSON", + }); + } + const parsed = excalidashManifestSchemaV1.safeParse(manifestJson); + if (!parsed.success) { + return res.status(400).json({ + error: "Invalid backup manifest", + message: "Malformed excalidash.manifest.json", + }); + } + const manifest = parsed.data; + if (manifest.collections.length > MAX_IMPORT_COLLECTIONS) { + return res.status(400).json({ + error: "Invalid backup manifest", + message: `Too many collections (max ${MAX_IMPORT_COLLECTIONS})`, + }); + } + if (manifest.drawings.length > MAX_IMPORT_DRAWINGS) { + return res.status(400).json({ + error: "Invalid backup manifest", + message: `Too many drawings (max ${MAX_IMPORT_DRAWINGS})`, + }); + } + + const duplicateCollectionId = findFirstDuplicate(manifest.collections.map((c) => c.id)); + if (duplicateCollectionId) { + return res.status(400).json({ + error: "Invalid backup manifest", + message: `Duplicate collection id in manifest: ${duplicateCollectionId}`, + }); + } + const duplicateDrawingId = findFirstDuplicate(manifest.drawings.map((d) => d.id)); + if (duplicateDrawingId) { + return res.status(400).json({ + error: "Invalid backup manifest", + message: `Duplicate drawing id in manifest: ${duplicateDrawingId}`, + }); + } + const duplicateDrawingPath = findFirstDuplicate(manifest.drawings.map((d) => d.filePath)); + if (duplicateDrawingPath) { + return res.status(400).json({ + error: "Invalid backup manifest", + message: `Duplicate drawing file path in manifest: ${duplicateDrawingPath}`, + }); + } + for (const drawing of manifest.drawings) { + if (!getSafeZipEntry(zip, drawing.filePath)) { + return res.status(400).json({ + error: "Invalid backup", + message: `Missing drawing file: ${drawing.filePath}`, + }); + } + } + + return res.json({ + valid: true, + formatVersion: manifest.formatVersion, + exportedAt: manifest.exportedAt, + excalidashBackendVersion: manifest.excalidashBackendVersion || null, + collections: manifest.collections.length, + drawings: manifest.drawings.length, + }); + } finally { + await removeFileIfExists(stagedPath); + } + })); + + app.post("/import/excalidash", requireAuth, upload.single("archive"), asyncHandler(async (req, res) => { + if (!req.user) return res.status(401).json({ error: "Unauthorized" }); + if (!req.file) return res.status(400).json({ error: "No file uploaded" }); + + let stagedPath: string; + try { + stagedPath = await resolveSafeUploadedFilePath( + { filename: req.file.filename }, + uploadDir + ); + } catch (error) { + if (error instanceof ImportValidationError) { + return res.status(error.status).json({ error: "Invalid upload", message: error.message }); + } + throw error; + } + try { + const buffer = await fsPromises.readFile(stagedPath); + const zip = await JSZip.loadAsync(buffer); + try { + assertSafeZipArchive(zip, MAX_IMPORT_ARCHIVE_ENTRIES); + } catch (error) { + if (error instanceof ImportValidationError) { + return res.status(error.status).json({ error: "Invalid backup", message: error.message }); + } + throw error; + } + + const manifestFile = getSafeZipEntry(zip, "excalidash.manifest.json"); + if (!manifestFile) { + return res.status(400).json({ error: "Invalid backup", message: "Missing excalidash.manifest.json" }); + } + const rawManifest = await manifestFile.async("string"); + if (Buffer.byteLength(rawManifest, "utf8") > MAX_IMPORT_MANIFEST_BYTES) { + return res.status(400).json({ + error: "Invalid backup manifest", + message: "excalidash.manifest.json is too large", + }); + } + + let manifestJson: unknown; + try { + manifestJson = JSON.parse(rawManifest); + } catch { + return res.status(400).json({ + error: "Invalid backup manifest", + message: "excalidash.manifest.json is not valid JSON", + }); + } + const parsed = excalidashManifestSchemaV1.safeParse(manifestJson); + if (!parsed.success) { + return res.status(400).json({ + error: "Invalid backup manifest", + message: "Malformed excalidash.manifest.json", + }); + } + const manifest = parsed.data; + + if (manifest.collections.length > MAX_IMPORT_COLLECTIONS) { + return res.status(400).json({ + error: "Invalid backup manifest", + message: `Too many collections (max ${MAX_IMPORT_COLLECTIONS})`, + }); + } + if (manifest.drawings.length > MAX_IMPORT_DRAWINGS) { + return res.status(400).json({ + error: "Invalid backup manifest", + message: `Too many drawings (max ${MAX_IMPORT_DRAWINGS})`, + }); + } + + const duplicateCollectionId = findFirstDuplicate(manifest.collections.map((c) => c.id)); + if (duplicateCollectionId) { + return res.status(400).json({ + error: "Invalid backup manifest", + message: `Duplicate collection id in manifest: ${duplicateCollectionId}`, + }); + } + const duplicateDrawingId = findFirstDuplicate(manifest.drawings.map((d) => d.id)); + if (duplicateDrawingId) { + return res.status(400).json({ + error: "Invalid backup manifest", + message: `Duplicate drawing id in manifest: ${duplicateDrawingId}`, + }); + } + const duplicateDrawingPath = findFirstDuplicate(manifest.drawings.map((d) => d.filePath)); + if (duplicateDrawingPath) { + return res.status(400).json({ + error: "Invalid backup manifest", + message: `Duplicate drawing file path in manifest: ${duplicateDrawingPath}`, + }); + } + + type PreparedImportDrawing = { + id: string; + name: string; + version: number | undefined; + collectionId: string | null; + sanitized: ReturnType; + }; + const preparedDrawings: PreparedImportDrawing[] = []; + let extractedBytes = Buffer.byteLength(rawManifest, "utf8"); + try { + for (const d of manifest.drawings) { + const entry = getSafeZipEntry(zip, d.filePath); + if (!entry) throw new ImportValidationError(`Missing drawing file: ${d.filePath}`); + + const raw = await entry.async("string"); + const rawSize = Buffer.byteLength(raw, "utf8"); + if (rawSize > MAX_IMPORT_DRAWING_BYTES) { + throw new ImportValidationError(`Drawing is too large: ${d.filePath}`); + } + extractedBytes += rawSize; + if (extractedBytes > MAX_IMPORT_TOTAL_EXTRACTED_BYTES) { + throw new ImportValidationError("Backup contents exceed maximum import size"); + } + + let parsedJson: any; + try { + parsedJson = JSON.parse(raw) as any; + } catch { + throw new ImportValidationError(`Drawing JSON is invalid: ${d.filePath}`); + } + + const imported = { + name: d.name, + elements: Array.isArray(parsedJson?.elements) ? parsedJson.elements : [], + appState: + typeof parsedJson?.appState === "object" && parsedJson.appState !== null + ? parsedJson.appState + : {}, + files: + typeof parsedJson?.files === "object" && parsedJson.files !== null + ? parsedJson.files + : {}, + preview: null as string | null, + collectionId: d.collectionId, + }; + + if (!validateImportedDrawing(imported)) { + throw new ImportValidationError(`Drawing failed validation: ${d.filePath}`); + } + + preparedDrawings.push({ + id: d.id, + name: sanitizeText(imported.name, 255) || "Untitled Drawing", + version: typeof d.version === "number" ? d.version : undefined, + collectionId: d.collectionId, + sanitized: sanitizeDrawingData(imported), + }); + } + } catch (error) { + if (error instanceof ImportValidationError) { + return res.status(error.status).json({ error: "Invalid backup", message: error.message }); + } + throw error; + } + + const result = await prisma.$transaction(async (tx) => { + const trashCollectionId = getUserTrashCollectionId(req.user!.id); + const collectionIdMap = new Map(); + let collectionsCreated = 0; + let collectionsUpdated = 0; + let collectionIdConflicts = 0; + let drawingsCreated = 0; + let drawingsUpdated = 0; + let drawingIdConflicts = 0; + + const needsTrash = + manifest.collections.some((c) => c.id === "trash") || + preparedDrawings.some((d) => d.collectionId === "trash"); + if (needsTrash) await ensureTrashCollection(tx, req.user!.id); + + for (const c of manifest.collections) { + if (c.id === "trash") { + collectionIdMap.set("trash", trashCollectionId); + continue; + } + + const existing = await tx.collection.findUnique({ where: { id: c.id } }); + if (!existing) { + await tx.collection.create({ + data: { id: c.id, name: sanitizeText(c.name, 100) || "Collection", userId: req.user!.id }, + }); + collectionIdMap.set(c.id, c.id); + collectionsCreated += 1; + continue; + } + + if (existing.userId === req.user!.id) { + await tx.collection.update({ + where: { id: c.id }, + data: { name: sanitizeText(c.name, 100) || "Collection" }, + }); + collectionIdMap.set(c.id, c.id); + collectionsUpdated += 1; + continue; + } + + const newId = uuidv4(); + await tx.collection.create({ + data: { id: newId, name: sanitizeText(c.name, 100) || "Collection", userId: req.user!.id }, + }); + collectionIdMap.set(c.id, newId); + collectionsCreated += 1; + collectionIdConflicts += 1; + } + + const resolveCollectionId = (collectionId: string | null): string | null => { + if (!collectionId) return null; + if (collectionId === "trash") return trashCollectionId; + return collectionIdMap.get(collectionId) || null; + }; + + for (const prepared of preparedDrawings) { + const targetCollectionId = resolveCollectionId(prepared.collectionId); + const existing = await tx.drawing.findUnique({ where: { id: prepared.id } }); + if (!existing) { + await tx.drawing.create({ + data: { + id: prepared.id, + name: prepared.name, + elements: JSON.stringify(prepared.sanitized.elements), + appState: JSON.stringify(prepared.sanitized.appState), + files: JSON.stringify(prepared.sanitized.files || {}), + preview: prepared.sanitized.preview ?? null, + version: prepared.version ?? 1, + userId: req.user!.id, + collectionId: targetCollectionId, + }, + }); + drawingsCreated += 1; + continue; + } + + if (existing.userId === req.user!.id) { + await tx.drawing.update({ + where: { id: prepared.id }, + data: { + name: prepared.name, + elements: JSON.stringify(prepared.sanitized.elements), + appState: JSON.stringify(prepared.sanitized.appState), + files: JSON.stringify(prepared.sanitized.files || {}), + preview: prepared.sanitized.preview ?? null, + version: prepared.version ?? existing.version, + collectionId: targetCollectionId, + }, + }); + drawingsUpdated += 1; + continue; + } + + const newId = uuidv4(); + await tx.drawing.create({ + data: { + id: newId, + name: prepared.name, + elements: JSON.stringify(prepared.sanitized.elements), + appState: JSON.stringify(prepared.sanitized.appState), + files: JSON.stringify(prepared.sanitized.files || {}), + preview: prepared.sanitized.preview ?? null, + version: prepared.version ?? 1, + userId: req.user!.id, + collectionId: targetCollectionId, + }, + }); + drawingsCreated += 1; + drawingIdConflicts += 1; + } + + return { + collections: { created: collectionsCreated, updated: collectionsUpdated, idConflicts: collectionIdConflicts }, + drawings: { created: drawingsCreated, updated: drawingsUpdated, idConflicts: drawingIdConflicts }, + }; + }); + + invalidateDrawingsCache(); + return res.json({ success: true, message: "Backup imported successfully", ...result }); + } finally { + await removeFileIfExists(stagedPath); + } + })); +}; diff --git a/backend/src/routes/importExport/exportRoutes.ts b/backend/src/routes/importExport/exportRoutes.ts new file mode 100644 index 0000000..fabf2e3 --- /dev/null +++ b/backend/src/routes/importExport/exportRoutes.ts @@ -0,0 +1,163 @@ +import archiver from "archiver"; +import { Prisma } from "../../generated/client"; +import { + RegisterImportExportDeps, + assertSafeArchivePath, + getUserTrashCollectionId, + isTrashCollectionId, + makeUniqueName, + sanitizePathSegment, + toPublicTrashCollectionId, +} from "./shared"; + +export const registerExcalidashExportRoute = (deps: RegisterImportExportDeps) => { + const { + app, + prisma, + requireAuth, + asyncHandler, + getBackendVersion, + parseJsonField, + } = deps; + + app.get("/export/excalidash", requireAuth, asyncHandler(async (req, res) => { + if (!req.user) return res.status(401).json({ error: "Unauthorized" }); + const trashCollectionId = getUserTrashCollectionId(req.user.id); + + const extParam = typeof req.query.ext === "string" ? req.query.ext.toLowerCase() : ""; + const zipSuffix = extParam === "zip"; + const date = new Date().toISOString().split("T")[0]; + const filename = zipSuffix + ? `excalidash-backup-${date}.excalidash.zip` + : `excalidash-backup-${date}.excalidash`; + + const exportedAt = new Date().toISOString(); + const drawings = await prisma.drawing.findMany({ + where: { userId: req.user.id }, + include: { collection: true }, + }); + const userCollections = await prisma.collection.findMany({ + where: { userId: req.user.id }, + }); + + const hasInternalTrashCollection = userCollections.some((collection) => collection.id === trashCollectionId); + const normalizedUserCollections = userCollections.filter( + (collection) => !(hasInternalTrashCollection && collection.id === "trash") + ); + const hasTrashDrawings = drawings.some((drawing) => + isTrashCollectionId(drawing.collectionId, req.user!.id) + ); + const collectionsToExport = [...normalizedUserCollections]; + if ( + hasTrashDrawings && + !collectionsToExport.some((collection) => + isTrashCollectionId(collection.id, req.user!.id) + ) + ) { + const trash = await prisma.collection.findFirst({ + where: { userId: req.user.id, id: { in: [trashCollectionId, "trash"] } }, + }); + if (trash) collectionsToExport.push(trash); + } + + const exportSource = `${req.protocol}://${req.get("host")}`; + const usedFolderNames = new Set(); + const unorganizedFolder = makeUniqueName("Unorganized", usedFolderNames); + const folderByCollectionId = new Map(); + for (const collection of collectionsToExport) { + const base = sanitizePathSegment(collection.name, "Collection"); + const folder = makeUniqueName(base, usedFolderNames); + folderByCollectionId.set(collection.id, folder); + } + + type DrawingWithCollection = Prisma.DrawingGetPayload<{ include: { collection: true } }>; + const drawingsManifest = drawings.map((drawing: DrawingWithCollection) => { + const folder = drawing.collectionId + ? folderByCollectionId.get(drawing.collectionId) || unorganizedFolder + : unorganizedFolder; + const fileNameBase = sanitizePathSegment(drawing.name, "Untitled"); + const fileName = `${fileNameBase}__${drawing.id.slice(0, 8)}.excalidraw`; + return { + id: drawing.id, + name: drawing.name, + filePath: `${folder}/${fileName}`, + collectionId: toPublicTrashCollectionId(drawing.collectionId, req.user!.id), + version: drawing.version, + createdAt: drawing.createdAt.toISOString(), + updatedAt: drawing.updatedAt.toISOString(), + }; + }); + + const manifestCollections = collectionsToExport + .map((collection) => ({ + id: toPublicTrashCollectionId(collection.id, req.user!.id) || collection.id, + name: isTrashCollectionId(collection.id, req.user!.id) ? "Trash" : collection.name, + folder: folderByCollectionId.get(collection.id) || sanitizePathSegment(collection.name, "Collection"), + createdAt: collection.createdAt.toISOString(), + updatedAt: collection.updatedAt.toISOString(), + })) + .filter((collection, index, all) => all.findIndex((c) => c.id === collection.id) === index); + + const manifest = { + format: "excalidash" as const, + formatVersion: 1 as const, + exportedAt, + excalidashBackendVersion: getBackendVersion(), + userId: req.user.id, + unorganizedFolder, + collections: manifestCollections, + drawings: drawingsManifest, + }; + + res.setHeader("Content-Type", "application/zip"); + res.setHeader("Content-Disposition", `attachment; filename="${filename}"`); + + const archive = archiver("zip", { zlib: { level: 9 } }); + archive.on("error", (err) => { + console.error("Archive error:", err); + res.status(500).json({ error: "Failed to create archive" }); + }); + archive.pipe(res); + + archive.append(JSON.stringify(manifest, null, 2), { name: "excalidash.manifest.json" }); + + const drawingsManifestById = new Map(drawingsManifest.map((d) => [d.id, d])); + for (const drawing of drawings) { + const meta = drawingsManifestById.get(drawing.id); + if (!meta) continue; + const drawingData = { + type: "excalidraw" as const, + version: 2 as const, + source: exportSource, + elements: parseJsonField(drawing.elements, [] as unknown[]), + appState: parseJsonField(drawing.appState, {} as Record), + files: parseJsonField(drawing.files, {} as Record), + excalidash: { + drawingId: drawing.id, + collectionId: drawing.collectionId ?? null, + exportedAt, + }, + }; + assertSafeArchivePath(meta.filePath); + archive.append(JSON.stringify(drawingData, null, 2), { name: meta.filePath }); + } + + const readme = `ExcaliDash Backup (.excalidash) + +This file is a zip archive containing a versioned ExcaliDash manifest and your drawings, +organized into folders by collection. + +Files: +- excalidash.manifest.json (required) +- /*.excalidraw + +ExportedAt: ${exportedAt} +FormatVersion: 1 +BackendVersion: ${getBackendVersion()} +Collections: ${collectionsToExport.length} +Drawings: ${drawings.length} +`; + archive.append(readme, { name: "README.txt" }); + await archive.finalize(); + })); +}; diff --git a/backend/src/routes/importExport/index.ts b/backend/src/routes/importExport/index.ts new file mode 100644 index 0000000..4265f9f --- /dev/null +++ b/backend/src/routes/importExport/index.ts @@ -0,0 +1,12 @@ +import { registerExcalidashImportRoutes } from "./excalidashImportRoutes"; +import { registerExcalidashExportRoute } from "./exportRoutes"; +import { registerLegacySqliteImportRoutes } from "./legacySqliteImportRoutes"; +import { RegisterImportExportDeps } from "./shared"; + +export const registerImportExportRoutes = (deps: RegisterImportExportDeps) => { + registerExcalidashExportRoute(deps); + registerExcalidashImportRoutes(deps); + registerLegacySqliteImportRoutes(deps); +}; + +export type { RegisterImportExportDeps } from "./shared"; diff --git a/backend/src/routes/importExport/legacySqliteImportRoutes.ts b/backend/src/routes/importExport/legacySqliteImportRoutes.ts new file mode 100644 index 0000000..ab509ba --- /dev/null +++ b/backend/src/routes/importExport/legacySqliteImportRoutes.ts @@ -0,0 +1,414 @@ +import { v4 as uuidv4 } from "uuid"; +import { + RegisterImportExportDeps, + ImportValidationError, + findFirstDuplicate, + findSqliteTable, + getCurrentLatestPrismaMigrationName, + getUserTrashCollectionId, + normalizeNonEmptyId, + openReadonlySqliteDb, + parseOptionalJson, + resolveSafeUploadedFilePath, + sanitizeDrawingData, +} from "./shared"; + +export const registerLegacySqliteImportRoutes = (deps: RegisterImportExportDeps) => { + const { + app, + prisma, + requireAuth, + asyncHandler, + upload, + uploadDir, + backendRoot, + sanitizeText, + validateImportedDrawing, + ensureTrashCollection, + invalidateDrawingsCache, + removeFileIfExists, + verifyDatabaseIntegrityAsync, + MAX_IMPORT_COLLECTIONS, + MAX_IMPORT_DRAWINGS, + } = deps; + + app.post("/import/sqlite/legacy/verify", requireAuth, upload.single("db"), asyncHandler(async (req, res) => { + if (!req.user) return res.status(401).json({ error: "Unauthorized" }); + if (!req.file) return res.status(400).json({ error: "No file uploaded" }); + + let stagedPath: string; + try { + stagedPath = await resolveSafeUploadedFilePath( + { filename: req.file.filename }, + uploadDir + ); + } catch (error) { + if (error instanceof ImportValidationError) { + return res.status(error.status).json({ error: "Invalid upload", message: error.message }); + } + throw error; + } + try { + const isValid = await verifyDatabaseIntegrityAsync(stagedPath); + if (!isValid) return res.status(400).json({ error: "Invalid database format" }); + + let db: any | null = null; + try { + db = openReadonlySqliteDb(stagedPath); + const tables: string[] = db + .prepare("SELECT name FROM sqlite_master WHERE type='table'") + .all() + .map((row: any) => String(row.name)); + + const drawingTable = findSqliteTable(tables, ["Drawing", "drawings"]); + const collectionTable = findSqliteTable(tables, ["Collection", "collections"]); + if (!drawingTable) { + return res.status(400).json({ error: "Invalid legacy DB", message: "Missing Drawing table" }); + } + + const drawingsCount = Number(db.prepare(`SELECT COUNT(1) as c FROM "${drawingTable}"`).get()?.c ?? 0); + const collectionsCount = collectionTable + ? Number(db.prepare(`SELECT COUNT(1) as c FROM "${collectionTable}"`).get()?.c ?? 0) + : 0; + if (drawingsCount > MAX_IMPORT_DRAWINGS) { + return res.status(400).json({ + error: "Invalid legacy DB", + message: `Too many drawings (max ${MAX_IMPORT_DRAWINGS})`, + }); + } + if (collectionsCount > MAX_IMPORT_COLLECTIONS) { + return res.status(400).json({ + error: "Invalid legacy DB", + message: `Too many collections (max ${MAX_IMPORT_COLLECTIONS})`, + }); + } + + const duplicateDrawingIdRow = db + .prepare( + `SELECT id FROM "${drawingTable}" WHERE id IS NOT NULL GROUP BY id HAVING COUNT(1) > 1 LIMIT 1` + ) + .get(); + if (duplicateDrawingIdRow?.id) { + return res.status(400).json({ + error: "Invalid legacy DB", + message: `Duplicate drawing id in legacy DB: ${String(duplicateDrawingIdRow.id)}`, + }); + } + if (collectionTable) { + const duplicateCollectionIdRow = db + .prepare( + `SELECT id FROM "${collectionTable}" WHERE id IS NOT NULL GROUP BY id HAVING COUNT(1) > 1 LIMIT 1` + ) + .get(); + if (duplicateCollectionIdRow?.id) { + return res.status(400).json({ + error: "Invalid legacy DB", + message: `Duplicate collection id in legacy DB: ${String(duplicateCollectionIdRow.id)}`, + }); + } + } + + let latestMigration: string | null = null; + const migrationsTable = findSqliteTable(tables, ["_prisma_migrations"]); + if (migrationsTable) { + try { + const row = db + .prepare( + `SELECT migration_name as name, finished_at as finishedAt FROM "${migrationsTable}" ORDER BY finished_at DESC LIMIT 1` + ) + .get(); + if (row?.name) latestMigration = String(row.name); + } catch { + latestMigration = null; + } + } + + return res.json({ + valid: true, + drawings: drawingsCount, + collections: collectionsCount, + latestMigration, + currentLatestMigration: await getCurrentLatestPrismaMigrationName(backendRoot), + }); + } catch { + return res.status(500).json({ + error: "Legacy DB support unavailable", + message: + "Failed to open the SQLite database for inspection. If you're on Node < 22, you may need to rebuild native dependencies (e.g. `cd backend && npm rebuild better-sqlite3`).", + }); + } finally { + try { + db?.close?.(); + } catch {} + } + } finally { + await removeFileIfExists(stagedPath); + } + })); + + app.post("/import/sqlite/legacy", requireAuth, upload.single("db"), asyncHandler(async (req, res) => { + if (!req.user) return res.status(401).json({ error: "Unauthorized" }); + if (!req.file) return res.status(400).json({ error: "No file uploaded" }); + + let stagedPath: string; + try { + stagedPath = await resolveSafeUploadedFilePath( + { filename: req.file.filename }, + uploadDir + ); + } catch (error) { + if (error instanceof ImportValidationError) { + return res.status(error.status).json({ error: "Invalid upload", message: error.message }); + } + throw error; + } + try { + const isValid = await verifyDatabaseIntegrityAsync(stagedPath); + if (!isValid) return res.status(400).json({ error: "Invalid database format" }); + + let legacyDb: any | null = null; + try { + legacyDb = openReadonlySqliteDb(stagedPath); + const tables: string[] = legacyDb + .prepare("SELECT name FROM sqlite_master WHERE type='table'") + .all() + .map((row: any) => String(row.name)); + + const drawingTable = findSqliteTable(tables, ["Drawing", "drawings"]); + const collectionTable = findSqliteTable(tables, ["Collection", "collections"]); + if (!drawingTable) { + return res.status(400).json({ error: "Invalid legacy DB", message: "Missing Drawing table" }); + } + + const importedCollections: any[] = collectionTable + ? legacyDb.prepare(`SELECT * FROM "${collectionTable}"`).all() + : []; + const importedDrawings: any[] = legacyDb.prepare(`SELECT * FROM "${drawingTable}"`).all(); + + if (importedCollections.length > MAX_IMPORT_COLLECTIONS) { + return res.status(400).json({ + error: "Invalid legacy DB", + message: `Too many collections (max ${MAX_IMPORT_COLLECTIONS})`, + }); + } + if (importedDrawings.length > MAX_IMPORT_DRAWINGS) { + return res.status(400).json({ + error: "Invalid legacy DB", + message: `Too many drawings (max ${MAX_IMPORT_DRAWINGS})`, + }); + } + + const importedCollectionIds = importedCollections + .map((c) => normalizeNonEmptyId(c?.id)) + .filter((id): id is string => id !== null); + const duplicateCollectionId = findFirstDuplicate(importedCollectionIds); + if (duplicateCollectionId) { + return res.status(400).json({ + error: "Invalid legacy DB", + message: `Duplicate collection id in legacy DB: ${duplicateCollectionId}`, + }); + } + + const importedDrawingIds = importedDrawings + .map((d) => normalizeNonEmptyId(d?.id)) + .filter((id): id is string => id !== null); + const duplicateDrawingId = findFirstDuplicate(importedDrawingIds); + if (duplicateDrawingId) { + return res.status(400).json({ + error: "Invalid legacy DB", + message: `Duplicate drawing id in legacy DB: ${duplicateDrawingId}`, + }); + } + + type PreparedLegacyDrawing = { + importedId: string | null; + name: string; + sanitized: ReturnType; + collectionIdRaw: unknown; + collectionNameRaw: unknown; + versionRaw: unknown; + }; + + const preparedDrawings: PreparedLegacyDrawing[] = []; + for (const d of importedDrawings) { + const importPayload = { + name: typeof d.name === "string" ? d.name : "Untitled Drawing", + elements: parseOptionalJson(d.elements, []), + appState: parseOptionalJson>(d.appState, {}), + files: parseOptionalJson>(d.files, {}), + preview: typeof d.preview === "string" ? d.preview : null, + collectionId: null as string | null, + }; + + if (!validateImportedDrawing(importPayload)) { + return res.status(400).json({ + error: "Invalid imported drawing", + message: "Legacy database contains invalid drawing data", + }); + } + + preparedDrawings.push({ + importedId: typeof d.id === "string" ? d.id : null, + name: sanitizeText(importPayload.name, 255) || "Untitled Drawing", + sanitized: sanitizeDrawingData(importPayload), + collectionIdRaw: d.collectionId, + collectionNameRaw: d.collectionName, + versionRaw: d.version, + }); + } + + const result = await prisma.$transaction(async (tx) => { + const trashCollectionId = getUserTrashCollectionId(req.user!.id); + const hasTrash = importedDrawings.some((d) => String(d.collectionId || "") === "trash"); + if (hasTrash) await ensureTrashCollection(tx, req.user!.id); + + const collectionIdMap = new Map(); + let collectionsCreated = 0; + let collectionsUpdated = 0; + let collectionIdConflicts = 0; + let drawingsCreated = 0; + let drawingsUpdated = 0; + let drawingIdConflicts = 0; + + for (const c of importedCollections) { + const importedId = typeof c.id === "string" ? c.id : null; + const name = typeof c.name === "string" ? c.name : "Collection"; + + if (importedId === "trash" || name === "Trash") { + collectionIdMap.set(importedId || "trash", trashCollectionId); + continue; + } + + if (!importedId) { + const newId = uuidv4(); + await tx.collection.create({ + data: { id: newId, name: sanitizeText(name, 100) || "Collection", userId: req.user!.id }, + }); + collectionIdMap.set(`__name:${name}`, newId); + collectionsCreated += 1; + continue; + } + + const existing = await tx.collection.findUnique({ where: { id: importedId } }); + if (!existing) { + await tx.collection.create({ + data: { id: importedId, name: sanitizeText(name, 100) || "Collection", userId: req.user!.id }, + }); + collectionIdMap.set(importedId, importedId); + collectionsCreated += 1; + continue; + } + if (existing.userId === req.user!.id) { + await tx.collection.update({ + where: { id: importedId }, + data: { name: sanitizeText(name, 100) || "Collection" }, + }); + collectionIdMap.set(importedId, importedId); + collectionsUpdated += 1; + continue; + } + + const newId = uuidv4(); + await tx.collection.create({ + data: { id: newId, name: sanitizeText(name, 100) || "Collection", userId: req.user!.id }, + }); + collectionIdMap.set(importedId, newId); + collectionsCreated += 1; + collectionIdConflicts += 1; + } + + const resolveImportedCollectionId = ( + rawCollectionId: unknown, + rawCollectionName: unknown + ): string | null => { + const id = typeof rawCollectionId === "string" ? rawCollectionId : null; + const name = typeof rawCollectionName === "string" ? rawCollectionName : null; + + if (id === "trash" || name === "Trash") return trashCollectionId; + if (id && collectionIdMap.has(id)) return collectionIdMap.get(id)!; + if (name && collectionIdMap.has(`__name:${name}`)) return collectionIdMap.get(`__name:${name}`)!; + return null; + }; + + for (const d of preparedDrawings) { + const resolvedCollectionId = resolveImportedCollectionId(d.collectionIdRaw, d.collectionNameRaw); + const existing = d.importedId ? await tx.drawing.findUnique({ where: { id: d.importedId } }) : null; + + if (!existing) { + const idToUse = d.importedId || uuidv4(); + await tx.drawing.create({ + data: { + id: idToUse, + name: d.name, + elements: JSON.stringify(d.sanitized.elements), + appState: JSON.stringify(d.sanitized.appState), + files: JSON.stringify(d.sanitized.files || {}), + preview: d.sanitized.preview ?? null, + version: Number.isFinite(Number(d.versionRaw)) ? Number(d.versionRaw) : 1, + userId: req.user!.id, + collectionId: resolvedCollectionId ?? null, + }, + }); + drawingsCreated += 1; + continue; + } + + if (existing.userId === req.user!.id) { + await tx.drawing.update({ + where: { id: existing.id }, + data: { + name: d.name, + elements: JSON.stringify(d.sanitized.elements), + appState: JSON.stringify(d.sanitized.appState), + files: JSON.stringify(d.sanitized.files || {}), + preview: d.sanitized.preview ?? null, + version: Number.isFinite(Number(d.versionRaw)) ? Number(d.versionRaw) : existing.version, + collectionId: resolvedCollectionId ?? null, + }, + }); + drawingsUpdated += 1; + continue; + } + + const newId = uuidv4(); + await tx.drawing.create({ + data: { + id: newId, + name: d.name, + elements: JSON.stringify(d.sanitized.elements), + appState: JSON.stringify(d.sanitized.appState), + files: JSON.stringify(d.sanitized.files || {}), + preview: d.sanitized.preview ?? null, + version: Number.isFinite(Number(d.versionRaw)) ? Number(d.versionRaw) : 1, + userId: req.user!.id, + collectionId: resolvedCollectionId ?? null, + }, + }); + drawingsCreated += 1; + drawingIdConflicts += 1; + } + + return { + collections: { created: collectionsCreated, updated: collectionsUpdated, idConflicts: collectionIdConflicts }, + drawings: { created: drawingsCreated, updated: drawingsUpdated, idConflicts: drawingIdConflicts }, + }; + }); + + invalidateDrawingsCache(); + return res.json({ success: true, ...result }); + } catch { + return res.status(500).json({ + error: "Legacy DB support unavailable", + message: + "Failed to open the SQLite database for import. If you're on Node < 22, you may need to rebuild native dependencies (e.g. `cd backend && npm rebuild better-sqlite3`).", + }); + } finally { + try { + legacyDb?.close?.(); + } catch {} + } + } finally { + await removeFileIfExists(stagedPath); + } + })); +}; diff --git a/backend/src/routes/importExport/shared.ts b/backend/src/routes/importExport/shared.ts new file mode 100644 index 0000000..26e7cf1 --- /dev/null +++ b/backend/src/routes/importExport/shared.ts @@ -0,0 +1,255 @@ +import express from "express"; +import path from "path"; +import { promises as fsPromises } from "fs"; +import JSZip from "jszip"; +import { z } from "zod"; +import { Prisma, PrismaClient } from "../../generated/client"; +import { sanitizeDrawingData } from "../../security"; + +export class ImportValidationError extends Error { + status: number; + + constructor(message: string, status = 400) { + super(message); + this.name = "ImportValidationError"; + this.status = status; + } +} + +export const excalidashManifestSchemaV1 = z.object({ + format: z.literal("excalidash"), + formatVersion: z.literal(1), + exportedAt: z.string().min(1), + excalidashBackendVersion: z.string().optional(), + userId: z.string().optional(), + unorganizedFolder: z.string().min(1), + collections: z.array( + z.object({ + id: z.string().min(1), + name: z.string(), + folder: z.string().min(1), + createdAt: z.string().optional(), + updatedAt: z.string().optional(), + }) + ), + drawings: z.array( + z.object({ + id: z.string().min(1), + name: z.string(), + filePath: z.string().min(1), + collectionId: z.string().nullable(), + version: z.number().int().optional(), + createdAt: z.string().optional(), + updatedAt: z.string().optional(), + }) + ), +}); + +export type RegisterImportExportDeps = { + app: express.Express; + prisma: PrismaClient; + requireAuth: express.RequestHandler; + asyncHandler: ( + fn: (req: express.Request, res: express.Response, next: express.NextFunction) => Promise + ) => express.RequestHandler; + upload: any; + uploadDir: string; + backendRoot: string; + getBackendVersion: () => string; + parseJsonField: (rawValue: string | null | undefined, fallback: T) => T; + sanitizeText: (input: unknown, maxLength?: number) => string; + validateImportedDrawing: (data: unknown) => boolean; + ensureTrashCollection: ( + db: Prisma.TransactionClient | PrismaClient, + userId: string + ) => Promise; + invalidateDrawingsCache: () => void; + removeFileIfExists: (filePath?: string) => Promise; + verifyDatabaseIntegrityAsync: (filePath: string) => Promise; + MAX_IMPORT_ARCHIVE_ENTRIES: number; + MAX_IMPORT_COLLECTIONS: number; + MAX_IMPORT_DRAWINGS: number; + MAX_IMPORT_MANIFEST_BYTES: number; + MAX_IMPORT_DRAWING_BYTES: number; + MAX_IMPORT_TOTAL_EXTRACTED_BYTES: number; +}; + +const getZipEntries = (zip: JSZip) => Object.values(zip.files).filter((entry) => !entry.dir); + +export const normalizeArchivePath = (filePath: string): string => + path.posix.normalize(filePath.replace(/\\/g, "/")); + +export const assertSafeArchivePath = (filePath: string) => { + const normalized = normalizeArchivePath(filePath); + if ( + normalized.length === 0 || + path.posix.isAbsolute(normalized) || + normalized === ".." || + normalized.startsWith("../") || + normalized.includes("\0") + ) { + throw new ImportValidationError(`Unsafe archive path: ${filePath}`); + } +}; + +export const assertSafeZipArchive = (zip: JSZip, maxEntries: number) => { + const entries = getZipEntries(zip); + if (entries.length > maxEntries) { + throw new ImportValidationError("Archive contains too many files"); + } + for (const entry of entries) { + assertSafeArchivePath(entry.name); + } +}; + +export const getSafeZipEntry = (zip: JSZip, filePath: string) => { + const normalizedPath = normalizeArchivePath(filePath); + assertSafeArchivePath(normalizedPath); + return zip.file(normalizedPath); +}; + +export const sanitizePathSegment = (input: string, fallback: string): string => { + const value = typeof input === "string" ? input.trim() : ""; + const cleaned = value + .replace(/[<>:"/\\|?*\x00-\x1F]/g, "_") + .replace(/\s+/g, " ") + .slice(0, 120) + .trim(); + return cleaned.length > 0 ? cleaned : fallback; +}; + +export const makeUniqueName = (base: string, used: Set): string => { + let candidate = base; + let n = 2; + while (used.has(candidate)) { + candidate = `${base}__${n}`; + n += 1; + } + used.add(candidate); + return candidate; +}; + +export const findFirstDuplicate = (values: string[]): string | null => { + const seen = new Set(); + for (const value of values) { + if (seen.has(value)) return value; + seen.add(value); + } + return null; +}; + +export const normalizeNonEmptyId = (value: unknown): string | null => { + if (typeof value !== "string") return null; + const trimmed = value.trim(); + return trimmed.length > 0 ? trimmed : null; +}; + +export const getUserTrashCollectionId = (userId: string): string => `trash:${userId}`; + +export const isTrashCollectionId = ( + collectionId: string | null | undefined, + userId: string +): boolean => + Boolean(collectionId) && + (collectionId === "trash" || collectionId === getUserTrashCollectionId(userId)); + +export const toPublicTrashCollectionId = ( + collectionId: string | null | undefined, + userId: string +): string | null => + isTrashCollectionId(collectionId, userId) ? "trash" : collectionId ?? null; + +export const findSqliteTable = (tables: string[], candidates: string[]): string | null => { + const byLower = new Map(tables.map((t) => [t.toLowerCase(), t])); + for (const candidate of candidates) { + const found = byLower.get(candidate.toLowerCase()); + if (found) return found; + } + return null; +}; + +export const parseOptionalJson = (raw: unknown, fallback: T): T => { + if (typeof raw === "string") { + try { + return JSON.parse(raw) as T; + } catch { + return fallback; + } + } + if (typeof raw === "object" && raw !== null) { + return raw as T; + } + return fallback; +}; + +const isPathInsideDirectory = (candidatePath: string, rootDir: string): boolean => { + const relativePath = path.relative(rootDir, candidatePath); + return ( + relativePath === "" || + (!relativePath.startsWith("..") && !path.isAbsolute(relativePath)) + ); +}; + +const isSafeMulterTempFilename = (value: string): boolean => + /^[a-f0-9]{32}$/.test(value); + +export const resolveSafeUploadedFilePath = async ( + fileMeta: { filename?: unknown }, + uploadRoot: string +): Promise => { + const absoluteUploadRoot = path.resolve(uploadRoot); + let canonicalUploadRoot = absoluteUploadRoot; + + try { + canonicalUploadRoot = await fsPromises.realpath(absoluteUploadRoot); + } catch { + throw new ImportValidationError("Invalid upload path"); + } + + const filename = typeof fileMeta.filename === "string" ? fileMeta.filename : ""; + if (!isSafeMulterTempFilename(filename)) { + throw new ImportValidationError("Invalid upload path"); + } + + const joinedPath = path.resolve(canonicalUploadRoot, filename); + if (!isPathInsideDirectory(joinedPath, canonicalUploadRoot)) { + throw new ImportValidationError("Invalid upload path"); + } + + return joinedPath; +}; + +export const openReadonlySqliteDb = (filePath: string): any => { + try { + // eslint-disable-next-line @typescript-eslint/no-var-requires + const { DatabaseSync } = require("node:sqlite") as any; + return new DatabaseSync(filePath, { + readOnly: true, + enableForeignKeyConstraints: false, + }); + } catch { + // eslint-disable-next-line @typescript-eslint/no-var-requires + const Database = require("better-sqlite3") as any; + return new Database(filePath, { readonly: true, fileMustExist: true }); + } +}; + +export const getCurrentLatestPrismaMigrationName = async ( + backendRoot: string +): Promise => { + try { + const migrationsDir = path.resolve(backendRoot, "prisma/migrations"); + const entries = await fsPromises.readdir(migrationsDir, { withFileTypes: true }); + const dirs = entries + .filter((e) => e.isDirectory()) + .map((e) => e.name) + .filter((name) => !name.startsWith(".")); + if (dirs.length === 0) return null; + dirs.sort(); + return dirs[dirs.length - 1] || null; + } catch { + return null; + } +}; + +export { sanitizeDrawingData }; diff --git a/backend/src/security/csrfClient.ts b/backend/src/security/csrfClient.ts new file mode 100644 index 0000000..3cc7da5 --- /dev/null +++ b/backend/src/security/csrfClient.ts @@ -0,0 +1,47 @@ +import { Request } from "express"; + +export const CSRF_CLIENT_COOKIE_NAME = "excalidash-csrf-client"; + +export const parseCookies = (cookieHeader: string | undefined): Record => { + if (!cookieHeader) return {}; + const cookies: Record = {}; + for (const part of cookieHeader.split(";")) { + const [rawKey, ...rawValueParts] = part.split("="); + const key = rawKey?.trim(); + if (!key) continue; + const rawValue = rawValueParts.join("=").trim(); + try { + cookies[key] = decodeURIComponent(rawValue); + } catch { + cookies[key] = rawValue; + } + } + return cookies; +}; + +export const getCsrfClientCookieValue = (req: Request): string | null => { + const cookies = parseCookies(req.headers.cookie); + const value = cookies[CSRF_CLIENT_COOKIE_NAME]; + if (!value) return null; + if (!/^[A-Za-z0-9_-]{16,128}$/.test(value)) return null; + return value; +}; + +export const getLegacyClientId = (req: Request): string => { + const ip = req.ip || req.connection.remoteAddress || "unknown"; + const userAgent = req.headers["user-agent"] || "unknown"; + return `${ip}:${userAgent}`.slice(0, 256); +}; + +export const getCsrfValidationClientIds = (req: Request): string[] => { + const candidates: string[] = []; + const cookieValue = getCsrfClientCookieValue(req); + if (cookieValue) { + candidates.push(`cookie:${cookieValue}`); + } + const legacyClientId = getLegacyClientId(req); + if (!candidates.includes(legacyClientId)) { + candidates.push(legacyClientId); + } + return candidates; +}; diff --git a/backend/src/server/csrf.ts b/backend/src/server/csrf.ts new file mode 100644 index 0000000..9b09360 --- /dev/null +++ b/backend/src/server/csrf.ts @@ -0,0 +1,201 @@ +import express from "express"; +import crypto from "crypto"; +import { + createCsrfToken, + getCsrfTokenHeader, + getOriginFromReferer, + validateCsrfToken, +} from "../security"; +import { + CSRF_CLIENT_COOKIE_NAME, + getCsrfClientCookieValue, + getCsrfValidationClientIds, + getLegacyClientId, +} from "../security/csrfClient"; + +const CSRF_CLIENT_COOKIE_MAX_AGE_SECONDS = 60 * 60 * 24 * 30; // 30 days +const CSRF_RATE_LIMIT_WINDOW = 60 * 1000; // 1 minute + +type RegisterCsrfProtectionDeps = { + app: express.Express; + isAllowedOrigin: (origin?: string) => boolean; + maxRequestsPerWindow: number; + enableDebugLogging?: boolean; +}; + +export const registerCsrfProtection = ({ + app, + isAllowedOrigin, + maxRequestsPerWindow, + enableDebugLogging, +}: RegisterCsrfProtectionDeps) => { + const requestUsesHttps = (req: express.Request): boolean => { + if (req.secure) return true; + const forwardedProto = req.headers["x-forwarded-proto"]; + const raw = Array.isArray(forwardedProto) ? forwardedProto[0] : forwardedProto; + const firstHop = String(raw || "") + .split(",")[0] + .trim() + .toLowerCase(); + return firstHop === "https"; + }; + + const setCsrfClientCookie = (req: express.Request, res: express.Response, value: string): void => { + const secure = requestUsesHttps(req) ? "; Secure" : ""; + res.append( + "Set-Cookie", + `${CSRF_CLIENT_COOKIE_NAME}=${encodeURIComponent( + value + )}; Path=/; HttpOnly; SameSite=Lax; Max-Age=${CSRF_CLIENT_COOKIE_MAX_AGE_SECONDS}${secure}` + ); + }; + + const getClientIdForTokenIssue = ( + req: express.Request, + res: express.Response + ): { clientId: string; strategy: "cookie" | "legacy-bootstrap" } => { + const existingCookieValue = getCsrfClientCookieValue(req); + if (existingCookieValue) { + return { + clientId: `cookie:${existingCookieValue}`, + strategy: "cookie", + }; + } + + const generatedCookieValue = crypto.randomUUID().replace(/-/g, ""); + setCsrfClientCookie(req, res, generatedCookieValue); + return { + clientId: getLegacyClientId(req), + strategy: "legacy-bootstrap", + }; + }; + + const getClientIdForTokenIssueDebug = ( + req: express.Request, + res: express.Response + ): string => { + const { clientId, strategy } = getClientIdForTokenIssue(req, res); + + if (enableDebugLogging) { + const validationCandidates = getCsrfValidationClientIds(req); + const ip = req.ip || req.connection.remoteAddress || "unknown"; + console.log("[CSRF DEBUG] getClientId", { + method: req.method, + path: req.path, + ip, + remoteAddress: req.connection.remoteAddress, + "x-forwarded-for": req.headers["x-forwarded-for"], + "x-real-ip": req.headers["x-real-ip"], + hasCsrfCookie: Boolean(getCsrfClientCookieValue(req)), + clientIdPreview: clientId.slice(0, 60) + "...", + trustProxySetting: req.app.get("trust proxy"), + strategy, + validationCandidatesPreview: validationCandidates.map((candidate) => + `${candidate.slice(0, 60)}...` + ), + }); + } + + return clientId; + }; + + const csrfRateLimit = new Map(); + let csrfCleanupCounter = 0; + + app.get("/csrf-token", (req, res) => { + const ip = req.ip || req.connection.remoteAddress || "unknown"; + const now = Date.now(); + const clientLimit = csrfRateLimit.get(ip); + + if (clientLimit && now < clientLimit.resetTime) { + if (clientLimit.count >= maxRequestsPerWindow) { + return res.status(429).json({ + error: "Rate limit exceeded", + message: "Too many CSRF token requests", + }); + } + clientLimit.count++; + } else { + csrfRateLimit.set(ip, { count: 1, resetTime: now + CSRF_RATE_LIMIT_WINDOW }); + } + + csrfCleanupCounter += 1; + if (csrfCleanupCounter % 100 === 0) { + for (const [key, data] of csrfRateLimit.entries()) { + if (now > data.resetTime) csrfRateLimit.delete(key); + } + } + + const clientId = getClientIdForTokenIssueDebug(req, res); + const token = createCsrfToken(clientId); + + res.json({ + token, + header: getCsrfTokenHeader(), + }); + }); + + const csrfProtectionMiddleware = ( + req: express.Request, + res: express.Response, + next: express.NextFunction + ) => { + const safeMethods = ["GET", "HEAD", "OPTIONS"]; + if (safeMethods.includes(req.method)) { + return next(); + } + + const origin = req.headers["origin"]; + const referer = req.headers["referer"]; + const originValue = Array.isArray(origin) ? origin[0] : origin; + const refererValue = Array.isArray(referer) ? referer[0] : referer; + + if (originValue) { + if (!isAllowedOrigin(originValue)) { + return res.status(403).json({ + error: "CSRF origin mismatch", + message: "Origin not allowed", + }); + } + } else if (refererValue) { + const refererOrigin = getOriginFromReferer(refererValue); + if (!refererOrigin || !isAllowedOrigin(refererOrigin)) { + return res.status(403).json({ + error: "CSRF referer mismatch", + message: "Referer not allowed", + }); + } + } + + const clientIdCandidates = getCsrfValidationClientIds(req); + const headerName = getCsrfTokenHeader(); + const tokenHeader = req.headers[headerName]; + const token = Array.isArray(tokenHeader) ? tokenHeader[0] : tokenHeader; + + if (!token) { + return res.status(403).json({ + error: "CSRF token missing", + message: `Missing ${headerName} header`, + }); + } + + const isValidToken = clientIdCandidates.some((clientId) => + validateCsrfToken(clientId, token) + ); + if (!isValidToken) { + return res.status(403).json({ + error: "CSRF token invalid", + message: "Invalid or expired CSRF token. Please refresh and try again.", + }); + } + + next(); + }; + + app.use((req, res, next) => { + if (req.path.startsWith("/auth/")) { + return next(); + } + csrfProtectionMiddleware(req, res, next); + }); +}; diff --git a/backend/src/server/drawingsCache.ts b/backend/src/server/drawingsCache.ts new file mode 100644 index 0000000..98582d5 --- /dev/null +++ b/backend/src/server/drawingsCache.ts @@ -0,0 +1,63 @@ +type DrawingsCacheEntry = { body: Buffer; expiresAt: number }; + +export type DrawingsCacheKeyParts = { + userId: string; + searchTerm: string; + collectionFilter: string; + includeData: boolean; + sortField: "name" | "createdAt" | "updatedAt"; + sortDirection: "asc" | "desc"; +}; + +export const createDrawingsCacheStore = (ttlMs: number) => { + const drawingsCache = new Map(); + + const buildDrawingsCacheKey = (keyParts: DrawingsCacheKeyParts) => + JSON.stringify([ + keyParts.userId, + keyParts.searchTerm, + keyParts.collectionFilter, + keyParts.includeData ? "full" : "summary", + keyParts.sortField, + keyParts.sortDirection, + ]); + + const getCachedDrawingsBody = (key: string): Buffer | null => { + const entry = drawingsCache.get(key); + if (!entry) return null; + if (Date.now() > entry.expiresAt) { + drawingsCache.delete(key); + return null; + } + return entry.body; + }; + + const cacheDrawingsResponse = (key: string, payload: unknown): Buffer => { + const body = Buffer.from(JSON.stringify(payload)); + drawingsCache.set(key, { + body, + expiresAt: Date.now() + ttlMs, + }); + return body; + }; + + const invalidateDrawingsCache = () => { + drawingsCache.clear(); + }; + + setInterval(() => { + const now = Date.now(); + for (const [key, entry] of drawingsCache.entries()) { + if (now > entry.expiresAt) { + drawingsCache.delete(key); + } + } + }, 60_000).unref(); + + return { + buildDrawingsCacheKey, + getCachedDrawingsBody, + cacheDrawingsResponse, + invalidateDrawingsCache, + }; +}; diff --git a/backend/src/server/socket.ts b/backend/src/server/socket.ts new file mode 100644 index 0000000..42f2d87 --- /dev/null +++ b/backend/src/server/socket.ts @@ -0,0 +1,221 @@ +import jwt from "jsonwebtoken"; +import { Server } from "socket.io"; +import { PrismaClient } from "../generated/client"; +import { AuthModeService } from "../auth/authMode"; + +interface User { + id: string; + name: string; + initials: string; + color: string; + socketId: string; + isActive: boolean; +} + +type RegisterSocketHandlersDeps = { + io: Server; + prisma: PrismaClient; + authModeService: AuthModeService; + jwtSecret: string; +}; + +export const registerSocketHandlers = ({ + io, + prisma, + authModeService, + jwtSecret, +}: RegisterSocketHandlersDeps) => { + const roomUsers = new Map(); + const socketUserMap = new Map(); + + const toPresenceName = (value: unknown): string => { + if (typeof value !== "string") return "User"; + const trimmed = value.trim().slice(0, 120); + return trimmed.length > 0 ? trimmed : "User"; + }; + + const toPresenceInitials = (name: string): string => { + const words = name + .split(/\s+/) + .map((part) => part.trim()) + .filter((part) => part.length > 0); + if (words.length === 0) return "U"; + const first = words[0]?.[0] ?? ""; + const second = words.length > 1 ? words[1]?.[0] ?? "" : ""; + const initials = `${first}${second}`.toUpperCase().slice(0, 2); + return initials.length > 0 ? initials : "U"; + }; + + const toPresenceColor = (value: unknown): string => { + if (typeof value !== "string") return "#4f46e5"; + const trimmed = value.trim(); + if (/^#[0-9a-fA-F]{3,8}$/.test(trimmed)) { + return trimmed; + } + return "#4f46e5"; + }; + + const getSocketAuthUserId = async (token?: string): Promise => { + const authEnabled = await authModeService.getAuthEnabled(); + if (!authEnabled) { + return "bootstrap-admin"; + } + + if (!token) return null; + + try { + const decoded = jwt.verify(token, jwtSecret) as Record; + if ( + typeof decoded.userId !== "string" || + typeof decoded.email !== "string" || + decoded.type !== "access" + ) { + return null; + } + + const user = await prisma.user.findUnique({ + where: { id: decoded.userId }, + select: { id: true, isActive: true }, + }); + + if (!user || !user.isActive) return null; + return user.id; + } catch { + return null; + } + }; + + io.use(async (socket, next) => { + try { + const token = socket.handshake.auth?.token as string | undefined; + const userId = await getSocketAuthUserId(token); + + if (!userId) { + return next(new Error("Authentication required")); + } + + socketUserMap.set(socket.id, userId); + next(); + } catch { + next(new Error("Authentication failed")); + } + }); + + io.on("connection", (socket) => { + const authenticatedUserId = socketUserMap.get(socket.id); + const authorizedDrawingIds = new Set(); + + socket.on( + "join-room", + async ({ + drawingId, + user, + }: { + drawingId: string; + user: Omit; + }) => { + try { + if (authenticatedUserId) { + const drawing = await prisma.drawing.findFirst({ + where: { id: drawingId, userId: authenticatedUserId }, + select: { id: true }, + }); + + if (!drawing) { + socket.emit("error", { message: "You do not have access to this drawing" }); + return; + } + } + + const roomId = `drawing_${drawingId}`; + socket.join(roomId); + authorizedDrawingIds.add(drawingId); + + let trustedUserId = + typeof user?.id === "string" && user.id.trim().length > 0 + ? user.id.trim().slice(0, 200) + : socket.id; + let trustedName = toPresenceName(user?.name); + + if (authenticatedUserId && authenticatedUserId !== "bootstrap-admin") { + const account = await prisma.user.findUnique({ + where: { id: authenticatedUserId }, + select: { id: true, name: true }, + }); + if (account) { + trustedUserId = account.id; + trustedName = toPresenceName(account.name); + } + } + + const newUser: User = { + id: trustedUserId, + name: trustedName, + initials: toPresenceInitials(trustedName), + color: toPresenceColor(user?.color), + socketId: socket.id, + isActive: true, + }; + + const currentUsers = roomUsers.get(roomId) || []; + const filteredUsers = currentUsers.filter((u) => u.id !== newUser.id); + filteredUsers.push(newUser); + roomUsers.set(roomId, filteredUsers); + + io.to(roomId).emit("presence-update", filteredUsers); + } catch (err) { + console.error("Error in join-room handler:", err); + socket.emit("error", { message: "Failed to join room" }); + } + } + ); + + socket.on("cursor-move", (data) => { + const drawingId = typeof data?.drawingId === "string" ? data.drawingId : null; + if (!drawingId || !authorizedDrawingIds.has(drawingId)) { + return; + } + const roomId = `drawing_${drawingId}`; + socket.volatile.to(roomId).emit("cursor-move", data); + }); + + socket.on("element-update", (data) => { + const drawingId = typeof data?.drawingId === "string" ? data.drawingId : null; + if (!drawingId || !authorizedDrawingIds.has(drawingId)) { + return; + } + const roomId = `drawing_${drawingId}`; + socket.to(roomId).emit("element-update", data); + }); + + socket.on( + "user-activity", + ({ drawingId, isActive }: { drawingId: string; isActive: boolean }) => { + if (!authorizedDrawingIds.has(drawingId)) { + return; + } + const roomId = `drawing_${drawingId}`; + const users = roomUsers.get(roomId); + if (users) { + const user = users.find((u) => u.socketId === socket.id); + if (user) { + user.isActive = isActive; + io.to(roomId).emit("presence-update", users); + } + } + } + ); + + socket.on("disconnect", () => { + socketUserMap.delete(socket.id); + roomUsers.forEach((users, roomId) => { + const index = users.findIndex((u) => u.socketId === socket.id); + if (index !== -1) { + users.splice(index, 1); + roomUsers.set(roomId, users); + io.to(roomId).emit("presence-update", users); + } + }); + }); + }); +}; diff --git a/backend/src/utils/__tests__/audit.test.ts b/backend/src/utils/__tests__/audit.test.ts index 7d20f3a..25f9cb4 100644 --- a/backend/src/utils/__tests__/audit.test.ts +++ b/backend/src/utils/__tests__/audit.test.ts @@ -7,7 +7,12 @@ import { describe, it, expect, beforeAll, afterAll, beforeEach } from "vitest"; import { getTestPrisma, setupTestDb, initTestDb, createTestUser } from "../../__tests__/testUtils"; -import { logAuditEvent, getAuditLogs, type AuditLogData } from "../audit"; +import { + logAuditEvent, + getAuditLogs, + setAuditPrismaProvider, + type AuditLogData, +} from "../audit"; describe("Audit Logging", () => { const prisma = getTestPrisma(); @@ -16,11 +21,13 @@ describe("Audit Logging", () => { beforeAll(async () => { setupTestDb(); testUser = await initTestDb(prisma); + setAuditPrismaProvider(() => prisma); // Enable audit logging for tests process.env.ENABLE_AUDIT_LOGGING = "true"; }); afterAll(async () => { + setAuditPrismaProvider(null); await prisma.$disconnect(); delete process.env.ENABLE_AUDIT_LOGGING; }); diff --git a/backend/src/utils/audit.ts b/backend/src/utils/audit.ts index 888af2b..ae52228 100644 --- a/backend/src/utils/audit.ts +++ b/backend/src/utils/audit.ts @@ -1,13 +1,12 @@ /** * Audit logging utility for security events */ -import { PrismaClient } from "../generated/client"; +import { prisma } from "../db/prisma"; -let prisma: PrismaClient | null = null; -const getPrisma = () => { - if (prisma) return prisma; - prisma = new PrismaClient(); - return prisma; +let prismaProvider: () => typeof prisma = () => prisma; + +export const setAuditPrismaProvider = (provider: (() => typeof prisma) | null): void => { + prismaProvider = provider ?? (() => prisma); }; export interface AuditLogData { @@ -44,7 +43,7 @@ export const logAuditEvent = async (data: AuditLogData): Promise => { return; // Feature disabled, silently skip } - await getPrisma().auditLog.create({ + await prismaProvider().auditLog.create({ data: { userId: data.userId || null, action: data.action, @@ -79,7 +78,7 @@ export const getAuditLogs = async ( return []; // Feature disabled, return empty array } - const logs = await getPrisma().auditLog.findMany({ + const logs = await prismaProvider().auditLog.findMany({ where: userId ? { userId } : undefined, orderBy: { createdAt: "desc" }, take: limit, diff --git a/frontend/src/api/index.ts b/frontend/src/api/index.ts index 647dba8..720a280 100644 --- a/frontend/src/api/index.ts +++ b/frontend/src/api/index.ts @@ -73,6 +73,84 @@ export const clearCsrfToken = (): void => { csrfToken = null; }; +export interface AuthStatusResponse { + authEnabled?: boolean; + enabled?: boolean; + bootstrapRequired?: boolean; +} + +export interface AuthUser { + id: string; + username?: string | null; + email: string; + name: string; + role?: string; + mustResetPassword?: boolean; +} + +export const authStatus = async (): Promise => { + const response = await axios.get( + `${API_URL}/auth/status`, + { withCredentials: true } + ); + return response.data; +}; + +export const authMe = async (accessToken: string): Promise<{ user: AuthUser }> => { + const response = await axios.get<{ user: AuthUser }>(`${API_URL}/auth/me`, { + headers: { Authorization: `Bearer ${accessToken}` }, + withCredentials: true, + }); + return response.data; +}; + +export const authRefresh = async ( + refreshToken: string +): Promise<{ accessToken: string; refreshToken?: string }> => { + const response = await axios.post<{ accessToken: string; refreshToken?: string }>( + `${API_URL}/auth/refresh`, + { refreshToken }, + { withCredentials: true } + ); + return response.data; +}; + +export const authLogin = async ( + email: string, + password: string +): Promise<{ user: AuthUser; accessToken: string; refreshToken: string }> => { + const response = await axios.post<{ user: AuthUser; accessToken: string; refreshToken: string }>( + `${API_URL}/auth/login`, + { email, password }, + { withCredentials: true } + ); + return response.data; +}; + +export const authRegister = async ( + email: string, + password: string, + name: string +): Promise<{ user: AuthUser; accessToken: string; refreshToken: string }> => { + const response = await axios.post<{ user: AuthUser; accessToken: string; refreshToken: string }>( + `${API_URL}/auth/register`, + { email, password, name }, + { withCredentials: true } + ); + return response.data; +}; + +export const authPasswordResetConfirm = async ( + token: string, + password: string +): Promise => { + await axios.post( + `${API_URL}/auth/password-reset-confirm`, + { token, password }, + { withCredentials: true } + ); +}; + const clearStoredAuth = () => { localStorage.removeItem(TOKEN_KEY); localStorage.removeItem(REFRESH_TOKEN_KEY); @@ -100,15 +178,12 @@ const getAuthEnabledStatus = async (): Promise => { } try { - const response = await axios.get<{ authEnabled?: boolean; enabled?: boolean }>( - `${API_URL}/auth/status`, - { withCredentials: true } - ); + const response = await authStatus(); const enabled = - typeof response.data?.authEnabled === "boolean" - ? response.data.authEnabled - : typeof response.data?.enabled === "boolean" - ? response.data.enabled + typeof response?.authEnabled === "boolean" + ? response.authEnabled + : typeof response?.enabled === "boolean" + ? response.enabled : true; cacheAuthEnabled(enabled); return enabled; @@ -135,22 +210,16 @@ const refreshAccessToken = async (): Promise => { throw new Error("Missing refresh token"); } - const refreshResponse = await axios.post( - `${API_URL}/auth/refresh`, - { - refreshToken, - }, - { withCredentials: true } - ); + const refreshResponse = await authRefresh(refreshToken); - const nextAccessToken = String(refreshResponse.data.accessToken || ""); + const nextAccessToken = String(refreshResponse.accessToken || ""); if (!nextAccessToken) { throw new Error("Missing access token in refresh response"); } localStorage.setItem(TOKEN_KEY, nextAccessToken); - if (refreshResponse.data.refreshToken) { - localStorage.setItem(REFRESH_TOKEN_KEY, refreshResponse.data.refreshToken); + if (refreshResponse.refreshToken) { + localStorage.setItem(REFRESH_TOKEN_KEY, refreshResponse.refreshToken); } return nextAccessToken; diff --git a/frontend/src/context/AuthContext.tsx b/frontend/src/context/AuthContext.tsx index 9602002..a484642 100644 --- a/frontend/src/context/AuthContext.tsx +++ b/frontend/src/context/AuthContext.tsx @@ -1,9 +1,14 @@ import React, { createContext, useContext, useState, useEffect } from 'react'; import type { ReactNode } from 'react'; import { useNavigate } from 'react-router-dom'; -import axios from 'axios'; - -const API_URL = import.meta.env.VITE_API_URL || "/api"; +import { + authStatus, + authMe, + authRefresh, + authLogin, + authRegister, + isAxiosError, +} from '../api'; interface User { id: string; @@ -39,24 +44,21 @@ export const AuthProvider: React.FC<{ children: ReactNode }> = ({ children }) => const [bootstrapRequired, setBootstrapRequired] = useState(false); const navigate = useNavigate(); - // Load user from localStorage on mount useEffect(() => { const loadUser = async () => { try { - // Determine auth mode first (single-user mode vs multi-user auth). try { - const statusResponse = await axios.get(`${API_URL}/auth/status`); + const statusResponse = await authStatus(); const enabled = - typeof statusResponse.data?.authEnabled === "boolean" - ? statusResponse.data.authEnabled - : typeof statusResponse.data?.enabled === "boolean" - ? statusResponse.data.enabled + typeof statusResponse?.authEnabled === "boolean" + ? statusResponse.authEnabled + : typeof statusResponse?.enabled === "boolean" + ? statusResponse.enabled : true; setAuthEnabled(enabled); localStorage.setItem(AUTH_ENABLED_CACHE_KEY, String(enabled)); - setBootstrapRequired(Boolean(statusResponse.data?.bootstrapRequired)); + setBootstrapRequired(Boolean(statusResponse?.bootstrapRequired)); - // In single-user mode, do not require login. if (!enabled) { localStorage.removeItem(TOKEN_KEY); localStorage.removeItem(REFRESH_TOKEN_KEY); @@ -75,7 +77,6 @@ export const AuthProvider: React.FC<{ children: ReactNode }> = ({ children }) => setUser(null); return; } - // If status fails and no cached mode exists, default to auth-enabled mode. setAuthEnabled(true); setBootstrapRequired(false); } @@ -86,39 +87,28 @@ export const AuthProvider: React.FC<{ children: ReactNode }> = ({ children }) => if (storedUser && storedToken) { const userData = JSON.parse(storedUser); setUser(userData); - - // Verify token is still valid by fetching user info + try { - const response = await axios.get(`${API_URL}/auth/me`, { - headers: { - Authorization: `Bearer ${storedToken}`, - }, - }); - setUser(response.data.user); - } catch (error) { - // Token invalid, try refresh + const response = await authMe(storedToken); + setUser(response.user); + } catch { const refreshToken = localStorage.getItem(REFRESH_TOKEN_KEY); if (refreshToken) { try { - const refreshResponse = await axios.post(`${API_URL}/auth/refresh`, { - refreshToken, - }); - localStorage.setItem(TOKEN_KEY, refreshResponse.data.accessToken); - const userResponse = await axios.get(`${API_URL}/auth/me`, { - headers: { - Authorization: `Bearer ${refreshResponse.data.accessToken}`, - }, - }); - setUser(userResponse.data.user); + const refreshResponse = await authRefresh(refreshToken); + localStorage.setItem(TOKEN_KEY, refreshResponse.accessToken); + if (refreshResponse.refreshToken) { + localStorage.setItem(REFRESH_TOKEN_KEY, refreshResponse.refreshToken); + } + const userResponse = await authMe(refreshResponse.accessToken); + setUser(userResponse.user); } catch { - // Refresh failed, clear auth but don't navigate during initial load localStorage.removeItem(TOKEN_KEY); localStorage.removeItem(REFRESH_TOKEN_KEY); localStorage.removeItem(USER_KEY); setUser(null); } } else { - // No refresh token, clear auth localStorage.removeItem(TOKEN_KEY); localStorage.removeItem(REFRESH_TOKEN_KEY); localStorage.removeItem(USER_KEY); @@ -128,7 +118,6 @@ export const AuthProvider: React.FC<{ children: ReactNode }> = ({ children }) => } } catch (error) { console.error('Failed to load user:', error); - // Clear auth on error localStorage.removeItem(TOKEN_KEY); localStorage.removeItem(REFRESH_TOKEN_KEY); localStorage.removeItem(USER_KEY); @@ -146,12 +135,9 @@ export const AuthProvider: React.FC<{ children: ReactNode }> = ({ children }) => if (authEnabled === false) { throw new Error("Authentication is disabled"); } - const response = await axios.post(`${API_URL}/auth/login`, { - email, - password, - }); + const response = await authLogin(email, password); - const { user: userData, accessToken, refreshToken } = response.data; + const { user: userData, accessToken, refreshToken } = response; localStorage.setItem(TOKEN_KEY, accessToken); localStorage.setItem(REFRESH_TOKEN_KEY, refreshToken); @@ -159,8 +145,8 @@ export const AuthProvider: React.FC<{ children: ReactNode }> = ({ children }) => setUser(userData); } catch (error: unknown) { - if (axios.isAxiosError(error)) { - const message = + if (isAxiosError(error)) { + const message = typeof error.response?.data === 'object' && error.response.data !== null && 'message' in error.response.data && @@ -178,13 +164,9 @@ export const AuthProvider: React.FC<{ children: ReactNode }> = ({ children }) => if (authEnabled === false) { throw new Error("Authentication is disabled"); } - const response = await axios.post(`${API_URL}/auth/register`, { - email, - password, - name, - }); + const response = await authRegister(email, password, name); - const { user: userData, accessToken, refreshToken } = response.data; + const { user: userData, accessToken, refreshToken } = response; localStorage.setItem(TOKEN_KEY, accessToken); localStorage.setItem(REFRESH_TOKEN_KEY, refreshToken); @@ -192,8 +174,8 @@ export const AuthProvider: React.FC<{ children: ReactNode }> = ({ children }) => setUser(userData); } catch (error: unknown) { - if (axios.isAxiosError(error)) { - const message = + if (isAxiosError(error)) { + const message = typeof error.response?.data === 'object' && error.response.data !== null && 'message' in error.response.data && @@ -211,7 +193,6 @@ export const AuthProvider: React.FC<{ children: ReactNode }> = ({ children }) => localStorage.removeItem(REFRESH_TOKEN_KEY); localStorage.removeItem(USER_KEY); setUser(null); - // Navigate to login - use setTimeout to ensure Router is ready setTimeout(() => { navigate('/login'); }, 0); diff --git a/frontend/src/pages/Dashboard.tsx b/frontend/src/pages/Dashboard.tsx index 74ad12e..a38a840 100644 --- a/frontend/src/pages/Dashboard.tsx +++ b/frontend/src/pages/Dashboard.tsx @@ -4,14 +4,13 @@ import { DrawingCard } from '../components/DrawingCard'; import { Plus, Search, Loader2, Inbox, Trash2, Folder, ArrowRight, Copy, Upload, CheckSquare, Square, ArrowUp, ArrowDown, ChevronDown, FileText, Calendar, Clock } from 'lucide-react'; import { useNavigate, useSearchParams, useLocation } from 'react-router-dom'; import * as api from '../api'; -import type { DrawingSummary, Collection } from '../types'; import type { DrawingSortField, SortDirection } from '../api'; import { useDebounce } from '../hooks/useDebounce'; import clsx from 'clsx'; import { ConfirmModal } from '../components/ConfirmModal'; import { useUpload } from '../context/UploadContext'; import { DragOverlayPortal, getSelectionBounds, type Point, type SelectionBounds } from './dashboard/shared'; -import { isLatestRequest, mergeUniqueDrawings } from './dashboard/pagination'; +import { useDashboardData } from './dashboard/useDashboardData'; const PAGE_SIZE = 24; @@ -19,10 +18,6 @@ export const Dashboard: React.FC = () => { const [searchParams] = useSearchParams(); const location = useLocation(); const navigate = useNavigate(); - const [drawings, setDrawings] = useState([]); - const [collections, setCollections] = useState([]); - const [totalCount, setTotalCount] = useState(0); - const [isFetchingMore, setIsFetchingMore] = useState(false); const selectedCollectionId = React.useMemo(() => { if (location.pathname === '/') return undefined; @@ -73,73 +68,29 @@ export const Dashboard: React.FC = () => { direction: 'desc' }); - const [isLoading, setIsLoading] = useState(false); - const listRequestVersionRef = useRef(0); - const { uploadFiles } = useUpload(); - - const hasMore = drawings.length < totalCount; - - const refreshData = useCallback(async () => { - const requestVersion = ++listRequestVersionRef.current; - setIsLoading(true); - try { - const [drawingsRes, collectionsData] = await Promise.all([ - api.getDrawings(debouncedSearch, selectedCollectionId, { - limit: PAGE_SIZE, - offset: 0, - sortField: sortConfig.field, - sortDirection: sortConfig.direction, - }), - api.getCollections() - ]); - if (!isLatestRequest(requestVersion, listRequestVersionRef.current)) return; - setDrawings(drawingsRes.drawings); - setTotalCount(drawingsRes.totalCount); - setCollections(collectionsData); - setSelectedIds(new Set()); - } catch (err) { - console.error('Failed to fetch data:', err); - } finally { - if (isLatestRequest(requestVersion, listRequestVersionRef.current)) { - setIsLoading(false); - } - } - }, [debouncedSearch, selectedCollectionId, sortConfig.field, sortConfig.direction]); - - const fetchMore = useCallback(async () => { - if (isFetchingMore || !hasMore || isLoading) return; - const requestVersion = listRequestVersionRef.current; - setIsFetchingMore(true); - try { - const drawingsRes = await api.getDrawings(debouncedSearch, selectedCollectionId, { - limit: PAGE_SIZE, - offset: drawings.length, - sortField: sortConfig.field, - sortDirection: sortConfig.direction, - }); - if (!isLatestRequest(requestVersion, listRequestVersionRef.current)) return; - setDrawings(prev => mergeUniqueDrawings(prev, drawingsRes.drawings)); - setTotalCount(drawingsRes.totalCount); - } catch (err) { - console.error('Failed to fetch more data:', err); - } finally { - setIsFetchingMore(false); - } - }, [ + const resetSelection = useCallback(() => { + setSelectedIds(new Set()); + }, []); + const { + drawings, + setDrawings, + collections, + setCollections, + setTotalCount, isFetchingMore, - hasMore, isLoading, + hasMore, + refreshData, + fetchMore, + } = useDashboardData({ debouncedSearch, selectedCollectionId, - drawings.length, - sortConfig.field, - sortConfig.direction, - ]); - - useEffect(() => { - refreshData(); - }, [refreshData]); + sortField: sortConfig.field, + sortDirection: sortConfig.direction, + pageSize: PAGE_SIZE, + onRefreshSuccess: resetSelection, + }); // Infinite scroll observer useEffect(() => { diff --git a/frontend/src/pages/Editor.tsx b/frontend/src/pages/Editor.tsx index 4b388ca..f691bb5 100644 --- a/frontend/src/pages/Editor.tsx +++ b/frontend/src/pages/Editor.tsx @@ -7,7 +7,7 @@ import debounce from 'lodash/debounce'; import throttle from 'lodash/throttle'; import { Toaster, toast } from 'sonner'; import { io, Socket } from 'socket.io-client'; -import { getUserIdentity, type UserIdentity } from '../utils/identity'; +import type { UserIdentity } from '../utils/identity'; import { useAuth } from '../context/AuthContext'; import { reconcileElements } from '../utils/sync'; import { exportFromEditor } from '../utils/exportUtils'; @@ -15,9 +15,7 @@ import * as api from '../api'; import { useTheme } from '../context/ThemeContext'; import { UIOptions, - getColorFromString, getFilesDelta, - getInitialsFromName, hasRenderableElements, haveSameElements, isSuspiciousEmptySnapshot, @@ -25,6 +23,8 @@ import { isStaleNonRenderableSnapshot, } from './editor/shared'; import type { ElementVersionInfo } from './editor/shared'; +import { useEditorChrome } from './editor/useEditorChrome'; +import { useEditorIdentity } from './editor/useEditorIdentity'; interface Peer extends UserIdentity { isActive: boolean; @@ -49,75 +49,13 @@ export const Editor: React.FC = () => { const [isSceneLoading, setIsSceneLoading] = useState(true); const [loadError, setLoadError] = useState(null); const [isSavingOnLeave, setIsSavingOnLeave] = useState(false); - const [isHeaderVisible, setIsHeaderVisible] = useState(true); const [autoHideEnabled, setAutoHideEnabled] = useState(true); - - useEffect(() => { - document.title = `${drawingName} - ExcaliDash`; - return () => { - document.title = 'ExcaliDash'; - }; - }, [drawingName]); - - // Auto-hide header based on mouse movement - useEffect(() => { - if (!autoHideEnabled || isRenaming) { - setIsHeaderVisible(true); - return; - } - - let hideTimeout: ReturnType | null = null; - let isInTriggerZone = false; - - const handleMouseMove = throttle((e: MouseEvent) => { - const wasInTriggerZone = isInTriggerZone; - isInTriggerZone = e.clientY < 5; - - if (isInTriggerZone) { - // Mouse is in trigger zone - show header - setIsHeaderVisible(true); - if (hideTimeout !== null) { - clearTimeout(hideTimeout); - hideTimeout = null; - } - } else if (wasInTriggerZone) { - // Mouse just left trigger zone - start hide timer - if (hideTimeout !== null) clearTimeout(hideTimeout); - hideTimeout = setTimeout(() => { - setIsHeaderVisible(false); - }, 2000); - } - // If mouse is already out of trigger zone and moving, don't reset timer - }, 100); - - // Show header initially - setIsHeaderVisible(true); - - // Hide after initial delay if mouse doesn't move to top - hideTimeout = setTimeout(() => { - setIsHeaderVisible(false); - }, 3000); - - window.addEventListener('mousemove', handleMouseMove, { passive: true }); - - return () => { - window.removeEventListener('mousemove', handleMouseMove); - if (hideTimeout !== null) clearTimeout(hideTimeout); - }; - }, [autoHideEnabled, isRenaming]); - - // Use authenticated user identity or fallback to generated identity - const [me] = useState(() => { - if (user) { - return { - id: user.id, - name: user.name, - initials: getInitialsFromName(user.name), - color: getColorFromString(user.id), - }; - } - return getUserIdentity(); + const { isHeaderVisible, setIsHeaderVisible } = useEditorChrome({ + drawingName, + autoHideEnabled, + isRenaming, }); + const me: UserIdentity = useEditorIdentity(user); const [peers, setPeers] = useState([]); const [isReady, setIsReady] = useState(false); diff --git a/frontend/src/pages/PasswordResetConfirm.tsx b/frontend/src/pages/PasswordResetConfirm.tsx index 46ac145..379a075 100644 --- a/frontend/src/pages/PasswordResetConfirm.tsx +++ b/frontend/src/pages/PasswordResetConfirm.tsx @@ -1,9 +1,7 @@ import React, { useState, useEffect } from 'react'; import { useSearchParams, useNavigate, Link } from 'react-router-dom'; -import axios from 'axios'; import { Logo } from '../components/Logo'; - -const API_URL = import.meta.env.VITE_API_URL || "/api"; +import { authPasswordResetConfirm, isAxiosError } from '../api'; export const PasswordResetConfirm: React.FC = () => { const [searchParams] = useSearchParams(); @@ -44,17 +42,14 @@ export const PasswordResetConfirm: React.FC = () => { setLoading(true); try { - await axios.post(`${API_URL}/auth/password-reset-confirm`, { - token, - password, - }); + await authPasswordResetConfirm(token, password); setSuccess(true); setTimeout(() => { navigate('/login'); }, 3000); } catch (err: unknown) { let message = 'Failed to reset password'; - if (axios.isAxiosError(err)) { + if (isAxiosError(err)) { if (err.response?.status === 404) { message = 'Password reset feature is not enabled on this server'; } else if (err.response?.data?.message) { diff --git a/frontend/src/pages/dashboard/useDashboardData.ts b/frontend/src/pages/dashboard/useDashboardData.ts new file mode 100644 index 0000000..fd81a91 --- /dev/null +++ b/frontend/src/pages/dashboard/useDashboardData.ts @@ -0,0 +1,117 @@ +import { useCallback, useEffect, useRef, useState } from 'react'; +import * as api from '../../api'; +import type { DrawingSortField, SortDirection } from '../../api'; +import type { Collection, DrawingSummary } from '../../types'; +import { isLatestRequest, mergeUniqueDrawings } from './pagination'; + +type SelectedCollectionId = string | null | undefined; + +type UseDashboardDataOptions = { + debouncedSearch: string; + selectedCollectionId: SelectedCollectionId; + sortField: DrawingSortField; + sortDirection: SortDirection; + pageSize: number; + onRefreshSuccess?: () => void; +}; + +export const useDashboardData = ({ + debouncedSearch, + selectedCollectionId, + sortField, + sortDirection, + pageSize, + onRefreshSuccess, +}: UseDashboardDataOptions) => { + const [drawings, setDrawings] = useState([]); + const [collections, setCollections] = useState([]); + const [totalCount, setTotalCount] = useState(0); + const [isFetchingMore, setIsFetchingMore] = useState(false); + const [isLoading, setIsLoading] = useState(false); + const listRequestVersionRef = useRef(0); + + const hasMore = drawings.length < totalCount; + + const refreshData = useCallback(async () => { + const requestVersion = ++listRequestVersionRef.current; + setIsLoading(true); + try { + const [drawingsRes, collectionsData] = await Promise.all([ + api.getDrawings(debouncedSearch, selectedCollectionId, { + limit: pageSize, + offset: 0, + sortField, + sortDirection, + }), + api.getCollections(), + ]); + if (!isLatestRequest(requestVersion, listRequestVersionRef.current)) return; + setDrawings(drawingsRes.drawings); + setTotalCount(drawingsRes.totalCount); + setCollections(collectionsData); + onRefreshSuccess?.(); + } catch (err) { + console.error('Failed to fetch data:', err); + } finally { + if (isLatestRequest(requestVersion, listRequestVersionRef.current)) { + setIsLoading(false); + } + } + }, [ + debouncedSearch, + selectedCollectionId, + pageSize, + sortField, + sortDirection, + onRefreshSuccess, + ]); + + const fetchMore = useCallback(async () => { + if (isFetchingMore || !hasMore || isLoading) return; + const requestVersion = listRequestVersionRef.current; + setIsFetchingMore(true); + try { + const drawingsRes = await api.getDrawings(debouncedSearch, selectedCollectionId, { + limit: pageSize, + offset: drawings.length, + sortField, + sortDirection, + }); + if (!isLatestRequest(requestVersion, listRequestVersionRef.current)) return; + setDrawings((prev) => mergeUniqueDrawings(prev, drawingsRes.drawings)); + setTotalCount(drawingsRes.totalCount); + } catch (err) { + console.error('Failed to fetch more data:', err); + } finally { + setIsFetchingMore(false); + } + }, [ + isFetchingMore, + hasMore, + isLoading, + debouncedSearch, + selectedCollectionId, + pageSize, + drawings.length, + sortField, + sortDirection, + ]); + + useEffect(() => { + refreshData(); + }, [refreshData]); + + return { + drawings, + setDrawings, + collections, + setCollections, + totalCount, + setTotalCount, + isFetchingMore, + isLoading, + hasMore, + refreshData, + fetchMore, + }; +}; diff --git a/frontend/src/pages/editor/useEditorChrome.ts b/frontend/src/pages/editor/useEditorChrome.ts new file mode 100644 index 0000000..ffb85d6 --- /dev/null +++ b/frontend/src/pages/editor/useEditorChrome.ts @@ -0,0 +1,68 @@ +import { useEffect, useState } from 'react'; +import throttle from 'lodash/throttle'; + +type UseEditorChromeOptions = { + drawingName: string; + autoHideEnabled: boolean; + isRenaming: boolean; +}; + +export const useEditorChrome = ({ + drawingName, + autoHideEnabled, + isRenaming, +}: UseEditorChromeOptions) => { + const [isHeaderVisible, setIsHeaderVisible] = useState(true); + + useEffect(() => { + document.title = `${drawingName} - ExcaliDash`; + return () => { + document.title = 'ExcaliDash'; + }; + }, [drawingName]); + + useEffect(() => { + if (!autoHideEnabled || isRenaming) { + setIsHeaderVisible(true); + return; + } + + let hideTimeout: ReturnType | null = null; + let isInTriggerZone = false; + + const handleMouseMove = throttle((e: MouseEvent) => { + const wasInTriggerZone = isInTriggerZone; + isInTriggerZone = e.clientY < 5; + + if (isInTriggerZone) { + setIsHeaderVisible(true); + if (hideTimeout !== null) { + clearTimeout(hideTimeout); + hideTimeout = null; + } + } else if (wasInTriggerZone) { + if (hideTimeout !== null) clearTimeout(hideTimeout); + hideTimeout = setTimeout(() => { + setIsHeaderVisible(false); + }, 2000); + } + }, 100); + + setIsHeaderVisible(true); + hideTimeout = setTimeout(() => { + setIsHeaderVisible(false); + }, 3000); + + window.addEventListener('mousemove', handleMouseMove, { passive: true }); + + return () => { + window.removeEventListener('mousemove', handleMouseMove); + if (hideTimeout !== null) clearTimeout(hideTimeout); + }; + }, [autoHideEnabled, isRenaming]); + + return { + isHeaderVisible, + setIsHeaderVisible, + }; +}; diff --git a/frontend/src/pages/editor/useEditorIdentity.ts b/frontend/src/pages/editor/useEditorIdentity.ts new file mode 100644 index 0000000..38f6292 --- /dev/null +++ b/frontend/src/pages/editor/useEditorIdentity.ts @@ -0,0 +1,25 @@ +import { useMemo } from 'react'; +import { getUserIdentity, type UserIdentity } from '../../utils/identity'; +import { + getColorFromString, + getInitialsFromName, +} from './shared'; + +type AuthUser = { + id: string; + name: string; +} | null | undefined; + +export const useEditorIdentity = (user: AuthUser): UserIdentity => { + return useMemo(() => { + if (user) { + return { + id: user.id, + name: user.name, + initials: getInitialsFromName(user.name), + color: getColorFromString(user.id), + }; + } + return getUserIdentity(); + }, [user]); +};