sign CSRF with cookie, Login rate-limit key hardened against identifier-only lockout

This commit is contained in:
Zimeng Xiong
2026-02-07 18:52:00 -08:00
parent fd013de325
commit 70103e18fb
6 changed files with 104 additions and 24 deletions
@@ -11,6 +11,7 @@ describe("Auth Enabled Toggle Authorization", () => {
const userAgent = "vitest-auth-enabled"; const userAgent = "vitest-auth-enabled";
let prisma: PrismaClient; let prisma: PrismaClient;
let app: any; let app: any;
let agent: any;
let csrfHeaderName: string; let csrfHeaderName: string;
let csrfToken: string; let csrfToken: string;
let regularUserToken: string; let regularUserToken: string;
@@ -79,7 +80,7 @@ describe("Auth Enabled Toggle Authorization", () => {
signOptions signOptions
); );
const agent = request.agent(app); agent = request.agent(app);
const csrfRes = await agent const csrfRes = await agent
.get("/csrf-token") .get("/csrf-token")
.set("User-Agent", userAgent); .set("User-Agent", userAgent);
@@ -92,7 +93,7 @@ describe("Auth Enabled Toggle Authorization", () => {
}); });
it("rejects unauthenticated auth-enabled toggle when auth is enabled", async () => { it("rejects unauthenticated auth-enabled toggle when auth is enabled", async () => {
const response = await request(app) const response = await agent
.post("/auth/auth-enabled") .post("/auth/auth-enabled")
.set("User-Agent", userAgent) .set("User-Agent", userAgent)
.set(csrfHeaderName, csrfToken) .set(csrfHeaderName, csrfToken)
@@ -102,7 +103,7 @@ describe("Auth Enabled Toggle Authorization", () => {
}); });
it("rejects non-admin auth-enabled toggle", async () => { it("rejects non-admin auth-enabled toggle", async () => {
const response = await request(app) const response = await agent
.post("/auth/auth-enabled") .post("/auth/auth-enabled")
.set("User-Agent", userAgent) .set("User-Agent", userAgent)
.set("Authorization", `Bearer ${regularUserToken}`) .set("Authorization", `Bearer ${regularUserToken}`)
@@ -120,7 +121,7 @@ describe("Auth Enabled Toggle Authorization", () => {
expect(warmStatusResponse.status).toBe(200); expect(warmStatusResponse.status).toBe(200);
expect(warmStatusResponse.body?.authEnabled).toBe(true); expect(warmStatusResponse.body?.authEnabled).toBe(true);
const toggleResponse = await request(app) const toggleResponse = await agent
.post("/auth/auth-enabled") .post("/auth/auth-enabled")
.set("User-Agent", userAgent) .set("User-Agent", userAgent)
.set("Authorization", `Bearer ${adminUserToken}`) .set("Authorization", `Bearer ${adminUserToken}`)
@@ -267,6 +267,7 @@ describe("Import compatibility (legacy exports)", () => {
const userAgent = "vitest-import-compat"; const userAgent = "vitest-import-compat";
let prisma: ReturnType<typeof getTestPrisma>; let prisma: ReturnType<typeof getTestPrisma>;
let app: any; let app: any;
let agent: any;
let csrfHeaderName: string; let csrfHeaderName: string;
let csrfToken: string; let csrfToken: string;
@@ -278,7 +279,8 @@ describe("Import compatibility (legacy exports)", () => {
// Import the server AFTER DATABASE_URL is set by setupTestDb/getTestPrisma. // Import the server AFTER DATABASE_URL is set by setupTestDb/getTestPrisma.
({ app } = await import("../index")); ({ app } = await import("../index"));
const csrfRes = await request(app).get("/csrf-token").set("User-Agent", userAgent); agent = request.agent(app);
const csrfRes = await agent.get("/csrf-token").set("User-Agent", userAgent);
csrfHeaderName = csrfRes.body.header; csrfHeaderName = csrfRes.body.header;
csrfToken = csrfRes.body.token; csrfToken = csrfRes.body.token;
expect(typeof csrfHeaderName).toBe("string"); expect(typeof csrfHeaderName).toBe("string");
@@ -301,7 +303,7 @@ describe("Import compatibility (legacy exports)", () => {
includeTrashDrawing: false, includeTrashDrawing: false,
}); });
const res = await request(app) const res = await agent
.post("/import/sqlite/legacy/verify") .post("/import/sqlite/legacy/verify")
.set("User-Agent", userAgent) .set("User-Agent", userAgent)
.set(csrfHeaderName, csrfToken) .set(csrfHeaderName, csrfToken)
@@ -323,7 +325,7 @@ describe("Import compatibility (legacy exports)", () => {
includeTrashDrawing: true, includeTrashDrawing: true,
}); });
const res = await request(app) const res = await agent
.post("/import/sqlite/legacy") .post("/import/sqlite/legacy")
.set("User-Agent", userAgent) .set("User-Agent", userAgent)
.set(csrfHeaderName, csrfToken) .set(csrfHeaderName, csrfToken)
@@ -359,7 +361,7 @@ describe("Import compatibility (legacy exports)", () => {
includeTrashDrawing: false, includeTrashDrawing: false,
}); });
const verify = await request(app) const verify = await agent
.post("/import/sqlite/legacy/verify") .post("/import/sqlite/legacy/verify")
.set("User-Agent", userAgent) .set("User-Agent", userAgent)
.set(csrfHeaderName, csrfToken) .set(csrfHeaderName, csrfToken)
@@ -369,7 +371,7 @@ describe("Import compatibility (legacy exports)", () => {
expect(verify.body.drawings).toBe(2); expect(verify.body.drawings).toBe(2);
expect(verify.body.collections).toBe(1); expect(verify.body.collections).toBe(1);
const res = await request(app) const res = await agent
.post("/import/sqlite/legacy") .post("/import/sqlite/legacy")
.set("User-Agent", userAgent) .set("User-Agent", userAgent)
.set(csrfHeaderName, csrfToken) .set(csrfHeaderName, csrfToken)
@@ -386,7 +388,7 @@ describe("Import compatibility (legacy exports)", () => {
db.exec(`CREATE TABLE "NotDrawing" (id TEXT PRIMARY KEY NOT NULL);`); db.exec(`CREATE TABLE "NotDrawing" (id TEXT PRIMARY KEY NOT NULL);`);
db.close(); db.close();
const res = await request(app) const res = await agent
.post("/import/sqlite/legacy/verify") .post("/import/sqlite/legacy/verify")
.set("User-Agent", userAgent) .set("User-Agent", userAgent)
.set(csrfHeaderName, csrfToken) .set(csrfHeaderName, csrfToken)
@@ -398,7 +400,7 @@ describe("Import compatibility (legacy exports)", () => {
it("rejects .excalidash verify when manifest has duplicate drawing IDs", async () => { it("rejects .excalidash verify when manifest has duplicate drawing IDs", async () => {
const archive = await createExcalidashArchiveWithDuplicateDrawingIds(); const archive = await createExcalidashArchiveWithDuplicateDrawingIds();
const res = await request(app) const res = await agent
.post("/import/excalidash/verify") .post("/import/excalidash/verify")
.set("User-Agent", userAgent) .set("User-Agent", userAgent)
.set(csrfHeaderName, csrfToken) .set(csrfHeaderName, csrfToken)
@@ -410,7 +412,7 @@ describe("Import compatibility (legacy exports)", () => {
it("rejects .excalidash import when manifest has duplicate drawing IDs", async () => { it("rejects .excalidash import when manifest has duplicate drawing IDs", async () => {
const archive = await createExcalidashArchiveWithDuplicateDrawingIds(); const archive = await createExcalidashArchiveWithDuplicateDrawingIds();
const res = await request(app) const res = await agent
.post("/import/excalidash") .post("/import/excalidash")
.set("User-Agent", userAgent) .set("User-Agent", userAgent)
.set(csrfHeaderName, csrfToken) .set(csrfHeaderName, csrfToken)
@@ -422,7 +424,7 @@ describe("Import compatibility (legacy exports)", () => {
it("rejects legacy verify when DB has duplicate drawing IDs", async () => { it("rejects legacy verify when DB has duplicate drawing IDs", async () => {
const legacyDb = createLegacySqliteDbWithDuplicateDrawingIds(); const legacyDb = createLegacySqliteDbWithDuplicateDrawingIds();
const res = await request(app) const res = await agent
.post("/import/sqlite/legacy/verify") .post("/import/sqlite/legacy/verify")
.set("User-Agent", userAgent) .set("User-Agent", userAgent)
.set(csrfHeaderName, csrfToken) .set(csrfHeaderName, csrfToken)
@@ -434,7 +436,7 @@ describe("Import compatibility (legacy exports)", () => {
it("rejects legacy import when DB has duplicate drawing IDs", async () => { it("rejects legacy import when DB has duplicate drawing IDs", async () => {
const legacyDb = createLegacySqliteDbWithDuplicateDrawingIds(); const legacyDb = createLegacySqliteDbWithDuplicateDrawingIds();
const res = await request(app) const res = await agent
.post("/import/sqlite/legacy") .post("/import/sqlite/legacy")
.set("User-Agent", userAgent) .set("User-Agent", userAgent)
.set(csrfHeaderName, csrfToken) .set(csrfHeaderName, csrfToken)
+40 -3
View File
@@ -81,6 +81,7 @@ export const createAuthRouter = (deps: CreateAuthRouterDeps): express.Router =>
let loginRateLimitConfig: LoginRateLimitConfig = { ...DEFAULT_LOGIN_RATE_LIMIT }; let loginRateLimitConfig: LoginRateLimitConfig = { ...DEFAULT_LOGIN_RATE_LIMIT };
let loginAttemptLimiter: ReturnType<typeof rateLimit> | null = null; let loginAttemptLimiter: ReturnType<typeof rateLimit> | null = null;
let loginLimiterInitPromise: Promise<void> | null = null; let loginLimiterInitPromise: Promise<void> | null = null;
let loginIdentifierKeyIndex = new Map<string, Set<string>>();
const parseLoginRateLimitConfig = ( const parseLoginRateLimitConfig = (
systemConfig: Awaited<ReturnType<typeof ensureSystemConfig>> systemConfig: Awaited<ReturnType<typeof ensureSystemConfig>>
@@ -114,8 +115,31 @@ export const createAuthRouter = (deps: CreateAuthRouterDeps): express.Router =>
return trimmed.length > 0 ? trimmed.slice(0, 255) : null; return trimmed.length > 0 ? trimmed.slice(0, 255) : null;
}; };
const resolveRateLimitIp = (req: Request): string =>
(req.ip || req.connection.remoteAddress || "unknown").slice(0, 255);
const trackIdentifierRateLimitKey = (identifier: string, key: string): void => {
if (!loginIdentifierKeyIndex.has(identifier) && loginIdentifierKeyIndex.size >= 5000) {
const oldestIdentifier = loginIdentifierKeyIndex.keys().next().value;
if (typeof oldestIdentifier === "string") {
loginIdentifierKeyIndex.delete(oldestIdentifier);
}
}
const existing = loginIdentifierKeyIndex.get(identifier) ?? new Set<string>();
if (existing.size >= 50) {
const oldestKey = existing.values().next().value;
if (typeof oldestKey === "string") {
existing.delete(oldestKey);
}
}
existing.add(key);
loginIdentifierKeyIndex.set(identifier, existing);
};
const buildLoginAttemptLimiter = (cfg: LoginRateLimitConfig) => { const buildLoginAttemptLimiter = (cfg: LoginRateLimitConfig) => {
const store = new MemoryStore(); const store = new MemoryStore();
loginIdentifierKeyIndex = new Map<string, Set<string>>();
const limiter = rateLimit({ const limiter = rateLimit({
windowMs: cfg.windowMs, windowMs: cfg.windowMs,
max: cfg.max, max: cfg.max,
@@ -131,8 +155,12 @@ export const createAuthRouter = (deps: CreateAuthRouterDeps): express.Router =>
store, store,
keyGenerator: (req) => { keyGenerator: (req) => {
const identifier = resolveAuthIdentifier(req as Request); const identifier = resolveAuthIdentifier(req as Request);
if (identifier) return `login:${identifier}`; const ip = resolveRateLimitIp(req as Request);
const ip = (req as Request).ip || "unknown"; if (identifier) {
const key = `login:${identifier}:ip:${ip}`;
trackIdentifierRateLimitKey(identifier, key);
return key;
}
return `login-ip:${ip}`; return `login-ip:${ip}`;
}, },
}); });
@@ -171,9 +199,18 @@ export const createAuthRouter = (deps: CreateAuthRouterDeps): express.Router =>
const resetLoginAttemptKey = async (identifier: string): Promise<void> => { const resetLoginAttemptKey = async (identifier: string): Promise<void> => {
await ensureLoginAttemptLimiter(); await ensureLoginAttemptLimiter();
const key = `login:${identifier}`; const normalizedIdentifier = identifier.trim().toLowerCase();
const keys = loginIdentifierKeyIndex.get(normalizedIdentifier);
try { try {
if (!keys || keys.size === 0) {
// Backward-compatible fallback for pre-change key format.
await loginAttemptLimiter?.resetKey(`login:${normalizedIdentifier}`);
return;
}
for (const key of keys) {
await loginAttemptLimiter?.resetKey(key); await loginAttemptLimiter?.resetKey(key);
}
loginIdentifierKeyIndex.delete(normalizedIdentifier);
} catch (error) { } catch (error) {
if (process.env.NODE_ENV === "development") { if (process.env.NODE_ENV === "development") {
console.debug("Rate limit reset skipped:", error); console.debug("Rate limit reset skipped:", error);
+41
View File
@@ -0,0 +1,41 @@
import express from "express";
import request from "supertest";
import { describe, expect, it } from "vitest";
import { registerCsrfProtection } from "./csrf";
describe("CSRF token issuance", () => {
it("binds first-issued tokens to cookie client identity", async () => {
const app = express();
app.use(express.json());
registerCsrfProtection({
app,
isAllowedOrigin: () => true,
maxRequestsPerWindow: 100,
});
app.post("/drawings", (_req, res) => {
res.status(200).json({ ok: true });
});
const agent = request.agent(app);
const csrfRes = await agent
.get("/csrf-token")
.set("User-Agent", "csrf-test-agent-a");
expect(csrfRes.status).toBe(200);
const headerName = csrfRes.body.header as string;
const token = csrfRes.body.token as string;
expect(typeof headerName).toBe("string");
expect(typeof token).toBe("string");
const postRes = await agent
.post("/drawings")
.set("User-Agent", "csrf-test-agent-b")
.set(headerName, token)
.send({ name: "test" });
expect(postRes.status).toBe(200);
expect(postRes.body.ok).toBe(true);
});
});
+3 -4
View File
@@ -10,7 +10,6 @@ import {
CSRF_CLIENT_COOKIE_NAME, CSRF_CLIENT_COOKIE_NAME,
getCsrfClientCookieValue, getCsrfClientCookieValue,
getCsrfValidationClientIds, getCsrfValidationClientIds,
getLegacyClientId,
} from "../security/csrfClient"; } from "../security/csrfClient";
const CSRF_CLIENT_COOKIE_MAX_AGE_SECONDS = 60 * 60 * 24 * 30; // 30 days const CSRF_CLIENT_COOKIE_MAX_AGE_SECONDS = 60 * 60 * 24 * 30; // 30 days
@@ -53,7 +52,7 @@ export const registerCsrfProtection = ({
const getClientIdForTokenIssue = ( const getClientIdForTokenIssue = (
req: express.Request, req: express.Request,
res: express.Response res: express.Response
): { clientId: string; strategy: "cookie" | "legacy-bootstrap" } => { ): { clientId: string; strategy: "cookie" } => {
const existingCookieValue = getCsrfClientCookieValue(req); const existingCookieValue = getCsrfClientCookieValue(req);
if (existingCookieValue) { if (existingCookieValue) {
return { return {
@@ -65,8 +64,8 @@ export const registerCsrfProtection = ({
const generatedCookieValue = crypto.randomUUID().replace(/-/g, ""); const generatedCookieValue = crypto.randomUUID().replace(/-/g, "");
setCsrfClientCookie(req, res, generatedCookieValue); setCsrfClientCookie(req, res, generatedCookieValue);
return { return {
clientId: getLegacyClientId(req), clientId: `cookie:${generatedCookieValue}`,
strategy: "legacy-bootstrap", strategy: "cookie",
}; };
}; };
+2 -2
View File
@@ -76,8 +76,8 @@ export const Settings: React.FC = () => {
); );
if (response.data.authEnabled) { if (response.data.authEnabled) {
// Auth enabled -> prompt admin bootstrap via register. // Auth enabled -> bootstrap registration only when required.
window.location.href = '/register'; window.location.href = response.data.bootstrapRequired ? '/register' : '/login';
return; return;
} }