diff --git a/backend/src/__tests__/auth-enabled.integration.ts b/backend/src/__tests__/auth-enabled.integration.ts index 0ac5785..87fe864 100644 --- a/backend/src/__tests__/auth-enabled.integration.ts +++ b/backend/src/__tests__/auth-enabled.integration.ts @@ -11,6 +11,7 @@ describe("Auth Enabled Toggle Authorization", () => { const userAgent = "vitest-auth-enabled"; let prisma: PrismaClient; let app: any; + let agent: any; let csrfHeaderName: string; let csrfToken: string; let regularUserToken: string; @@ -79,7 +80,7 @@ describe("Auth Enabled Toggle Authorization", () => { signOptions ); - const agent = request.agent(app); + agent = request.agent(app); const csrfRes = await agent .get("/csrf-token") .set("User-Agent", userAgent); @@ -92,7 +93,7 @@ describe("Auth Enabled Toggle Authorization", () => { }); it("rejects unauthenticated auth-enabled toggle when auth is enabled", async () => { - const response = await request(app) + const response = await agent .post("/auth/auth-enabled") .set("User-Agent", userAgent) .set(csrfHeaderName, csrfToken) @@ -102,7 +103,7 @@ describe("Auth Enabled Toggle Authorization", () => { }); it("rejects non-admin auth-enabled toggle", async () => { - const response = await request(app) + const response = await agent .post("/auth/auth-enabled") .set("User-Agent", userAgent) .set("Authorization", `Bearer ${regularUserToken}`) @@ -120,7 +121,7 @@ describe("Auth Enabled Toggle Authorization", () => { expect(warmStatusResponse.status).toBe(200); expect(warmStatusResponse.body?.authEnabled).toBe(true); - const toggleResponse = await request(app) + const toggleResponse = await agent .post("/auth/auth-enabled") .set("User-Agent", userAgent) .set("Authorization", `Bearer ${adminUserToken}`) diff --git a/backend/src/__tests__/imports-compat.integration.ts b/backend/src/__tests__/imports-compat.integration.ts index 6d32e98..ace4d69 100644 --- a/backend/src/__tests__/imports-compat.integration.ts +++ b/backend/src/__tests__/imports-compat.integration.ts @@ -267,6 +267,7 @@ describe("Import compatibility (legacy exports)", () => { const userAgent = "vitest-import-compat"; let prisma: ReturnType; let app: any; + let agent: any; let csrfHeaderName: string; let csrfToken: string; @@ -278,7 +279,8 @@ describe("Import compatibility (legacy exports)", () => { // Import the server AFTER DATABASE_URL is set by setupTestDb/getTestPrisma. ({ app } = await import("../index")); - const csrfRes = await request(app).get("/csrf-token").set("User-Agent", userAgent); + agent = request.agent(app); + const csrfRes = await agent.get("/csrf-token").set("User-Agent", userAgent); csrfHeaderName = csrfRes.body.header; csrfToken = csrfRes.body.token; expect(typeof csrfHeaderName).toBe("string"); @@ -301,7 +303,7 @@ describe("Import compatibility (legacy exports)", () => { includeTrashDrawing: false, }); - const res = await request(app) + const res = await agent .post("/import/sqlite/legacy/verify") .set("User-Agent", userAgent) .set(csrfHeaderName, csrfToken) @@ -323,7 +325,7 @@ describe("Import compatibility (legacy exports)", () => { includeTrashDrawing: true, }); - const res = await request(app) + const res = await agent .post("/import/sqlite/legacy") .set("User-Agent", userAgent) .set(csrfHeaderName, csrfToken) @@ -359,7 +361,7 @@ describe("Import compatibility (legacy exports)", () => { includeTrashDrawing: false, }); - const verify = await request(app) + const verify = await agent .post("/import/sqlite/legacy/verify") .set("User-Agent", userAgent) .set(csrfHeaderName, csrfToken) @@ -369,7 +371,7 @@ describe("Import compatibility (legacy exports)", () => { expect(verify.body.drawings).toBe(2); expect(verify.body.collections).toBe(1); - const res = await request(app) + const res = await agent .post("/import/sqlite/legacy") .set("User-Agent", userAgent) .set(csrfHeaderName, csrfToken) @@ -386,7 +388,7 @@ describe("Import compatibility (legacy exports)", () => { db.exec(`CREATE TABLE "NotDrawing" (id TEXT PRIMARY KEY NOT NULL);`); db.close(); - const res = await request(app) + const res = await agent .post("/import/sqlite/legacy/verify") .set("User-Agent", userAgent) .set(csrfHeaderName, csrfToken) @@ -398,7 +400,7 @@ describe("Import compatibility (legacy exports)", () => { it("rejects .excalidash verify when manifest has duplicate drawing IDs", async () => { const archive = await createExcalidashArchiveWithDuplicateDrawingIds(); - const res = await request(app) + const res = await agent .post("/import/excalidash/verify") .set("User-Agent", userAgent) .set(csrfHeaderName, csrfToken) @@ -410,7 +412,7 @@ describe("Import compatibility (legacy exports)", () => { it("rejects .excalidash import when manifest has duplicate drawing IDs", async () => { const archive = await createExcalidashArchiveWithDuplicateDrawingIds(); - const res = await request(app) + const res = await agent .post("/import/excalidash") .set("User-Agent", userAgent) .set(csrfHeaderName, csrfToken) @@ -422,7 +424,7 @@ describe("Import compatibility (legacy exports)", () => { it("rejects legacy verify when DB has duplicate drawing IDs", async () => { const legacyDb = createLegacySqliteDbWithDuplicateDrawingIds(); - const res = await request(app) + const res = await agent .post("/import/sqlite/legacy/verify") .set("User-Agent", userAgent) .set(csrfHeaderName, csrfToken) @@ -434,7 +436,7 @@ describe("Import compatibility (legacy exports)", () => { it("rejects legacy import when DB has duplicate drawing IDs", async () => { const legacyDb = createLegacySqliteDbWithDuplicateDrawingIds(); - const res = await request(app) + const res = await agent .post("/import/sqlite/legacy") .set("User-Agent", userAgent) .set(csrfHeaderName, csrfToken) diff --git a/backend/src/auth.ts b/backend/src/auth.ts index 8171cfe..1015eef 100644 --- a/backend/src/auth.ts +++ b/backend/src/auth.ts @@ -81,6 +81,7 @@ export const createAuthRouter = (deps: CreateAuthRouterDeps): express.Router => let loginRateLimitConfig: LoginRateLimitConfig = { ...DEFAULT_LOGIN_RATE_LIMIT }; let loginAttemptLimiter: ReturnType | null = null; let loginLimiterInitPromise: Promise | null = null; + let loginIdentifierKeyIndex = new Map>(); const parseLoginRateLimitConfig = ( systemConfig: Awaited> @@ -114,8 +115,31 @@ export const createAuthRouter = (deps: CreateAuthRouterDeps): express.Router => return trimmed.length > 0 ? trimmed.slice(0, 255) : null; }; + const resolveRateLimitIp = (req: Request): string => + (req.ip || req.connection.remoteAddress || "unknown").slice(0, 255); + + const trackIdentifierRateLimitKey = (identifier: string, key: string): void => { + if (!loginIdentifierKeyIndex.has(identifier) && loginIdentifierKeyIndex.size >= 5000) { + const oldestIdentifier = loginIdentifierKeyIndex.keys().next().value; + if (typeof oldestIdentifier === "string") { + loginIdentifierKeyIndex.delete(oldestIdentifier); + } + } + + const existing = loginIdentifierKeyIndex.get(identifier) ?? new Set(); + if (existing.size >= 50) { + const oldestKey = existing.values().next().value; + if (typeof oldestKey === "string") { + existing.delete(oldestKey); + } + } + existing.add(key); + loginIdentifierKeyIndex.set(identifier, existing); + }; + const buildLoginAttemptLimiter = (cfg: LoginRateLimitConfig) => { const store = new MemoryStore(); + loginIdentifierKeyIndex = new Map>(); const limiter = rateLimit({ windowMs: cfg.windowMs, max: cfg.max, @@ -131,8 +155,12 @@ export const createAuthRouter = (deps: CreateAuthRouterDeps): express.Router => store, keyGenerator: (req) => { const identifier = resolveAuthIdentifier(req as Request); - if (identifier) return `login:${identifier}`; - const ip = (req as Request).ip || "unknown"; + const ip = resolveRateLimitIp(req as Request); + if (identifier) { + const key = `login:${identifier}:ip:${ip}`; + trackIdentifierRateLimitKey(identifier, key); + return key; + } return `login-ip:${ip}`; }, }); @@ -171,9 +199,18 @@ export const createAuthRouter = (deps: CreateAuthRouterDeps): express.Router => const resetLoginAttemptKey = async (identifier: string): Promise => { await ensureLoginAttemptLimiter(); - const key = `login:${identifier}`; + const normalizedIdentifier = identifier.trim().toLowerCase(); + const keys = loginIdentifierKeyIndex.get(normalizedIdentifier); try { - await loginAttemptLimiter?.resetKey(key); + if (!keys || keys.size === 0) { + // Backward-compatible fallback for pre-change key format. + await loginAttemptLimiter?.resetKey(`login:${normalizedIdentifier}`); + return; + } + for (const key of keys) { + await loginAttemptLimiter?.resetKey(key); + } + loginIdentifierKeyIndex.delete(normalizedIdentifier); } catch (error) { if (process.env.NODE_ENV === "development") { console.debug("Rate limit reset skipped:", error); diff --git a/backend/src/server/csrf.test.ts b/backend/src/server/csrf.test.ts new file mode 100644 index 0000000..9eb9557 --- /dev/null +++ b/backend/src/server/csrf.test.ts @@ -0,0 +1,41 @@ +import express from "express"; +import request from "supertest"; +import { describe, expect, it } from "vitest"; +import { registerCsrfProtection } from "./csrf"; + +describe("CSRF token issuance", () => { + it("binds first-issued tokens to cookie client identity", async () => { + const app = express(); + app.use(express.json()); + + registerCsrfProtection({ + app, + isAllowedOrigin: () => true, + maxRequestsPerWindow: 100, + }); + + app.post("/drawings", (_req, res) => { + res.status(200).json({ ok: true }); + }); + + const agent = request.agent(app); + const csrfRes = await agent + .get("/csrf-token") + .set("User-Agent", "csrf-test-agent-a"); + + expect(csrfRes.status).toBe(200); + const headerName = csrfRes.body.header as string; + const token = csrfRes.body.token as string; + expect(typeof headerName).toBe("string"); + expect(typeof token).toBe("string"); + + const postRes = await agent + .post("/drawings") + .set("User-Agent", "csrf-test-agent-b") + .set(headerName, token) + .send({ name: "test" }); + + expect(postRes.status).toBe(200); + expect(postRes.body.ok).toBe(true); + }); +}); diff --git a/backend/src/server/csrf.ts b/backend/src/server/csrf.ts index 9b09360..03f9640 100644 --- a/backend/src/server/csrf.ts +++ b/backend/src/server/csrf.ts @@ -10,7 +10,6 @@ import { CSRF_CLIENT_COOKIE_NAME, getCsrfClientCookieValue, getCsrfValidationClientIds, - getLegacyClientId, } from "../security/csrfClient"; const CSRF_CLIENT_COOKIE_MAX_AGE_SECONDS = 60 * 60 * 24 * 30; // 30 days @@ -53,7 +52,7 @@ export const registerCsrfProtection = ({ const getClientIdForTokenIssue = ( req: express.Request, res: express.Response - ): { clientId: string; strategy: "cookie" | "legacy-bootstrap" } => { + ): { clientId: string; strategy: "cookie" } => { const existingCookieValue = getCsrfClientCookieValue(req); if (existingCookieValue) { return { @@ -65,8 +64,8 @@ export const registerCsrfProtection = ({ const generatedCookieValue = crypto.randomUUID().replace(/-/g, ""); setCsrfClientCookie(req, res, generatedCookieValue); return { - clientId: getLegacyClientId(req), - strategy: "legacy-bootstrap", + clientId: `cookie:${generatedCookieValue}`, + strategy: "cookie", }; }; diff --git a/frontend/src/pages/Settings.tsx b/frontend/src/pages/Settings.tsx index 6372a27..6eababb 100644 --- a/frontend/src/pages/Settings.tsx +++ b/frontend/src/pages/Settings.tsx @@ -76,8 +76,8 @@ export const Settings: React.FC = () => { ); if (response.data.authEnabled) { - // Auth enabled -> prompt admin bootstrap via register. - window.location.href = '/register'; + // Auth enabled -> bootstrap registration only when required. + window.location.href = response.data.bootstrapRequired ? '/register' : '/login'; return; }