prevent preview updates from overwriting drawings

This commit is contained in:
Zimeng Xiong
2026-02-07 15:51:27 -08:00
parent 02736d663a
commit 2aa749a2f0
27 changed files with 1172 additions and 2759 deletions
+24 -1
View File
@@ -2,6 +2,7 @@
set -e
JWT_SECRET_FILE="/app/prisma/.jwt_secret"
CSRF_SECRET_FILE="/app/prisma/.csrf_secret"
# Ensure JWT secret exists for production startup.
# Backward compatibility: older installs may not have JWT_SECRET configured.
@@ -25,6 +26,27 @@ fi
export JWT_SECRET
# Ensure CSRF secret exists for stable token validation across restarts.
# (Still recommend setting explicitly for multi-instance deployments.)
if [ -z "${CSRF_SECRET:-}" ]; then
echo "CSRF_SECRET not provided, resolving persisted secret..."
if [ -f "${CSRF_SECRET_FILE}" ]; then
CSRF_SECRET="$(tr -d '\r\n' < "${CSRF_SECRET_FILE}")"
fi
if [ -z "${CSRF_SECRET}" ]; then
echo "No persisted CSRF secret found. Generating a new secret..."
CSRF_SECRET="$(openssl rand -base64 32)"
umask 077
printf "%s" "${CSRF_SECRET}" > "${CSRF_SECRET_FILE}"
fi
else
umask 077
printf "%s" "${CSRF_SECRET}" > "${CSRF_SECRET_FILE}"
fi
export CSRF_SECRET
# 1. Hydrate volume if empty (Running as root)
if [ ! -f "/app/prisma/schema.prisma" ]; then
echo "Mount is empty. Hydrating /app/prisma..."
@@ -43,11 +65,12 @@ chown -R nodejs:nodejs /app/uploads
chown -R nodejs:nodejs /app/prisma
chmod 755 /app/uploads
chmod 600 "${JWT_SECRET_FILE}"
chmod 600 "${CSRF_SECRET_FILE}"
# Ensure database file has proper permissions
if [ -f "/app/prisma/dev.db" ]; then
echo "Database file found, ensuring write permissions..."
chmod 666 /app/prisma/dev.db
chmod 600 /app/prisma/dev.db
fi
# 3. Run Migrations (Drop privileges to nodejs)
@@ -0,0 +1,94 @@
import { afterAll, beforeAll, describe, expect, it } from "vitest";
import request from "supertest";
import bcrypt from "bcrypt";
import jwt, { SignOptions } from "jsonwebtoken";
import { StringValue } from "ms";
import { PrismaClient } from "../generated/client";
import { config } from "../config";
import { getTestPrisma, setupTestDb } from "./testUtils";
describe("Auth Enabled Toggle Authorization", () => {
const userAgent = "vitest-auth-enabled";
let prisma: PrismaClient;
let app: any;
let csrfHeaderName: string;
let csrfToken: string;
let regularUserToken: string;
beforeAll(async () => {
setupTestDb();
prisma = getTestPrisma();
({ app } = await import("../index"));
await prisma.systemConfig.upsert({
where: { id: "default" },
update: {
authEnabled: true,
registrationEnabled: false,
},
create: {
id: "default",
authEnabled: true,
registrationEnabled: false,
},
});
const passwordHash = await bcrypt.hash("password123", 10);
const user = await prisma.user.create({
data: {
email: "regular-user@test.local",
passwordHash,
name: "Regular User",
role: "USER",
isActive: true,
},
select: {
id: true,
email: true,
},
});
const signOptions: SignOptions = {
expiresIn: config.jwtAccessExpiresIn as StringValue,
};
regularUserToken = jwt.sign(
{ userId: user.id, email: user.email, type: "access" },
config.jwtSecret,
signOptions
);
const agent = request.agent(app);
const csrfRes = await agent
.get("/csrf-token")
.set("User-Agent", userAgent);
csrfHeaderName = csrfRes.body.header;
csrfToken = csrfRes.body.token;
});
afterAll(async () => {
await prisma.$disconnect();
});
it("rejects unauthenticated auth-enabled toggle when auth is enabled", async () => {
const response = await request(app)
.post("/auth/auth-enabled")
.set("User-Agent", userAgent)
.set(csrfHeaderName, csrfToken)
.send({ enabled: false });
expect(response.status).toBe(401);
});
it("rejects non-admin auth-enabled toggle", async () => {
const response = await request(app)
.post("/auth/auth-enabled")
.set("User-Agent", userAgent)
.set("Authorization", `Bearer ${regularUserToken}`)
.set(csrfHeaderName, csrfToken)
.send({ enabled: false });
expect(response.status).toBe(403);
expect(response.body?.message).toContain("Admin access required");
});
});
@@ -0,0 +1,20 @@
import { describe, expect, it } from "vitest";
import { createCsrfToken, validateCsrfToken } from "../security";
describe("CSRF client identity stability", () => {
it("keeps token validation stable when using cookie-based client IDs", () => {
const cookieClientId = "cookie:fixed-client-id";
const token = createCsrfToken(cookieClientId);
expect(validateCsrfToken(cookieClientId, token)).toBe(true);
});
it("shows why legacy IP-based IDs are unstable across proxy hops", () => {
const userAgent = "Mozilla/5.0 test";
const clientIdViaProxyA = `10.0.0.5:${userAgent}`;
const clientIdViaProxyB = `10.0.0.6:${userAgent}`;
const token = createCsrfToken(clientIdViaProxyA);
expect(validateCsrfToken(clientIdViaProxyB, token)).toBe(false);
});
});
@@ -345,7 +345,9 @@ describe("Import compatibility (legacy exports)", () => {
expect.arrayContaining(["legacy-drawing-1", "legacy-drawing-2", "legacy-drawing-trash"])
);
const trash = await prisma.collection.findUnique({ where: { id: "trash" } });
const trash = await prisma.collection.findUnique({
where: { id: "trash:bootstrap-admin" },
});
expect(trash).toBeTruthy();
});
@@ -0,0 +1,55 @@
import { describe, expect, it } from "vitest";
import { sanitizeDrawingUpdateData } from "../index";
describe("sanitizeDrawingUpdateData regression", () => {
it("does not inject empty scene fields for preview-only updates", () => {
const payload: {
preview?: string | null;
elements?: unknown[];
appState?: Record<string, unknown>;
files?: Record<string, unknown>;
} = {
preview: "<svg><rect width=\"10\" height=\"10\"/></svg>",
};
const ok = sanitizeDrawingUpdateData(payload);
expect(ok).toBe(true);
expect(typeof payload.preview).toBe("string");
expect(String(payload.preview)).toContain("<svg");
expect(payload.elements).toBeUndefined();
expect(payload.appState).toBeUndefined();
expect(payload.files).toBeUndefined();
});
it("still sanitizes scene fields when scene data is provided", () => {
const payload: {
preview?: string | null;
elements?: any[];
appState?: Record<string, unknown>;
files?: Record<string, unknown>;
} = {
elements: [
{
id: "el-1",
type: "rectangle",
x: 0,
y: 0,
width: 100,
height: 100,
version: 1,
versionNonce: 1,
isDeleted: false,
},
],
appState: { viewBackgroundColor: "#ffffff" },
files: {},
preview: "<svg/>",
};
const ok = sanitizeDrawingUpdateData(payload);
expect(ok).toBe(true);
expect(Array.isArray(payload.elements)).toBe(true);
expect(typeof payload.appState).toBe("object");
});
});
+6 -7
View File
@@ -98,11 +98,9 @@ export const setupTestDb = () => {
* Clean up the test database between tests
*/
export const cleanupTestDb = async (prisma: PrismaClient) => {
// Delete all drawings and collections (except Trash)
// Delete all drawings and collections.
await prisma.drawing.deleteMany({});
await prisma.collection.deleteMany({
where: { id: { not: "trash" } },
});
await prisma.collection.deleteMany({});
};
/**
@@ -129,14 +127,15 @@ export const createTestUser = async (prisma: PrismaClient, email: string = "test
export const initTestDb = async (prisma: PrismaClient) => {
// Create a test user first
const testUser = await createTestUser(prisma);
const trashCollectionId = `trash:${testUser.id}`;
// Ensure Trash collection exists
const trash = await prisma.collection.findUnique({
where: { id: "trash" },
const trash = await prisma.collection.findFirst({
where: { id: trashCollectionId, userId: testUser.id },
});
if (!trash) {
await prisma.collection.create({
data: { id: "trash", name: "Trash", userId: testUser.id },
data: { id: trashCollectionId, name: "Trash", userId: testUser.id },
});
}
+44 -2
View File
@@ -224,12 +224,52 @@ const requireAdmin = (
return true;
};
const getClientId = (req: Request): string => {
const CSRF_CLIENT_COOKIE_NAME = "excalidash-csrf-client";
const parseCookies = (cookieHeader: string | undefined): Record<string, string> => {
if (!cookieHeader) return {};
const cookies: Record<string, string> = {};
for (const part of cookieHeader.split(";")) {
const [rawKey, ...rawValueParts] = part.split("=");
const key = rawKey?.trim();
if (!key) continue;
const rawValue = rawValueParts.join("=").trim();
try {
cookies[key] = decodeURIComponent(rawValue);
} catch {
cookies[key] = rawValue;
}
}
return cookies;
};
const getCsrfClientCookieValue = (req: Request): string | null => {
const cookies = parseCookies(req.headers.cookie);
const value = cookies[CSRF_CLIENT_COOKIE_NAME];
if (!value) return null;
if (!/^[A-Za-z0-9_-]{16,128}$/.test(value)) return null;
return value;
};
const getLegacyClientId = (req: Request): string => {
const ip = req.ip || req.connection.remoteAddress || "unknown";
const userAgent = req.headers["user-agent"] || "unknown";
return `${ip}:${userAgent}`.slice(0, 256);
};
const getCsrfValidationClientIds = (req: Request): string[] => {
const candidates: string[] = [];
const cookieValue = getCsrfClientCookieValue(req);
if (cookieValue) {
candidates.push(`cookie:${cookieValue}`);
}
const legacyClientId = getLegacyClientId(req);
if (!candidates.includes(legacyClientId)) {
candidates.push(legacyClientId);
}
return candidates;
};
const requireCsrf = (req: Request, res: Response): boolean => {
const headerName = getCsrfTokenHeader();
const tokenHeader = req.headers[headerName];
@@ -243,7 +283,9 @@ const requireCsrf = (req: Request, res: Response): boolean => {
return false;
}
if (!validateCsrfToken(getClientId(req), token)) {
const clientIds = getCsrfValidationClientIds(req);
const isValidToken = clientIds.some((clientId) => validateCsrfToken(clientId, token));
if (!isValidToken) {
res.status(403).json({
error: "CSRF token invalid",
message: "Invalid or expired CSRF token. Please refresh and try again.",
+16 -5
View File
@@ -11,6 +11,7 @@ import {
updateEmailSchema,
updateProfileSchema,
} from "./schemas";
import { getTokenLookupCandidates, hashTokenForStorage } from "./tokenSecurity";
type RegisterAccountRoutesDeps = {
router: express.Router;
@@ -81,7 +82,7 @@ export const registerAccountRoutes = (deps: RegisterAccountRoutesDeps) => {
});
await prisma.passwordResetToken.create({
data: { userId: user.id, token: resetToken, expiresAt },
data: { userId: user.id, token: hashTokenForStorage(resetToken), expiresAt },
});
if (config.enableAuditLogging) {
@@ -137,8 +138,10 @@ export const registerAccountRoutes = (deps: RegisterAccountRoutesDeps) => {
}
const { token, password } = parsed.data;
const resetToken = await prisma.passwordResetToken.findUnique({
where: { token },
const resetToken = await prisma.passwordResetToken.findFirst({
where: {
OR: getTokenLookupCandidates(token).map((candidate) => ({ token: candidate })),
},
include: { user: true },
});
@@ -348,7 +351,11 @@ export const registerAccountRoutes = (deps: RegisterAccountRoutesDeps) => {
const expiresAt = getRefreshTokenExpiresAt();
try {
await prisma.refreshToken.create({
data: { userId: updatedUser.id, token: refreshToken, expiresAt },
data: {
userId: updatedUser.id,
token: hashTokenForStorage(refreshToken),
expiresAt,
},
});
} catch {
if (process.env.NODE_ENV === "development") {
@@ -525,7 +532,11 @@ export const registerAccountRoutes = (deps: RegisterAccountRoutesDeps) => {
const expiresAt = getRefreshTokenExpiresAt();
try {
await prisma.refreshToken.create({
data: { userId: updatedUser.id, token: refreshToken, expiresAt },
data: {
userId: updatedUser.id,
token: hashTokenForStorage(refreshToken),
expiresAt,
},
});
} catch {
if (process.env.NODE_ENV === "development") {
+2 -1
View File
@@ -11,6 +11,7 @@ import {
loginRateLimitUpdateSchema,
registrationToggleSchema,
} from "./schemas";
import { hashTokenForStorage } from "./tokenSecurity";
type RegisterAdminRoutesDeps = {
router: express.Router;
@@ -610,7 +611,7 @@ export const registerAdminRoutes = (deps: RegisterAdminRoutesDeps) => {
const expiresAt = getRefreshTokenExpiresAt();
try {
await prisma.refreshToken.create({
data: { userId: target.id, token: refreshToken, expiresAt },
data: { userId: target.id, token: hashTokenForStorage(refreshToken), expiresAt },
});
} catch {
if (process.env.NODE_ENV === "development") {
+31 -26
View File
@@ -9,6 +9,7 @@ import {
loginSchema,
registerSchema,
} from "./schemas";
import { getTokenLookupCandidates, hashTokenForStorage } from "./tokenSecurity";
type RegisterCoreRoutesDeps = {
router: express.Router;
@@ -86,6 +87,7 @@ export const registerCoreRoutes = (deps: RegisterCoreRoutesDeps) => {
bootstrapUserId,
defaultSystemConfigId,
} = deps;
const getUserTrashCollectionId = (userId: string): string => `trash:${userId}`;
router.post("/register", loginAttemptRateLimiter, async (req: Request, res: Response) => {
try {
@@ -139,13 +141,14 @@ export const registerCoreRoutes = (deps: RegisterCoreRoutesDeps) => {
},
});
const existingTrash = await prisma.collection.findUnique({
where: { id: "trash" },
const trashCollectionId = getUserTrashCollectionId(user.id);
const existingTrash = await prisma.collection.findFirst({
where: { id: trashCollectionId, userId: user.id },
});
if (!existingTrash) {
await prisma.collection.create({
data: {
id: "trash",
id: trashCollectionId,
name: "Trash",
userId: user.id,
},
@@ -157,7 +160,7 @@ export const registerCoreRoutes = (deps: RegisterCoreRoutesDeps) => {
if (config.enableRefreshTokenRotation) {
const expiresAt = getRefreshTokenExpiresAt();
await prisma.refreshToken.create({
data: { userId: user.id, token: refreshToken, expiresAt },
data: { userId: user.id, token: hashTokenForStorage(refreshToken), expiresAt },
});
}
@@ -237,13 +240,14 @@ export const registerCoreRoutes = (deps: RegisterCoreRoutesDeps) => {
},
});
const existingTrash = await prisma.collection.findUnique({
where: { id: "trash" },
const trashCollectionId = getUserTrashCollectionId(user.id);
const existingTrash = await prisma.collection.findFirst({
where: { id: trashCollectionId, userId: user.id },
});
if (!existingTrash) {
await prisma.collection.create({
data: {
id: "trash",
id: trashCollectionId,
name: "Trash",
userId: user.id,
},
@@ -259,7 +263,7 @@ export const registerCoreRoutes = (deps: RegisterCoreRoutesDeps) => {
await prisma.refreshToken.create({
data: {
userId: user.id,
token: refreshToken,
token: hashTokenForStorage(refreshToken),
expiresAt,
},
});
@@ -372,7 +376,7 @@ export const registerCoreRoutes = (deps: RegisterCoreRoutesDeps) => {
await prisma.refreshToken.create({
data: {
userId: user.id,
token: refreshToken,
token: hashTokenForStorage(refreshToken),
expiresAt,
},
});
@@ -464,8 +468,12 @@ export const registerCoreRoutes = (deps: RegisterCoreRoutesDeps) => {
const expiresAt = getRefreshTokenExpiresAt();
await prisma.$transaction(async (tx) => {
const storedToken = await tx.refreshToken.findUnique({
where: { token: oldRefreshToken },
const storedToken = await tx.refreshToken.findFirst({
where: {
OR: getTokenLookupCandidates(oldRefreshToken).map((candidate) => ({
token: candidate,
})),
},
});
if (!storedToken || storedToken.userId !== user.id || storedToken.revoked) {
@@ -487,7 +495,7 @@ export const registerCoreRoutes = (deps: RegisterCoreRoutesDeps) => {
await tx.refreshToken.create({
data: {
userId: user.id,
token: newRefreshToken,
token: hashTokenForStorage(newRefreshToken),
expiresAt,
},
});
@@ -638,9 +646,19 @@ export const registerCoreRoutes = (deps: RegisterCoreRoutesDeps) => {
}
});
router.post("/auth-enabled", optionalAuth, async (req: Request, res: Response) => {
router.post("/auth-enabled", requireAuth, async (req: Request, res: Response) => {
try {
if (!requireCsrf(req, res)) return;
if (!req.user) {
return res
.status(401)
.json({ error: "Unauthorized", message: "User not authenticated" });
}
if (req.user.role !== "ADMIN") {
return res
.status(403)
.json({ error: "Forbidden", message: "Admin access required" });
}
const parsed = authEnabledToggleSchema.safeParse(req.body);
if (!parsed.success) {
@@ -653,19 +671,6 @@ export const registerCoreRoutes = (deps: RegisterCoreRoutesDeps) => {
const current = systemConfig.authEnabled;
const next = parsed.data.enabled;
if (current && !next) {
if (!req.user) {
return res
.status(401)
.json({ error: "Unauthorized", message: "User not authenticated" });
}
if (req.user.role !== "ADMIN") {
return res
.status(403)
.json({ error: "Forbidden", message: "Admin access required" });
}
}
if (!current && next) {
const bootstrap = await prisma.user.findUnique({
where: { id: bootstrapUserId },
+11
View File
@@ -0,0 +1,11 @@
import crypto from "crypto";
export const hashTokenForStorage = (token: string): string =>
crypto.createHash("sha256").update(token, "utf8").digest("hex");
export const getTokenLookupCandidates = (token: string): string[] => {
const candidates = new Set<string>();
candidates.add(token);
candidates.add(hashTokenForStorage(token));
return [...candidates];
};
+235 -52
View File
@@ -202,23 +202,32 @@ const invalidateDrawingsCache = () => {
drawingsCache.clear();
};
const getUserTrashCollectionId = (userId: string): string => `trash:${userId}`;
const ensureTrashCollection = async (
db: Prisma.TransactionClient | PrismaClient,
userId: string
): Promise<void> => {
const trashCollection = await db.collection.findUnique({
where: { id: "trash" },
const trashCollectionId = getUserTrashCollectionId(userId);
const trashCollection = await db.collection.findFirst({
where: { id: trashCollectionId, userId },
});
if (!trashCollection) {
await db.collection.create({
data: {
id: "trash",
id: trashCollectionId,
name: "Trash",
userId,
},
});
}
// Legacy migration: move this user's drawings off global "trash".
await db.drawing.updateMany({
where: { userId, collectionId: "trash" },
data: { collectionId: trashCollectionId },
});
};
setInterval(() => {
@@ -375,13 +384,109 @@ app.use(generalRateLimiter);
// CSRF Protection Middleware
// Generates a unique client ID based on IP and User-Agent for token association
const getClientId = (req: express.Request): string => {
const CSRF_CLIENT_COOKIE_NAME = "excalidash-csrf-client";
const CSRF_CLIENT_COOKIE_MAX_AGE_SECONDS = 60 * 60 * 24 * 30; // 30 days
const parseCookies = (cookieHeader: string | undefined): Record<string, string> => {
if (!cookieHeader) return {};
const cookies: Record<string, string> = {};
for (const part of cookieHeader.split(";")) {
const [rawKey, ...rawValueParts] = part.split("=");
const key = rawKey?.trim();
if (!key) continue;
const rawValue = rawValueParts.join("=").trim();
try {
cookies[key] = decodeURIComponent(rawValue);
} catch {
cookies[key] = rawValue;
}
}
return cookies;
};
const getCsrfClientCookieValue = (req: express.Request): string | null => {
const cookies = parseCookies(req.headers.cookie);
const value = cookies[CSRF_CLIENT_COOKIE_NAME];
if (!value) return null;
if (!/^[A-Za-z0-9_-]{16,128}$/.test(value)) return null;
return value;
};
const requestUsesHttps = (req: express.Request): boolean => {
if (req.secure) return true;
const forwardedProto = req.headers["x-forwarded-proto"];
const raw = Array.isArray(forwardedProto) ? forwardedProto[0] : forwardedProto;
const firstHop = String(raw || "")
.split(",")[0]
.trim()
.toLowerCase();
return firstHop === "https";
};
const setCsrfClientCookie = (req: express.Request, res: express.Response, value: string): void => {
const secure = requestUsesHttps(req) ? "; Secure" : "";
res.append(
"Set-Cookie",
`${CSRF_CLIENT_COOKIE_NAME}=${encodeURIComponent(
value
)}; Path=/; HttpOnly; SameSite=Lax; Max-Age=${CSRF_CLIENT_COOKIE_MAX_AGE_SECONDS}${secure}`
);
};
const getLegacyClientId = (req: express.Request): string => {
const ip = req.ip || req.connection.remoteAddress || "unknown";
const userAgent = req.headers["user-agent"] || "unknown";
const clientId = `${ip}:${userAgent}`.slice(0, 256);
return `${ip}:${userAgent}`.slice(0, 256);
};
const getClientIdForTokenIssue = (
req: express.Request,
res: express.Response
): { clientId: string; strategy: "cookie" | "legacy-bootstrap" } => {
const existingCookieValue = getCsrfClientCookieValue(req);
if (existingCookieValue) {
return {
clientId: `cookie:${existingCookieValue}`,
strategy: "cookie",
};
}
// No cookie presented by client yet:
// - issue a token bound to legacy identity for compatibility with non-cookie clients
// - still set a cookie so subsequent browser requests can transition to cookie-bound tokens
const generatedCookieValue = uuidv4().replace(/-/g, "");
setCsrfClientCookie(req, res, generatedCookieValue);
return {
clientId: getLegacyClientId(req),
strategy: "legacy-bootstrap",
};
};
const getClientIdCandidatesForValidation = (req: express.Request): string[] => {
const candidates: string[] = [];
const cookieValue = getCsrfClientCookieValue(req);
if (cookieValue) {
candidates.push(`cookie:${cookieValue}`);
}
const legacyClientId = getLegacyClientId(req);
if (!candidates.includes(legacyClientId)) {
candidates.push(legacyClientId);
}
return candidates;
};
const getClientIdForTokenIssueDebug = (
req: express.Request,
res: express.Response
): string => {
const { clientId, strategy } = getClientIdForTokenIssue(req, res);
// Debug logging for CSRF troubleshooting (issue #38)
if (process.env.DEBUG_CSRF === "true") {
const validationCandidates = getClientIdCandidatesForValidation(req);
const ip = req.ip || req.connection.remoteAddress || "unknown";
console.log("[CSRF DEBUG] getClientId", {
method: req.method,
path: req.path,
@@ -389,9 +494,13 @@ const getClientId = (req: express.Request): string => {
remoteAddress: req.connection.remoteAddress,
"x-forwarded-for": req.headers["x-forwarded-for"],
"x-real-ip": req.headers["x-real-ip"],
userAgent: userAgent.slice(0, 100),
hasCsrfCookie: Boolean(getCsrfClientCookieValue(req)),
clientIdPreview: clientId.slice(0, 60) + "...",
trustProxySetting: req.app.get("trust proxy"),
strategy,
validationCandidatesPreview: validationCandidates.map((candidate) =>
`${candidate.slice(0, 60)}...`
),
});
}
@@ -436,7 +545,7 @@ app.get("/csrf-token", (req, res) => {
}
}
const clientId = getClientId(req);
const clientId = getClientIdForTokenIssueDebug(req, res);
const token = createCsrfToken(clientId);
res.json({
@@ -487,7 +596,7 @@ const csrfProtectionMiddleware = (
// Some legitimate clients/proxies might strip these, so we don't block strictly on their absence,
// but relying on the token is the primary defense.
const clientId = getClientId(req);
const clientIdCandidates = getClientIdCandidatesForValidation(req);
const headerName = getCsrfTokenHeader();
const tokenHeader = req.headers[headerName];
const token = Array.isArray(tokenHeader) ? tokenHeader[0] : tokenHeader;
@@ -499,7 +608,10 @@ const csrfProtectionMiddleware = (
});
}
if (!validateCsrfToken(clientId, token)) {
const isValidToken = clientIdCandidates.some((clientId) =>
validateCsrfToken(clientId, token)
);
if (!isValidToken) {
return res.status(403).json({
error: "CSRF token invalid",
message: "Invalid or expired CSRF token. Please refresh and try again.",
@@ -555,52 +667,71 @@ const drawingCreateSchema = drawingBaseSchema
}
);
const drawingUpdateSchema = drawingBaseSchema
const drawingUpdateSchemaBase = drawingBaseSchema
.extend({
elements: elementSchema.array().optional(),
appState: appStateSchema.optional(),
files: filesFieldSchema,
version: z.number().int().positive().optional(),
})
.refine(
(data) => {
const needsSanitization =
data.elements !== undefined ||
data.appState !== undefined ||
data.files !== undefined ||
data.preview !== undefined;
});
try {
const sanitizedData = { ...data };
if (needsSanitization) {
const fullData = {
elements: Array.isArray(data.elements) ? data.elements : [],
appState:
typeof data.appState === "object" && data.appState !== null
? data.appState
: {},
files: data.files || {},
preview: data.preview,
name: data.name,
collectionId: data.collectionId,
};
const sanitized = sanitizeDrawingData(fullData);
sanitizedData.elements = sanitized.elements;
sanitizedData.appState = sanitized.appState;
if (data.files !== undefined) sanitizedData.files = sanitized.files;
if (data.preview !== undefined)
sanitizedData.preview = sanitized.preview;
Object.assign(data, sanitizedData);
}
return true;
} catch (error) {
console.error("Sanitization failed:", error);
if (!needsSanitization) {
return true;
}
return false;
}
},
export const sanitizeDrawingUpdateData = (
data: {
elements?: unknown[];
appState?: Record<string, unknown>;
files?: Record<string, unknown>;
preview?: string | null;
name?: string;
collectionId?: string | null;
}
): boolean => {
const hasSceneFields =
data.elements !== undefined ||
data.appState !== undefined ||
data.files !== undefined;
const hasPreviewField = data.preview !== undefined;
const needsSanitization = hasSceneFields || hasPreviewField;
try {
const sanitizedData = { ...data };
if (hasSceneFields) {
const fullData = {
elements: Array.isArray(data.elements) ? data.elements : [],
appState:
typeof data.appState === "object" && data.appState !== null
? data.appState
: {},
files: data.files || {},
preview: data.preview,
name: data.name,
collectionId: data.collectionId,
};
const sanitized = sanitizeDrawingData(fullData);
sanitizedData.elements = sanitized.elements;
sanitizedData.appState = sanitized.appState;
if (data.files !== undefined) sanitizedData.files = sanitized.files;
if (data.preview !== undefined) sanitizedData.preview = sanitized.preview;
Object.assign(data, sanitizedData);
} else if (hasPreviewField && typeof data.preview === "string") {
// Preview-only updates must not inject default scene fields.
data.preview = sanitizeSvg(data.preview);
Object.assign(data, { ...data, preview: data.preview });
} else if (hasPreviewField && data.preview === null) {
// Explicitly allow clearing preview without touching scene data.
Object.assign(data, sanitizedData);
}
return true;
} catch (error) {
console.error("Sanitization failed:", error);
if (!needsSanitization) {
return true;
}
return false;
}
};
const drawingUpdateSchema = drawingUpdateSchemaBase.refine(
(data) => sanitizeDrawingUpdateData(data as any),
{
message: "Invalid or malicious drawing data detected",
}
@@ -726,6 +857,33 @@ const roomUsers = new Map<string, User[]>();
// Track which authenticated user owns each socket for authorization checks
const socketUserMap = new Map<string, string>();
const toPresenceName = (value: unknown): string => {
if (typeof value !== "string") return "User";
const trimmed = value.trim().slice(0, 120);
return trimmed.length > 0 ? trimmed : "User";
};
const toPresenceInitials = (name: string): string => {
const words = name
.split(/\s+/)
.map((part) => part.trim())
.filter((part) => part.length > 0);
if (words.length === 0) return "U";
const first = words[0]?.[0] ?? "";
const second = words.length > 1 ? words[1]?.[0] ?? "" : "";
const initials = `${first}${second}`.toUpperCase().slice(0, 2);
return initials.length > 0 ? initials : "U";
};
const toPresenceColor = (value: unknown): string => {
if (typeof value !== "string") return "#4f46e5";
const trimmed = value.trim();
if (/^#[0-9a-fA-F]{3,8}$/.test(trimmed)) {
return trimmed;
}
return "#4f46e5";
};
/**
* Verify JWT from Socket.io auth and check if auth is required.
* When auth is disabled (single-user mode), all connections are allowed.
@@ -815,10 +973,35 @@ io.on("connection", (socket) => {
socket.join(roomId);
authorizedDrawingIds.add(drawingId);
const newUser: User = { ...user, socketId: socket.id, isActive: true };
let trustedUserId =
typeof user?.id === "string" && user.id.trim().length > 0
? user.id.trim().slice(0, 200)
: socket.id;
let trustedName = toPresenceName(user?.name);
// In auth-enabled mode, identity should come from the authenticated account.
if (authenticatedUserId && authenticatedUserId !== "bootstrap-admin") {
const account = await prisma.user.findUnique({
where: { id: authenticatedUserId },
select: { id: true, name: true },
});
if (account) {
trustedUserId = account.id;
trustedName = toPresenceName(account.name);
}
}
const newUser: User = {
id: trustedUserId,
name: trustedName,
initials: toPresenceInitials(trustedName),
color: toPresenceColor(user?.color),
socketId: socket.id,
isActive: true,
};
const currentUsers = roomUsers.get(roomId) || [];
const filteredUsers = currentUsers.filter((u) => u.id !== user.id);
const filteredUsers = currentUsers.filter((u) => u.id !== newUser.id);
filteredUsers.push(newUser);
roomUsers.set(roomId, filteredUsers);
+71 -10
View File
@@ -79,11 +79,30 @@ export const registerDashboardRoutes = (
logAuditEvent,
} = deps;
const getUserTrashCollectionId = (userId: string): string => `trash:${userId}`;
const isTrashCollectionId = (
collectionId: string | null | undefined,
userId: string
): boolean =>
Boolean(collectionId) &&
(collectionId === "trash" || collectionId === getUserTrashCollectionId(userId));
const toInternalTrashCollectionId = (
collectionId: string | null | undefined,
userId: string
): string | null | undefined =>
collectionId === "trash" ? getUserTrashCollectionId(userId) : collectionId;
const toPublicTrashCollectionId = (
collectionId: string | null | undefined,
userId: string
): string | null | undefined =>
isTrashCollectionId(collectionId, userId) ? "trash" : collectionId;
app.get("/drawings", requireAuth, asyncHandler(async (req, res) => {
if (!req.user) {
return res.status(401).json({ error: "Unauthorized" });
}
const trashCollectionId = getUserTrashCollectionId(req.user.id);
const { search, collectionId, includeData, limit, offset, sortField, sortDirection } = req.query;
const where: Prisma.DrawingWhereInput = { userId: req.user.id };
const searchTerm =
@@ -100,7 +119,7 @@ export const registerDashboardRoutes = (
} else if (collectionId) {
const normalizedCollectionId = String(collectionId);
if (normalizedCollectionId === "trash") {
where.collectionId = "trash";
where.collectionId = { in: [trashCollectionId, "trash"] };
collectionFilterKey = "trash";
} else {
const collection = await prisma.collection.findFirst({
@@ -113,7 +132,10 @@ export const registerDashboardRoutes = (
collectionFilterKey = `id:${normalizedCollectionId}`;
}
} else {
where.OR = [{ collectionId: { not: "trash" } }, { collectionId: null }];
where.OR = [
{ collectionId: { notIn: [trashCollectionId, "trash"] } },
{ collectionId: null },
];
}
const shouldIncludeData =
@@ -188,10 +210,16 @@ export const registerDashboardRoutes = (
if (shouldIncludeData) {
responsePayload = (drawings as any[]).map((d: any) => ({
...d,
collectionId: toPublicTrashCollectionId(d.collectionId, req.user!.id),
elements: parseJsonField(d.elements, []),
appState: parseJsonField(d.appState, {}),
files: parseJsonField(d.files, {}),
}));
} else {
responsePayload = (drawings as any[]).map((d: any) => ({
...d,
collectionId: toPublicTrashCollectionId(d.collectionId, req.user!.id),
}));
}
const finalResponse = {
@@ -223,6 +251,7 @@ export const registerDashboardRoutes = (
return res.json({
...drawing,
collectionId: toPublicTrashCollectionId(drawing.collectionId, req.user.id),
elements: parseJsonField(drawing.elements, []),
appState: parseJsonField(drawing.appState, {}),
files: parseJsonField(drawing.files, {}),
@@ -254,14 +283,16 @@ export const registerDashboardRoutes = (
files?: Record<string, unknown>;
};
const drawingName = payload.name ?? "Untitled Drawing";
const targetCollectionId = payload.collectionId === undefined ? null : payload.collectionId;
const targetCollectionIdRaw = payload.collectionId === undefined ? null : payload.collectionId;
const targetCollectionId =
toInternalTrashCollectionId(targetCollectionIdRaw, req.user.id) ?? null;
if (targetCollectionId && targetCollectionId !== "trash") {
if (targetCollectionId && !isTrashCollectionId(targetCollectionId, req.user.id)) {
const collection = await prisma.collection.findFirst({
where: { id: targetCollectionId, userId: req.user.id },
});
if (!collection) return res.status(404).json({ error: "Collection not found" });
} else if (targetCollectionId === "trash") {
} else if (targetCollectionIdRaw === "trash") {
await ensureTrashCollection(prisma, req.user.id);
}
@@ -280,6 +311,7 @@ export const registerDashboardRoutes = (
return res.json({
...newDrawing,
collectionId: toPublicTrashCollectionId(newDrawing.collectionId, req.user.id),
elements: parseJsonField(newDrawing.elements, []),
appState: parseJsonField(newDrawing.appState, {}),
files: parseJsonField(newDrawing.files, {}),
@@ -312,11 +344,14 @@ export const registerDashboardRoutes = (
files?: Record<string, unknown>;
version?: number;
};
const trashCollectionId = getUserTrashCollectionId(req.user.id);
const isSceneUpdate =
payload.elements !== undefined ||
payload.appState !== undefined ||
payload.files !== undefined;
const data: Prisma.DrawingUpdateInput = { version: { increment: 1 } };
const data: Prisma.DrawingUpdateInput = isSceneUpdate
? { version: { increment: 1 } }
: {};
if (payload.name !== undefined) data.name = payload.name;
if (payload.elements !== undefined) data.elements = JSON.stringify(payload.elements);
@@ -327,7 +362,7 @@ export const registerDashboardRoutes = (
if (payload.collectionId !== undefined) {
if (payload.collectionId === "trash") {
await ensureTrashCollection(prisma, req.user.id);
(data as Prisma.DrawingUncheckedUpdateInput).collectionId = "trash";
(data as Prisma.DrawingUncheckedUpdateInput).collectionId = trashCollectionId;
} else if (payload.collectionId) {
const collection = await prisma.collection.findFirst({
where: { id: payload.collectionId, userId: req.user.id },
@@ -374,6 +409,7 @@ export const registerDashboardRoutes = (
return res.json({
...updatedDrawing,
collectionId: toPublicTrashCollectionId(updatedDrawing.collectionId, req.user.id),
elements: parseJsonField(updatedDrawing.elements, []),
appState: parseJsonField(updatedDrawing.appState, {}),
files: parseJsonField(updatedDrawing.files, {}),
@@ -415,8 +451,10 @@ export const registerDashboardRoutes = (
const { id } = req.params;
const original = await prisma.drawing.findFirst({ where: { id, userId: req.user.id } });
if (!original) return res.status(404).json({ error: "Original drawing not found" });
if (original.collectionId === "trash") {
let duplicatedCollectionId = original.collectionId;
if (isTrashCollectionId(original.collectionId, req.user.id)) {
await ensureTrashCollection(prisma, req.user.id);
duplicatedCollectionId = getUserTrashCollectionId(req.user.id);
}
const newDrawing = await prisma.drawing.create({
@@ -426,7 +464,7 @@ export const registerDashboardRoutes = (
appState: original.appState,
files: original.files,
userId: req.user.id,
collectionId: original.collectionId,
collectionId: duplicatedCollectionId,
version: 1,
},
});
@@ -434,6 +472,7 @@ export const registerDashboardRoutes = (
return res.json({
...newDrawing,
collectionId: toPublicTrashCollectionId(newDrawing.collectionId, req.user.id),
elements: parseJsonField(newDrawing.elements, []),
appState: parseJsonField(newDrawing.appState, {}),
files: parseJsonField(newDrawing.files, {}),
@@ -442,11 +481,21 @@ export const registerDashboardRoutes = (
app.get("/collections", requireAuth, asyncHandler(async (req, res) => {
if (!req.user) return res.status(401).json({ error: "Unauthorized" });
const trashCollectionId = getUserTrashCollectionId(req.user.id);
await ensureTrashCollection(prisma, req.user.id);
const collections = await prisma.collection.findMany({
const rawCollections = await prisma.collection.findMany({
where: { userId: req.user.id },
orderBy: { createdAt: "desc" },
});
const hasInternalTrash = rawCollections.some((collection) => collection.id === trashCollectionId);
const collections = rawCollections
.filter((collection) => !(hasInternalTrash && collection.id === "trash"))
.map((collection) =>
collection.id === trashCollectionId
? { ...collection, id: "trash", name: "Trash" }
: collection
);
return res.json(collections);
}));
@@ -472,6 +521,12 @@ export const registerDashboardRoutes = (
if (!req.user) return res.status(401).json({ error: "Unauthorized" });
const { id } = req.params;
if (isTrashCollectionId(id, req.user.id)) {
return res.status(400).json({
error: "Validation error",
message: "Trash collection cannot be renamed",
});
}
const existingCollection = await prisma.collection.findFirst({
where: { id, userId: req.user.id },
});
@@ -506,6 +561,12 @@ export const registerDashboardRoutes = (
if (!req.user) return res.status(401).json({ error: "Unauthorized" });
const { id } = req.params;
if (isTrashCollectionId(id, req.user.id)) {
return res.status(400).json({
error: "Validation error",
message: "Trash collection cannot be deleted",
});
}
const collection = await prisma.collection.findFirst({
where: { id, userId: req.user.id },
});
+51 -16
View File
@@ -146,6 +146,21 @@ const normalizeNonEmptyId = (value: unknown): string | null => {
return trimmed.length > 0 ? trimmed : null;
};
const getUserTrashCollectionId = (userId: string): string => `trash:${userId}`;
const isTrashCollectionId = (
collectionId: string | null | undefined,
userId: string
): boolean =>
Boolean(collectionId) &&
(collectionId === "trash" || collectionId === getUserTrashCollectionId(userId));
const toPublicTrashCollectionId = (
collectionId: string | null | undefined,
userId: string
): string | null =>
isTrashCollectionId(collectionId, userId) ? "trash" : collectionId ?? null;
const findSqliteTable = (tables: string[], candidates: string[]): string | null => {
const byLower = new Map(tables.map((t) => [t.toLowerCase(), t]));
for (const candidate of candidates) {
@@ -264,6 +279,7 @@ export const registerImportExportRoutes = (deps: RegisterImportExportDeps) => {
app.get("/export/excalidash", requireAuth, asyncHandler(async (req, res) => {
if (!req.user) return res.status(401).json({ error: "Unauthorized" });
const trashCollectionId = getUserTrashCollectionId(req.user.id);
const extParam = typeof req.query.ext === "string" ? req.query.ext.toLowerCase() : "";
const zipSuffix = extParam === "zip";
@@ -281,10 +297,23 @@ export const registerImportExportRoutes = (deps: RegisterImportExportDeps) => {
where: { userId: req.user.id },
});
const hasTrashDrawings = drawings.some((d) => d.collectionId === "trash");
const collectionsToExport = [...userCollections];
if (hasTrashDrawings && !collectionsToExport.some((c) => c.id === "trash")) {
const trash = await prisma.collection.findUnique({ where: { id: "trash" } });
const hasInternalTrashCollection = userCollections.some((collection) => collection.id === trashCollectionId);
const normalizedUserCollections = userCollections.filter(
(collection) => !(hasInternalTrashCollection && collection.id === "trash")
);
const hasTrashDrawings = drawings.some((drawing) =>
isTrashCollectionId(drawing.collectionId, req.user!.id)
);
const collectionsToExport = [...normalizedUserCollections];
if (
hasTrashDrawings &&
!collectionsToExport.some((collection) =>
isTrashCollectionId(collection.id, req.user!.id)
)
) {
const trash = await prisma.collection.findFirst({
where: { userId: req.user.id, id: { in: [trashCollectionId, "trash"] } },
});
if (trash) collectionsToExport.push(trash);
}
@@ -309,13 +338,23 @@ export const registerImportExportRoutes = (deps: RegisterImportExportDeps) => {
id: drawing.id,
name: drawing.name,
filePath: `${folder}/${fileName}`,
collectionId: drawing.collectionId ?? null,
collectionId: toPublicTrashCollectionId(drawing.collectionId, req.user!.id),
version: drawing.version,
createdAt: drawing.createdAt.toISOString(),
updatedAt: drawing.updatedAt.toISOString(),
};
});
const manifestCollections = collectionsToExport
.map((collection) => ({
id: toPublicTrashCollectionId(collection.id, req.user!.id) || collection.id,
name: isTrashCollectionId(collection.id, req.user!.id) ? "Trash" : collection.name,
folder: folderByCollectionId.get(collection.id) || sanitizePathSegment(collection.name, "Collection"),
createdAt: collection.createdAt.toISOString(),
updatedAt: collection.updatedAt.toISOString(),
}))
.filter((collection, index, all) => all.findIndex((c) => c.id === collection.id) === index);
const manifest = {
format: "excalidash" as const,
formatVersion: 1 as const,
@@ -323,13 +362,7 @@ export const registerImportExportRoutes = (deps: RegisterImportExportDeps) => {
excalidashBackendVersion: getBackendVersion(),
userId: req.user.id,
unorganizedFolder,
collections: collectionsToExport.map((c) => ({
id: c.id,
name: c.name,
folder: folderByCollectionId.get(c.id) || sanitizePathSegment(c.name, "Collection"),
createdAt: c.createdAt.toISOString(),
updatedAt: c.updatedAt.toISOString(),
})),
collections: manifestCollections,
drawings: drawingsManifest,
};
@@ -657,6 +690,7 @@ Drawings: ${drawings.length}
}
const result = await prisma.$transaction(async (tx) => {
const trashCollectionId = getUserTrashCollectionId(req.user!.id);
const collectionIdMap = new Map<string, string>();
let collectionsCreated = 0;
let collectionsUpdated = 0;
@@ -672,7 +706,7 @@ Drawings: ${drawings.length}
for (const c of manifest.collections) {
if (c.id === "trash") {
collectionIdMap.set("trash", "trash");
collectionIdMap.set("trash", trashCollectionId);
continue;
}
@@ -707,7 +741,7 @@ Drawings: ${drawings.length}
const resolveCollectionId = (collectionId: string | null): string | null => {
if (!collectionId) return null;
if (collectionId === "trash") return "trash";
if (collectionId === "trash") return trashCollectionId;
return collectionIdMap.get(collectionId) || null;
};
@@ -1006,6 +1040,7 @@ Drawings: ${drawings.length}
}
const result = await prisma.$transaction(async (tx) => {
const trashCollectionId = getUserTrashCollectionId(req.user!.id);
const hasTrash = importedDrawings.some((d) => String(d.collectionId || "") === "trash");
if (hasTrash) await ensureTrashCollection(tx, req.user!.id);
@@ -1022,7 +1057,7 @@ Drawings: ${drawings.length}
const name = typeof c.name === "string" ? c.name : "Collection";
if (importedId === "trash" || name === "Trash") {
collectionIdMap.set(importedId || "trash", "trash");
collectionIdMap.set(importedId || "trash", trashCollectionId);
continue;
}
@@ -1071,7 +1106,7 @@ Drawings: ${drawings.length}
const id = typeof rawCollectionId === "string" ? rawCollectionId : null;
const name = typeof rawCollectionName === "string" ? rawCollectionName : null;
if (id === "trash" || name === "Trash") return "trash";
if (id === "trash" || name === "Trash") return trashCollectionId;
if (id && collectionIdMap.has(id)) return collectionIdMap.get(id)!;
if (name && collectionIdMap.has(`__name:${name}`)) return collectionIdMap.get(`__name:${name}`)!;
return null;