feat(security): implement CSRF protection

This commit is contained in:
AdrianAcala
2025-12-21 02:47:14 -08:00
parent e75b727a5a
commit 8a78b2bb2e
25 changed files with 1157 additions and 580 deletions
+2
View File
@@ -31,7 +31,9 @@ backend/dist/
# E2E Testing
e2e/node_modules/
e2e/test-results/
e2e/test-results-user/
e2e/playwright-report/
e2e/playwright-report-user/
e2e/.playwright/
# Temporary files
+1 -1
View File
@@ -148,7 +148,7 @@ ExcaliDash/
**Backend (.env):**
```bash
DATABASE_URL="file:./prisma/dev.db"
DATABASE_URL="file:./dev.db"
PORT=8000
NODE_ENV=development
```
+1 -1
View File
@@ -75,7 +75,7 @@ See [release notes](https://github.com/ZimengXiong/ExcaliDash/releases) for a sp
# Installation
> [!CAUTION]
> NOT for production use. While attempts have been made at hardening (XSS/dompurify, CORS, rate-limiting, sanitization), they are inadequate for public deployment. Do not expose any ports. Currently lacking CSRF.
> NOT for production use. While attempts have been made at hardening (XSS/dompurify, CORS, rate-limiting, sanitization), they are inadequate for public deployment. Do not expose any ports.
> [!CAUTION]
> ExcaliDash is in BETA. Please backup your data regularly (e.g. with cron).
+197 -371
View File
File diff suppressed because it is too large Load Diff
+2 -2
View File
@@ -16,7 +16,7 @@
"dependencies": {
"@prisma/client": "^5.22.0",
"@types/archiver": "^7.0.0",
"@types/jsdom": "^27.0.0",
"@types/jsdom": "^21.1.7",
"@types/multer": "^2.0.0",
"@types/socket.io": "^3.0.1",
"archiver": "^7.0.1",
@@ -25,7 +25,7 @@
"dompurify": "^3.3.0",
"dotenv": "^17.2.3",
"express": "^5.1.0",
"jsdom": "^27.2.0",
"jsdom": "^22.1.0",
"multer": "^2.0.2",
"prisma": "^5.22.0",
"socket.io": "^4.8.1",
+168
View File
@@ -0,0 +1,168 @@
/**
* CSRF Tests - Horizontal Scaling (K8s) Validation
*
* PR #20 review concern:
* "Worried that in memory token store might not work on horizontal scaling"
*
* Fix:
* - CSRF tokens are now stateless and HMAC-signed using a shared `CSRF_SECRET`.
* - Any pod can validate any token as long as all pods share the same secret.
*
* These tests prove:
* - Tokens validate correctly for the issuing client id
* - Tokens do NOT validate for a different client id
* - Tokens expire after 24 hours
* - Tokens validate across separate module instances (simulated pods)
*/
import { describe, it, expect, beforeAll, afterEach, vi } from "vitest";
const SHARED_SECRET = "test-shared-csrf-secret";
beforeAll(() => {
// Must be shared across instances/pods for horizontal scaling.
process.env.CSRF_SECRET = SHARED_SECRET;
});
afterEach(() => {
vi.useRealTimers();
});
describe("CSRF - stateless HMAC tokens", () => {
it("creates a token in payload.signature format and validates for same client id", async () => {
const { createCsrfToken, validateCsrfToken } = await import("../security");
const clientId = "test-client-1";
const token = createCsrfToken(clientId);
expect(typeof token).toBe("string");
// base64url(payload).base64url(signature)
expect(token).toMatch(/^[A-Za-z0-9_-]+\.[A-Za-z0-9_-]+$/);
expect(validateCsrfToken(clientId, token)).toBe(true);
});
it("rejects validation for a different client id (token binding)", async () => {
const { createCsrfToken, validateCsrfToken } = await import("../security");
const token = createCsrfToken("client-a");
expect(validateCsrfToken("client-b", token)).toBe(false);
});
it("rejects malformed tokens", async () => {
const { validateCsrfToken } = await import("../security");
expect(validateCsrfToken("client", "not-a-token")).toBe(false);
expect(validateCsrfToken("client", "a.b.c")).toBe(false);
expect(validateCsrfToken("client", "")).toBe(false);
});
it("revokeCsrfToken is a no-op for stateless tokens (does not break callers)", async () => {
const { createCsrfToken, validateCsrfToken, revokeCsrfToken } = await import(
"../security"
);
const clientId = "client-revoke";
const token = createCsrfToken(clientId);
expect(validateCsrfToken(clientId, token)).toBe(true);
revokeCsrfToken(clientId);
// Stateless token remains valid until expiry
expect(validateCsrfToken(clientId, token)).toBe(true);
});
it("expires tokens after 24 hours", async () => {
vi.useFakeTimers();
vi.setSystemTime(new Date("2025-01-01T00:00:00.000Z"));
const { createCsrfToken, validateCsrfToken } = await import("../security");
const clientId = "client-expiry";
const token = createCsrfToken(clientId);
expect(validateCsrfToken(clientId, token)).toBe(true);
// 24h + 1ms later
vi.setSystemTime(new Date("2025-01-02T00:00:00.001Z"));
expect(validateCsrfToken(clientId, token)).toBe(false);
});
});
describe("CSRF - horizontal scaling (simulated pods)", () => {
it("validates across module instances (pod A issues, pod B validates)", async () => {
const clientId = "user-123";
vi.resetModules();
const podA = await import("../security");
const token = podA.createCsrfToken(clientId);
// Simulate a different pod (new Node.js process / fresh module state)
vi.resetModules();
const podB = await import("../security");
expect(podB.validateCsrfToken(clientId, token)).toBe(true);
});
it("has 0% failure rate under round-robin validation across 3 pods", async () => {
const clientId = "user-round-robin";
const pods: Array<{
createCsrfToken: (clientId: string) => string;
validateCsrfToken: (clientId: string, token: string) => boolean;
}> = [];
for (let i = 0; i < 3; i++) {
vi.resetModules();
pods.push(await import("../security"));
}
// Token issued on one pod
const token = pods[0].createCsrfToken(clientId);
// Validate on alternating pods (simulates a non-sticky load balancer)
const attempts = 60;
let failures = 0;
for (let i = 0; i < attempts; i++) {
const pod = pods[i % pods.length];
if (!pod.validateCsrfToken(clientId, token)) failures++;
}
expect(failures).toBe(0);
});
});
describe("CSRF - referer origin parsing", () => {
it("extracts exact origin from a referer URL", async () => {
const { getOriginFromReferer } = await import("../security");
expect(getOriginFromReferer("https://example.com/path?x=1")).toBe(
"https://example.com"
);
expect(getOriginFromReferer("http://localhost:5173/some/page")).toBe(
"http://localhost:5173"
);
});
it("does not allow prefix tricks (origin must be parsed)", async () => {
const { getOriginFromReferer } = await import("../security");
expect(
getOriginFromReferer("https://example.com.evil.com/anything")
).toBe("https://example.com.evil.com");
// `startsWith("https://example.com")` would incorrectly allow this.
expect(getOriginFromReferer("https://example.com@evil.com/anything")).toBe(
"https://evil.com"
);
});
it("returns null for invalid or non-http(s) referers", async () => {
const { getOriginFromReferer } = await import("../security");
expect(getOriginFromReferer("")).toBeNull();
expect(getOriginFromReferer("not a url")).toBeNull();
expect(getOriginFromReferer("file:///etc/passwd")).toBeNull();
expect(getOriginFromReferer(null)).toBeNull();
});
});
+153 -6
View File
@@ -18,6 +18,10 @@ import {
sanitizeSvg,
elementSchema,
appStateSchema,
createCsrfToken,
validateCsrfToken,
getCsrfTokenHeader,
getOriginFromReferer,
} from "./security";
dotenv.config();
@@ -34,9 +38,22 @@ const resolveDatabaseUrl = (rawUrl?: string) => {
}
const filePath = rawUrl.replace(/^file:/, "");
// Prisma treats relative SQLite paths as relative to the schema directory
// (i.e. `backend/prisma/schema.prisma`). Historically this project used
// `file:./prisma/dev.db`, which Prisma interprets as `prisma/prisma/dev.db`.
// To keep runtime and migrations aligned:
// - Prefer resolving relative paths against `backend/prisma`
// - But if the path already includes a leading `prisma/`, resolve from repo root
const prismaDir = path.resolve(backendRoot, "prisma");
const normalizedRelative = filePath.replace(/^\.\/?/, "");
const hasLeadingPrismaDir =
normalizedRelative === "prisma" ||
normalizedRelative.startsWith("prisma/");
const absolutePath = path.isAbsolute(filePath)
? filePath
: path.resolve(backendRoot, filePath);
: path.resolve(hasLeadingPrismaDir ? backendRoot : prismaDir, normalizedRelative);
return `file:${absolutePath}`;
};
@@ -63,11 +80,15 @@ const normalizeOrigins = (rawOrigins?: string | null): string[] => {
const ensureProtocol = (origin: string) =>
/^https?:\/\//i.test(origin) ? origin : `http://${origin}`;
const removeTrailingSlash = (origin: string) =>
origin.endsWith("/") ? origin.slice(0, -1) : origin;
const parsed = rawOrigins
.split(",")
.map((origin) => origin.trim())
.filter((origin) => origin.length > 0)
.map(ensureProtocol);
.map(ensureProtocol)
.map(removeTrailingSlash);
return parsed.length > 0 ? parsed : [fallback];
};
@@ -211,6 +232,8 @@ app.use(
cors({
origin: allowedOrigins,
credentials: true,
allowedHeaders: ["Content-Type", "Authorization", "x-csrf-token"],
exposedHeaders: ["x-csrf-token"],
})
);
app.use(express.json({ limit: "50mb" }));
@@ -296,6 +319,132 @@ app.use((req, res, next) => {
next();
});
// CSRF Protection Middleware
// Generates a unique client ID based on IP and User-Agent for token association
const getClientId = (req: express.Request): string => {
const ip = req.ip || req.connection.remoteAddress || "unknown";
const userAgent = req.headers["user-agent"] || "unknown";
// Create a simple hash for client identification
// In production, you might use a session ID instead
return `${ip}:${userAgent}`.slice(0, 256);
};
// Rate limiter specifically for CSRF token generation to prevent store exhaustion
const csrfRateLimit = new Map<string, { count: number; resetTime: number }>();
const CSRF_RATE_LIMIT_WINDOW = 60 * 1000; // 1 minute
const CSRF_MAX_REQUESTS = (() => {
const parsed = Number(process.env.CSRF_MAX_REQUESTS);
if (!Number.isFinite(parsed) || parsed <= 0) {
return 60; // 1 per second average
}
return parsed;
})();
// CSRF token endpoint - clients should call this to get a token
app.get("/csrf-token", (req, res) => {
const ip = req.ip || req.connection.remoteAddress || "unknown";
const now = Date.now();
const clientLimit = csrfRateLimit.get(ip);
if (clientLimit && now < clientLimit.resetTime) {
if (clientLimit.count >= CSRF_MAX_REQUESTS) {
return res.status(429).json({
error: "Rate limit exceeded",
message: "Too many CSRF token requests",
});
}
clientLimit.count++;
} else {
csrfRateLimit.set(ip, { count: 1, resetTime: now + CSRF_RATE_LIMIT_WINDOW });
}
// Cleanup old rate limit entries occasionally
if (Math.random() < 0.01) {
for (const [key, data] of csrfRateLimit.entries()) {
if (now > data.resetTime) csrfRateLimit.delete(key);
}
}
const clientId = getClientId(req);
const token = createCsrfToken(clientId);
res.json({
token,
header: getCsrfTokenHeader()
});
});
// CSRF validation middleware for state-changing requests
const csrfProtectionMiddleware = (
req: express.Request,
res: express.Response,
next: express.NextFunction
) => {
// Skip CSRF validation for safe methods (GET, HEAD, OPTIONS)
const safeMethods = ["GET", "HEAD", "OPTIONS"];
if (safeMethods.includes(req.method)) {
return next();
}
// Skip CSRF for the CSRF token endpoint itself
if (req.path === "/csrf-token") {
return next();
}
// Origin/Referer check for defense in depth
const origin = req.headers["origin"];
const referer = req.headers["referer"];
// If Origin is present, it must match allowed origins
const originValue = Array.isArray(origin) ? origin[0] : origin;
const refererValue = Array.isArray(referer) ? referer[0] : referer;
if (originValue) {
if (!allowedOrigins.includes(originValue)) {
return res.status(403).json({
error: "CSRF origin mismatch",
message: "Origin not allowed",
});
}
} else if (refererValue) {
// If no Origin but Referer exists, validate its *origin* (avoid prefix bypass)
const refererOrigin = getOriginFromReferer(refererValue);
if (!refererOrigin || !allowedOrigins.includes(refererOrigin)) {
return res.status(403).json({
error: "CSRF referer mismatch",
message: "Referer not allowed",
});
}
}
// Note: If neither Origin nor Referer is present, we proceed to token check.
// Some legitimate clients/proxies might strip these, so we don't block strictly on their absence,
// but relying on the token is the primary defense.
const clientId = getClientId(req);
const headerName = getCsrfTokenHeader();
const tokenHeader = req.headers[headerName];
const token = Array.isArray(tokenHeader) ? tokenHeader[0] : tokenHeader;
if (!token) {
return res.status(403).json({
error: "CSRF token missing",
message: `Missing ${headerName} header`,
});
}
if (!validateCsrfToken(clientId, token)) {
return res.status(403).json({
error: "CSRF token invalid",
message: "Invalid or expired CSRF token. Please refresh and try again.",
});
}
next();
};
// Apply CSRF protection to all routes
app.use(csrfProtectionMiddleware);
const filesFieldSchema = z
.union([z.record(z.string(), z.any()), z.null()])
.optional()
@@ -922,8 +1071,7 @@ app.get("/export", async (req, res) => {
res.setHeader("Content-Type", "application/octet-stream");
res.setHeader(
"Content-Disposition",
`attachment; filename="excalidash-db-${
new Date().toISOString().split("T")[0]
`attachment; filename="excalidash-db-${new Date().toISOString().split("T")[0]
}.${extension}"`
);
@@ -946,8 +1094,7 @@ app.get("/export/json", async (req, res) => {
res.setHeader("Content-Type", "application/zip");
res.setHeader(
"Content-Disposition",
`attachment; filename="excalidraw-drawings-${
new Date().toISOString().split("T")[0]
`attachment; filename="excalidraw-drawings-${new Date().toISOString().split("T")[0]
}.zip"`
);
+188
View File
@@ -1,6 +1,10 @@
/**
* Security utilities for XSS prevention, data sanitization, and CSRF protection
*/
import { z } from "zod";
import DOMPurify from "dompurify";
import { JSDOM } from "jsdom";
import crypto from "crypto";
// Create a DOM environment for DOMPurify (Node.js compatibility)
const window = new JSDOM("").window;
@@ -523,3 +527,187 @@ export const validateImportedDrawing = (data: any): boolean => {
return false;
}
};
// ============================================================================
// CSRF Protection
// ============================================================================
const CSRF_TOKEN_LENGTH = 32;
const CSRF_TOKEN_HEADER = "x-csrf-token";
const CSRF_TOKEN_EXPIRY_MS = 24 * 60 * 60 * 1000; // 24 hours
const CSRF_TOKEN_FUTURE_SKEW_MS = 5 * 60 * 1000; // 5 minutes clock skew tolerance
const CSRF_NONCE_BYTES = 16;
const CSRF_TOKEN_MAX_LENGTH = 2048; // sanity limit against abuse
/**
* IMPORTANT (Horizontal Scaling / K8s)
* -----------------------------------
* CSRF tokens must validate across multiple stateless instances.
*
* The prior in-memory Map-based token store breaks under horizontal scaling
* because each pod has its own memory. This implementation is stateless:
*
* - Token payload: { ts, nonce }
* - Signature: HMAC_SHA256(secret, `${clientId}|${ts}|${nonce}`)
*
* As long as all pods share the same `CSRF_SECRET`, any pod can validate
* any token without shared state (works on Kubernetes).
*/
let cachedCsrfSecret: Buffer | null = null;
const getCsrfSecret = (): Buffer => {
if (cachedCsrfSecret) return cachedCsrfSecret;
const secretFromEnv = process.env.CSRF_SECRET;
if (secretFromEnv && secretFromEnv.trim().length > 0) {
cachedCsrfSecret = Buffer.from(secretFromEnv, "utf8");
return cachedCsrfSecret;
}
// If not configured, generate an ephemeral secret for this process.
// This keeps single-instance deployments working out of the box, but:
// - Horizontal scaling will BREAK unless CSRF_SECRET is set and shared.
cachedCsrfSecret = crypto.randomBytes(32);
const envLabel = process.env.NODE_ENV ? ` (${process.env.NODE_ENV})` : "";
console.warn(
`[security] CSRF_SECRET is not set${envLabel}. Using an ephemeral per-process secret. ` +
"For horizontal scaling (k8s), set CSRF_SECRET to the same value on all instances."
);
return cachedCsrfSecret;
};
const base64UrlEncode = (input: Buffer | string): string => {
const buf = typeof input === "string" ? Buffer.from(input, "utf8") : input;
return buf
.toString("base64")
.replace(/\+/g, "-")
.replace(/\//g, "_")
.replace(/=+$/g, "");
};
const base64UrlDecode = (input: string): Buffer => {
const normalized = input.replace(/-/g, "+").replace(/_/g, "/");
const padded = normalized + "=".repeat((4 - (normalized.length % 4)) % 4);
return Buffer.from(padded, "base64");
};
type CsrfTokenPayload = {
/** Issued-at timestamp (ms since epoch) */
ts: number;
/** Random nonce (base64url) */
nonce: string;
};
const signCsrfToken = (clientId: string, payload: CsrfTokenPayload): Buffer => {
const secret = getCsrfSecret();
const data = `${clientId}|${payload.ts}|${payload.nonce}`;
return crypto.createHmac("sha256", secret).update(data, "utf8").digest();
};
/**
* Generate a cryptographically secure CSRF token
*/
export const generateCsrfToken = (): string => {
return crypto.randomBytes(CSRF_TOKEN_LENGTH).toString("hex");
};
/**
* Create and store a new CSRF token for a client
* Returns the token to be sent to the client
*/
export const createCsrfToken = (clientId: string): string => {
const payload: CsrfTokenPayload = {
ts: Date.now(),
nonce: base64UrlEncode(crypto.randomBytes(CSRF_NONCE_BYTES)),
};
const payloadJson = JSON.stringify(payload);
const payloadB64 = base64UrlEncode(payloadJson);
const sigB64 = base64UrlEncode(signCsrfToken(clientId, payload));
return `${payloadB64}.${sigB64}`;
};
/**
* Validate a CSRF token for a client
* Uses timing-safe comparison to prevent timing attacks
*/
export const validateCsrfToken = (clientId: string, token: string): boolean => {
if (!token || typeof token !== "string") {
return false;
}
if (token.length > CSRF_TOKEN_MAX_LENGTH) {
return false;
}
try {
const parts = token.split(".");
if (parts.length !== 2) return false;
const [payloadB64, sigB64] = parts;
const payloadJson = base64UrlDecode(payloadB64).toString("utf8");
const payload = JSON.parse(payloadJson) as Partial<CsrfTokenPayload>;
if (
typeof payload.ts !== "number" ||
!Number.isFinite(payload.ts) ||
typeof payload.nonce !== "string" ||
payload.nonce.length < 8
) {
return false;
}
const now = Date.now();
// Expiry check
if (now - payload.ts > CSRF_TOKEN_EXPIRY_MS) return false;
// Future skew check (clock mismatch)
if (payload.ts - now > CSRF_TOKEN_FUTURE_SKEW_MS) return false;
const expectedSig = signCsrfToken(clientId, {
ts: payload.ts,
nonce: payload.nonce,
});
const providedSig = base64UrlDecode(sigB64);
if (providedSig.length !== expectedSig.length) return false;
return crypto.timingSafeEqual(providedSig, expectedSig);
} catch {
return false;
}
};
/**
* Revoke a CSRF token (e.g., on logout or token refresh)
*/
export const revokeCsrfToken = (clientId: string): void => {
// Stateless CSRF tokens cannot be selectively revoked without shared state.
// If revocation is required, implement token blacklisting in a shared store
// (e.g., Redis) or rotate CSRF_SECRET.
void clientId;
};
/**
* Get the CSRF token header name
*/
export const getCsrfTokenHeader = (): string => {
return CSRF_TOKEN_HEADER;
};
export const getOriginFromReferer = (referer: unknown): string | null => {
if (typeof referer !== "string" || referer.trim().length === 0) {
return null;
}
try {
const url = new URL(referer);
if (url.protocol !== "http:" && url.protocol !== "https:") {
return null;
}
return `${url.protocol}//${url.host}`;
} catch {
return null;
}
};
+2 -4
View File
@@ -1,6 +1,4 @@
import { defineConfig } from "vitest/config";
export default defineConfig({
export default {
test: {
globals: true,
environment: "node",
@@ -20,4 +18,4 @@ export default defineConfig({
},
},
},
});
};
+2
View File
@@ -6,6 +6,8 @@ services:
- DATABASE_URL=file:/app/prisma/dev.db
- PORT=8000
- NODE_ENV=production
# Required for horizontal scaling (k8s): must be the same across all instances
- CSRF_SECRET=${CSRF_SECRET}
volumes:
- backend-data:/app/prisma
networks:
+2
View File
@@ -8,6 +8,8 @@ services:
- DATABASE_URL=file:/app/prisma/dev.db
- PORT=8000
- NODE_ENV=production
# Required for horizontal scaling (k8s): must be the same across all instances
- CSRF_SECRET=${CSRF_SECRET}
volumes:
- backend-data:/app/prisma
networks:
+1 -1
View File
@@ -1,5 +1,5 @@
# Playwright E2E Test Runner
FROM mcr.microsoft.com/playwright:v1.52.0-noble
FROM mcr.microsoft.com/playwright:v1.57.0-noble
WORKDIR /app
+13 -8
View File
@@ -17,14 +17,18 @@ services:
context: ../backend
dockerfile: Dockerfile
environment:
- DATABASE_URL=file:./prisma/e2e-test.db
# Use an absolute sqlite path so Prisma CLI + the running app always point
# at the same DB file (avoids schema being applied to a different relative path).
- DATABASE_URL=file:/app/prisma/e2e-test.db
- PORT=8000
- NODE_ENV=test
- FRONTEND_URL=http://frontend:80,http://localhost:5173
# Include both with and without :80 because browsers omit default ports in Origin.
- FRONTEND_URL=http://frontend,http://frontend:80,http://localhost:5173
ports:
- "8000:8000"
healthcheck:
test: ["CMD", "wget", "-q", "--spider", "http://localhost:8000/health"]
# Use IPv4 loopback explicitly to avoid IPv6 localhost resolution issues.
test: ["CMD", "wget", "-q", "--spider", "http://127.0.0.1:8000/health"]
interval: 5s
timeout: 5s
retries: 10
@@ -35,17 +39,18 @@ services:
# Frontend web server
frontend:
build:
context: ../frontend
dockerfile: Dockerfile
args:
- VITE_API_URL=http://backend:8000
# Use the repo root as build context because `frontend/Dockerfile` expects
# `frontend/...` paths (same as production `docker-compose.yml`).
context: ..
dockerfile: frontend/Dockerfile
ports:
- "5173:80"
depends_on:
backend:
condition: service_healthy
healthcheck:
test: ["CMD", "wget", "-q", "--spider", "http://localhost:80"]
# Use IPv4 loopback explicitly to avoid IPv6 localhost resolution issues.
test: ["CMD", "wget", "-q", "--spider", "http://127.0.0.1:80"]
interval: 5s
timeout: 5s
retries: 10
+13 -3
View File
@@ -33,11 +33,18 @@ export default defineConfig({
// Reporter configuration
reporter: [
["list"],
["html", { outputFolder: "playwright-report" }],
[
"html",
{
// Useful when a previous Docker run produced root-owned artifacts.
// Allows local runs to redirect output without editing the config.
outputFolder: process.env.PLAYWRIGHT_REPORT_DIR || "playwright-report",
},
],
],
// Output folder for test artifacts
outputDir: "test-results",
outputDir: process.env.PLAYWRIGHT_OUTPUT_DIR || "test-results",
// Global timeout for each test
timeout: 60000,
@@ -85,8 +92,11 @@ export default defineConfig({
stdout: "pipe",
stderr: "pipe",
env: {
DATABASE_URL: "file:./prisma/dev.db",
// Prisma resolves relative SQLite paths from the schema directory (backend/prisma).
// Using `file:./dev.db` avoids accidentally creating `prisma/prisma/dev.db`.
DATABASE_URL: "file:./dev.db",
FRONTEND_URL,
CSRF_MAX_REQUESTS: "1000",
},
},
{
+2 -3
View File
@@ -1,6 +1,5 @@
import { test, expect, type BrowserContext, type Page } from "@playwright/test";
import { test, expect } from "@playwright/test";
import {
API_URL,
createDrawing,
deleteDrawing,
getDrawing,
@@ -22,7 +21,7 @@ test.describe("Real-time Collaboration", () => {
for (const id of createdDrawingIds) {
try {
await deleteDrawing(request, id);
} catch (e) {
} catch {
// Ignore cleanup errors
}
}
+2 -2
View File
@@ -45,7 +45,7 @@ test.describe("Dashboard Workflows", () => {
for (const id of createdDrawingIds) {
try {
await deleteDrawing(request, id);
} catch (error) {
} catch {
// Ignore cleanup failures to keep tests resilient
}
}
@@ -54,7 +54,7 @@ test.describe("Dashboard Workflows", () => {
for (const id of createdCollectionIds) {
try {
await deleteCollection(request, id);
} catch (error) {
} catch {
// Ignore cleanup failures to keep tests resilient
}
}
+5 -6
View File
@@ -2,7 +2,6 @@ import { test, expect } from "@playwright/test";
import * as path from "path";
import * as fs from "fs";
import {
API_URL,
createDrawing,
deleteDrawing,
listDrawings,
@@ -27,7 +26,7 @@ test.describe("Drag and Drop - Collections", () => {
for (const id of createdDrawingIds) {
try {
await deleteDrawing(request, id);
} catch (e) {
} catch {
// Ignore cleanup errors
}
}
@@ -36,7 +35,7 @@ test.describe("Drag and Drop - Collections", () => {
for (const id of createdCollectionIds) {
try {
await deleteCollection(request, id);
} catch (e) {
} catch {
// Ignore cleanup errors
}
}
@@ -186,7 +185,7 @@ test.describe("Drag and Drop - File Import", () => {
for (const drawing of drawings) {
try {
await deleteDrawing(request, drawing.id);
} catch (e) {
} catch {
// Ignore cleanup errors
}
}
@@ -194,7 +193,7 @@ test.describe("Drag and Drop - File Import", () => {
for (const id of createdDrawingIds) {
try {
await deleteDrawing(request, id);
} catch (e) {
} catch {
// Ignore cleanup errors
}
}
@@ -255,7 +254,7 @@ test.describe("Drag and Drop - File Import", () => {
}
});
test("should import excalidraw file via file input", async ({ page, request }, testInfo) => {
test("should import excalidraw file via file input", async ({ page }, testInfo) => {
await page.goto("/");
await page.waitForLoadState("networkidle");
+5 -4
View File
@@ -1,5 +1,6 @@
import { test, expect } from "@playwright/test";
import {
API_URL,
createDrawing,
deleteDrawing,
getDrawing,
@@ -24,7 +25,7 @@ test.describe("Drawing Creation", () => {
for (const id of createdDrawingIds) {
try {
await deleteDrawing(request, id);
} catch (e) {
} catch {
// Ignore cleanup errors
}
}
@@ -150,7 +151,7 @@ test.describe("Drawing Editing", () => {
for (const id of createdDrawingIds) {
try {
await deleteDrawing(request, id);
} catch (e) {
} catch {
// Ignore cleanup errors
}
}
@@ -320,7 +321,7 @@ test.describe("Drawing Deletion", () => {
for (const id of createdDrawingIds) {
try {
await deleteDrawing(request, id);
} catch (e) {
} catch {
// Ignore cleanup errors
}
}
@@ -388,7 +389,7 @@ test.describe("Drawing Deletion", () => {
await expect(card).not.toBeVisible();
// Verify via API that drawing is deleted
const response = await request.get(`http://localhost:8000/drawings/${drawing.id}`);
const response = await request.get(`${API_URL}/drawings/${drawing.id}`);
expect(response.status()).toBe(404);
// Remove from cleanup list since it's already deleted
+9 -11
View File
@@ -1,12 +1,10 @@
import { test, expect } from "@playwright/test";
import * as fs from "fs";
import * as path from "path";
import {
API_URL,
createDrawing,
deleteDrawing,
getCsrfHeaders,
listDrawings,
createCollection,
deleteCollection,
} from "./helpers/api";
@@ -29,7 +27,7 @@ test.describe("Export Functionality", () => {
for (const id of createdDrawingIds) {
try {
await deleteDrawing(request, id);
} catch (e) {
} catch {
// Ignore cleanup errors
}
}
@@ -38,7 +36,7 @@ test.describe("Export Functionality", () => {
for (const id of createdCollectionIds) {
try {
await deleteCollection(request, id);
} catch (e) {
} catch {
// Ignore cleanup errors
}
}
@@ -137,7 +135,7 @@ test.describe.serial("Import Functionality", () => {
for (const drawing of testDrawings) {
try {
await deleteDrawing(request, drawing.id);
} catch (e) {
} catch {
// Ignore cleanup errors
}
}
@@ -145,7 +143,7 @@ test.describe.serial("Import Functionality", () => {
for (const id of createdDrawingIds) {
try {
await deleteDrawing(request, id);
} catch (e) {
} catch {
// Ignore cleanup errors
}
}
@@ -161,7 +159,7 @@ test.describe.serial("Import Functionality", () => {
await expect(importButton).toBeVisible();
});
test("should import .excalidraw file from Dashboard", async ({ page, request }) => {
test("should import .excalidraw file from Dashboard", async ({ page }) => {
await page.goto("/");
await page.waitForLoadState("networkidle");
@@ -206,8 +204,7 @@ test.describe.serial("Import Functionality", () => {
});
// Write temp file
const tempDir = "/tmp";
const tempFile = `${tempDir}/Import_Test_${Date.now()}.excalidraw`;
// tempFile was here
// Use page.evaluate to check if we can proceed
// Actually, Playwright has setInputFiles which can handle this
@@ -237,7 +234,7 @@ test.describe.serial("Import Functionality", () => {
await expect(importedCards.first()).toBeVisible({ timeout: 10000 });
});
test("should import JSON drawing file from Dashboard", async ({ page, request }) => {
test("should import JSON drawing file from Dashboard", async ({ page }) => {
await page.goto("/");
await page.waitForLoadState("networkidle");
@@ -394,6 +391,7 @@ test.describe("Database Import Verification", () => {
// Test that the verification endpoint responds
// We don't actually import a database as that would affect the test environment
const response = await request.post(`${API_URL}/import/sqlite/verify`, {
headers: await getCsrfHeaders(request),
// Send empty form data to test endpoint exists
multipart: {
db: {
+141 -6
View File
@@ -5,6 +5,91 @@ const DEFAULT_BACKEND_PORT = 8000;
export const API_URL = process.env.API_URL || `http://localhost:${DEFAULT_BACKEND_PORT}`;
type CsrfTokenResponse = {
token: string;
header?: string;
};
type CsrfInfo = {
token: string;
headerName: string;
};
// Cache CSRF tokens per Playwright request context so parallel tests don't race.
const csrfInfoByRequest = new WeakMap<APIRequestContext, CsrfInfo>();
const csrfFetchByRequest = new WeakMap<APIRequestContext, Promise<CsrfInfo>>();
const fetchCsrfInfo = async (request: APIRequestContext): Promise<CsrfInfo> => {
const response = await request.get(`${API_URL}/csrf-token`);
if (!response.ok()) {
const text = await response.text();
throw new Error(
`Failed to fetch CSRF token: ${response.status()} ${text || "(empty response)"}`
);
}
const data = (await response.json()) as CsrfTokenResponse;
if (!data || typeof data.token !== "string" || data.token.trim().length === 0) {
throw new Error("Failed to fetch CSRF token: missing token in response");
}
const headerName =
typeof data.header === "string" && data.header.trim().length > 0
? data.header
: "x-csrf-token";
return { token: data.token, headerName };
};
const getCsrfInfo = async (request: APIRequestContext): Promise<CsrfInfo> => {
const cached = csrfInfoByRequest.get(request);
if (cached) return cached;
const inFlight = csrfFetchByRequest.get(request);
if (inFlight) return inFlight;
const promise = fetchCsrfInfo(request)
.then((info) => {
csrfInfoByRequest.set(request, info);
return info;
})
.finally(() => {
csrfFetchByRequest.delete(request);
});
csrfFetchByRequest.set(request, promise);
return promise;
};
const refreshCsrfInfo = async (request: APIRequestContext): Promise<CsrfInfo> => {
const promise = fetchCsrfInfo(request)
.then((info) => {
csrfInfoByRequest.set(request, info);
return info;
})
.finally(() => {
csrfFetchByRequest.delete(request);
});
csrfFetchByRequest.set(request, promise);
return promise;
};
export async function getCsrfHeaders(
request: APIRequestContext
): Promise<Record<string, string>> {
const info = await getCsrfInfo(request);
return { [info.headerName]: info.token };
}
const withCsrfHeaders = async (
request: APIRequestContext,
headers: Record<string, string> = {}
): Promise<Record<string, string>> => ({
...headers,
...(await getCsrfHeaders(request)),
});
export interface DrawingRecord {
id: string;
name: string;
@@ -53,10 +138,26 @@ export async function createDrawing(
overrides: CreateDrawingOptions = {}
): Promise<DrawingRecord> {
const payload = { ...defaultDrawingPayload(), ...overrides };
const response = await request.post(`${API_URL}/drawings`, {
headers: { "Content-Type": "application/json" },
const headers = await withCsrfHeaders(request, { "Content-Type": "application/json" });
let response = await request.post(`${API_URL}/drawings`, {
headers,
data: payload,
});
// Retry once with a fresh token in case it expired or the cache was primed under
// a different clientId (rare, but can happen under parallelism / CI proxies).
if (!response.ok() && response.status() === 403) {
await refreshCsrfInfo(request);
const retryHeaders = await withCsrfHeaders(request, {
"Content-Type": "application/json",
});
response = await request.post(`${API_URL}/drawings`, {
headers: retryHeaders,
data: payload,
});
}
if (!response.ok()) {
const text = await response.text();
throw new Error(`Failed to create drawing: ${response.status()} ${text}`);
@@ -77,7 +178,17 @@ export async function deleteDrawing(
request: APIRequestContext,
id: string
): Promise<void> {
const response = await request.delete(`${API_URL}/drawings/${id}`);
const headers = await withCsrfHeaders(request);
let response = await request.delete(`${API_URL}/drawings/${id}`, { headers });
if (!response.ok() && response.status() === 403) {
await refreshCsrfInfo(request);
const retryHeaders = await withCsrfHeaders(request);
response = await request.delete(`${API_URL}/drawings/${id}`, {
headers: retryHeaders,
});
}
if (!response.ok()) {
// Ignore not found to keep cleanup idempotent
if (response.status() !== 404) {
@@ -113,10 +224,24 @@ export async function createCollection(
request: APIRequestContext,
name: string
): Promise<CollectionRecord> {
const response = await request.post(`${API_URL}/collections`, {
headers: { "Content-Type": "application/json" },
const headers = await withCsrfHeaders(request, { "Content-Type": "application/json" });
let response = await request.post(`${API_URL}/collections`, {
headers,
data: { name },
});
if (!response.ok() && response.status() === 403) {
await refreshCsrfInfo(request);
const retryHeaders = await withCsrfHeaders(request, {
"Content-Type": "application/json",
});
response = await request.post(`${API_URL}/collections`, {
headers: retryHeaders,
data: { name },
});
}
expect(response.ok()).toBe(true);
return (await response.json()) as CollectionRecord;
}
@@ -133,7 +258,17 @@ export async function deleteCollection(
request: APIRequestContext,
id: string
): Promise<void> {
const response = await request.delete(`${API_URL}/collections/${id}`);
const headers = await withCsrfHeaders(request);
let response = await request.delete(`${API_URL}/collections/${id}`, { headers });
if (!response.ok() && response.status() === 403) {
await refreshCsrfInfo(request);
const retryHeaders = await withCsrfHeaders(request);
response = await request.delete(`${API_URL}/collections/${id}`, {
headers: retryHeaders,
});
}
if (!response.ok()) {
if (response.status() !== 404) {
const text = await response.text();
+17 -5
View File
@@ -1,7 +1,13 @@
import { test, expect } from "@playwright/test";
import * as fs from "fs";
import * as path from "path";
import { API_URL, createDrawing, deleteDrawing, getDrawing } from "./helpers/api";
import {
API_URL,
createDrawing,
deleteDrawing,
getCsrfHeaders,
getDrawing,
} from "./helpers/api";
/**
* E2E Browser Tests for Image Persistence - Issue #17 Regression
@@ -34,7 +40,7 @@ test.describe("Image Persistence - Browser E2E Tests", () => {
for (const id of testDrawingIds) {
try {
await deleteDrawing(request, id);
} catch (e) {
} catch {
// Ignore cleanup errors
}
}
@@ -119,7 +125,7 @@ test.describe("Image Persistence - Browser E2E Tests", () => {
await expect(editorContainer).toBeVisible({ timeout: 10000 });
});
test("should import .excalidraw file with embedded image", async ({ page, request }) => {
test("should import .excalidraw file with embedded image", async ({ request }) => {
// Load the test fixture
const fixturePath = path.join(__dirname, "..", "fixtures", "small-image.excalidraw");
const fixtureContent = fs.readFileSync(fixturePath, "utf-8");
@@ -196,6 +202,7 @@ test.describe("Security - Malicious Content Blocking", () => {
const response = await request.post(`${API_URL}/drawings`, {
headers: {
"Content-Type": "application/json",
...(await getCsrfHeaders(request)),
},
data: {
name: "Security Test - JS URL",
@@ -218,7 +225,9 @@ test.describe("Security - Malicious Content Blocking", () => {
expect(savedFiles["malicious-image"].dataURL).not.toContain("javascript:");
// Cleanup
await request.delete(`${API_URL}/drawings/${drawing.id}`);
await request.delete(`${API_URL}/drawings/${drawing.id}`, {
headers: await getCsrfHeaders(request),
});
});
test("should block script tags in image data", async ({ request }) => {
@@ -234,6 +243,7 @@ test.describe("Security - Malicious Content Blocking", () => {
const response = await request.post(`${API_URL}/drawings`, {
headers: {
"Content-Type": "application/json",
...(await getCsrfHeaders(request)),
},
data: {
name: "Security Test - Script Tag",
@@ -256,6 +266,8 @@ test.describe("Security - Malicious Content Blocking", () => {
expect(savedFiles["malicious-image"].dataURL).not.toContain("<script>");
// Cleanup
await request.delete(`${API_URL}/drawings/${drawing.id}`);
await request.delete(`${API_URL}/drawings/${drawing.id}`, {
headers: await getCsrfHeaders(request),
});
});
});
+2 -3
View File
@@ -2,7 +2,6 @@ import { test, expect } from "@playwright/test";
import {
createDrawing,
deleteDrawing,
listDrawings,
} from "./helpers/api";
/**
@@ -21,7 +20,7 @@ test.describe("Search Drawings", () => {
for (const id of createdDrawingIds) {
try {
await deleteDrawing(request, id);
} catch (e) {
} catch {
// Ignore cleanup errors
}
}
@@ -135,7 +134,7 @@ test.describe("Sort Drawings", () => {
for (const id of createdDrawingIds) {
try {
await deleteDrawing(request, id);
} catch (e) {
} catch {
// Ignore cleanup errors
}
}
+1 -1
View File
@@ -22,7 +22,7 @@ export default defineConfig({
// Locally, you may need to start them manually or use npm run dev
webServer: process.env.CI ? [
{
command: "cd ../backend && DATABASE_URL=file:./prisma/dev.db npm run dev",
command: "cd ../backend && DATABASE_URL=file:./dev.db npm run dev",
url: "http://localhost:8000/health",
reuseExistingServer: false,
timeout: 120000,
+85
View File
@@ -7,6 +7,91 @@ export const api = axios.create({
baseURL: API_URL,
});
// CSRF Token Management
let csrfToken: string | null = null;
let csrfHeaderName: string = "x-csrf-token";
let csrfTokenPromise: Promise<void> | null = null;
/**
* Fetch a fresh CSRF token from the server
*/
export const fetchCsrfToken = async (): Promise<void> => {
try {
const response = await axios.get<{ token: string; header: string }>(
`${API_URL}/csrf-token`
);
csrfToken = response.data.token;
csrfHeaderName = response.data.header || "x-csrf-token";
} catch (error) {
console.error("Failed to fetch CSRF token:", error);
throw error;
}
};
/**
* Ensure we have a valid CSRF token, fetching one if needed
*/
const ensureCsrfToken = async (): Promise<void> => {
if (csrfToken) return;
// Prevent multiple simultaneous token fetches
if (!csrfTokenPromise) {
csrfTokenPromise = fetchCsrfToken().finally(() => {
csrfTokenPromise = null;
});
}
await csrfTokenPromise;
};
/**
* Clear the cached CSRF token (useful for handling 403 errors)
*/
export const clearCsrfToken = (): void => {
csrfToken = null;
};
// Add request interceptor to include CSRF token
api.interceptors.request.use(
async (config) => {
// Only add CSRF token for state-changing methods
const method = config.method?.toUpperCase();
if (method && ["POST", "PUT", "DELETE", "PATCH"].includes(method)) {
await ensureCsrfToken();
if (csrfToken) {
config.headers[csrfHeaderName] = csrfToken;
}
}
return config;
},
(error) => Promise.reject(error)
);
// Add response interceptor to handle CSRF token errors
api.interceptors.response.use(
(response) => response,
async (error) => {
// If we get a 403 with CSRF error, clear token and retry once
if (
error.response?.status === 403 &&
error.response?.data?.error?.includes("CSRF")
) {
clearCsrfToken();
// Retry the request once with a fresh token
const originalRequest = error.config;
if (!originalRequest._csrfRetry) {
originalRequest._csrfRetry = true;
await fetchCsrfToken();
if (csrfToken) {
originalRequest.headers[csrfHeaderName] = csrfToken;
}
return api(originalRequest);
}
}
return Promise.reject(error);
}
);
const coerceTimestamp = (value: string | number | Date): number => {
if (typeof value === "number") return value;
if (value instanceof Date) return value.getTime();
+9 -8
View File
@@ -1,5 +1,5 @@
import { exportToSvg } from "@excalidraw/excalidraw";
import { API_URL } from "../api";
import { api } from "../api";
export const importDrawings = async (
files: File[],
@@ -50,21 +50,22 @@ export const importDrawings = async (
preview: svg.outerHTML,
};
const res = await fetch(`${API_URL}/drawings`, {
method: "POST",
await api.post("/drawings", payload, {
headers: {
"Content-Type": "application/json",
// Backend uses this header to apply stricter validation for imported files.
"X-Imported-File": "true",
},
body: JSON.stringify(payload),
});
if (!res.ok) throw new Error("API Error");
successCount++;
} catch (err: any) {
console.error(`Failed to import ${file.name}:`, err);
failCount++;
errors.push(`${file.name}: ${err.message}`);
const apiMessage =
err?.response?.data?.message ||
err?.response?.data?.error ||
err?.message ||
"API Error";
errors.push(`${file.name}: ${apiMessage}`);
}
})
);