refactor index.ts
This commit is contained in:
@@ -0,0 +1,201 @@
|
||||
import express from "express";
|
||||
import crypto from "crypto";
|
||||
import {
|
||||
createCsrfToken,
|
||||
getCsrfTokenHeader,
|
||||
getOriginFromReferer,
|
||||
validateCsrfToken,
|
||||
} from "../security";
|
||||
import {
|
||||
CSRF_CLIENT_COOKIE_NAME,
|
||||
getCsrfClientCookieValue,
|
||||
getCsrfValidationClientIds,
|
||||
getLegacyClientId,
|
||||
} from "../security/csrfClient";
|
||||
|
||||
const CSRF_CLIENT_COOKIE_MAX_AGE_SECONDS = 60 * 60 * 24 * 30; // 30 days
|
||||
const CSRF_RATE_LIMIT_WINDOW = 60 * 1000; // 1 minute
|
||||
|
||||
type RegisterCsrfProtectionDeps = {
|
||||
app: express.Express;
|
||||
isAllowedOrigin: (origin?: string) => boolean;
|
||||
maxRequestsPerWindow: number;
|
||||
enableDebugLogging?: boolean;
|
||||
};
|
||||
|
||||
export const registerCsrfProtection = ({
|
||||
app,
|
||||
isAllowedOrigin,
|
||||
maxRequestsPerWindow,
|
||||
enableDebugLogging,
|
||||
}: RegisterCsrfProtectionDeps) => {
|
||||
const requestUsesHttps = (req: express.Request): boolean => {
|
||||
if (req.secure) return true;
|
||||
const forwardedProto = req.headers["x-forwarded-proto"];
|
||||
const raw = Array.isArray(forwardedProto) ? forwardedProto[0] : forwardedProto;
|
||||
const firstHop = String(raw || "")
|
||||
.split(",")[0]
|
||||
.trim()
|
||||
.toLowerCase();
|
||||
return firstHop === "https";
|
||||
};
|
||||
|
||||
const setCsrfClientCookie = (req: express.Request, res: express.Response, value: string): void => {
|
||||
const secure = requestUsesHttps(req) ? "; Secure" : "";
|
||||
res.append(
|
||||
"Set-Cookie",
|
||||
`${CSRF_CLIENT_COOKIE_NAME}=${encodeURIComponent(
|
||||
value
|
||||
)}; Path=/; HttpOnly; SameSite=Lax; Max-Age=${CSRF_CLIENT_COOKIE_MAX_AGE_SECONDS}${secure}`
|
||||
);
|
||||
};
|
||||
|
||||
const getClientIdForTokenIssue = (
|
||||
req: express.Request,
|
||||
res: express.Response
|
||||
): { clientId: string; strategy: "cookie" | "legacy-bootstrap" } => {
|
||||
const existingCookieValue = getCsrfClientCookieValue(req);
|
||||
if (existingCookieValue) {
|
||||
return {
|
||||
clientId: `cookie:${existingCookieValue}`,
|
||||
strategy: "cookie",
|
||||
};
|
||||
}
|
||||
|
||||
const generatedCookieValue = crypto.randomUUID().replace(/-/g, "");
|
||||
setCsrfClientCookie(req, res, generatedCookieValue);
|
||||
return {
|
||||
clientId: getLegacyClientId(req),
|
||||
strategy: "legacy-bootstrap",
|
||||
};
|
||||
};
|
||||
|
||||
const getClientIdForTokenIssueDebug = (
|
||||
req: express.Request,
|
||||
res: express.Response
|
||||
): string => {
|
||||
const { clientId, strategy } = getClientIdForTokenIssue(req, res);
|
||||
|
||||
if (enableDebugLogging) {
|
||||
const validationCandidates = getCsrfValidationClientIds(req);
|
||||
const ip = req.ip || req.connection.remoteAddress || "unknown";
|
||||
console.log("[CSRF DEBUG] getClientId", {
|
||||
method: req.method,
|
||||
path: req.path,
|
||||
ip,
|
||||
remoteAddress: req.connection.remoteAddress,
|
||||
"x-forwarded-for": req.headers["x-forwarded-for"],
|
||||
"x-real-ip": req.headers["x-real-ip"],
|
||||
hasCsrfCookie: Boolean(getCsrfClientCookieValue(req)),
|
||||
clientIdPreview: clientId.slice(0, 60) + "...",
|
||||
trustProxySetting: req.app.get("trust proxy"),
|
||||
strategy,
|
||||
validationCandidatesPreview: validationCandidates.map((candidate) =>
|
||||
`${candidate.slice(0, 60)}...`
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
return clientId;
|
||||
};
|
||||
|
||||
const csrfRateLimit = new Map<string, { count: number; resetTime: number }>();
|
||||
let csrfCleanupCounter = 0;
|
||||
|
||||
app.get("/csrf-token", (req, res) => {
|
||||
const ip = req.ip || req.connection.remoteAddress || "unknown";
|
||||
const now = Date.now();
|
||||
const clientLimit = csrfRateLimit.get(ip);
|
||||
|
||||
if (clientLimit && now < clientLimit.resetTime) {
|
||||
if (clientLimit.count >= maxRequestsPerWindow) {
|
||||
return res.status(429).json({
|
||||
error: "Rate limit exceeded",
|
||||
message: "Too many CSRF token requests",
|
||||
});
|
||||
}
|
||||
clientLimit.count++;
|
||||
} else {
|
||||
csrfRateLimit.set(ip, { count: 1, resetTime: now + CSRF_RATE_LIMIT_WINDOW });
|
||||
}
|
||||
|
||||
csrfCleanupCounter += 1;
|
||||
if (csrfCleanupCounter % 100 === 0) {
|
||||
for (const [key, data] of csrfRateLimit.entries()) {
|
||||
if (now > data.resetTime) csrfRateLimit.delete(key);
|
||||
}
|
||||
}
|
||||
|
||||
const clientId = getClientIdForTokenIssueDebug(req, res);
|
||||
const token = createCsrfToken(clientId);
|
||||
|
||||
res.json({
|
||||
token,
|
||||
header: getCsrfTokenHeader(),
|
||||
});
|
||||
});
|
||||
|
||||
const csrfProtectionMiddleware = (
|
||||
req: express.Request,
|
||||
res: express.Response,
|
||||
next: express.NextFunction
|
||||
) => {
|
||||
const safeMethods = ["GET", "HEAD", "OPTIONS"];
|
||||
if (safeMethods.includes(req.method)) {
|
||||
return next();
|
||||
}
|
||||
|
||||
const origin = req.headers["origin"];
|
||||
const referer = req.headers["referer"];
|
||||
const originValue = Array.isArray(origin) ? origin[0] : origin;
|
||||
const refererValue = Array.isArray(referer) ? referer[0] : referer;
|
||||
|
||||
if (originValue) {
|
||||
if (!isAllowedOrigin(originValue)) {
|
||||
return res.status(403).json({
|
||||
error: "CSRF origin mismatch",
|
||||
message: "Origin not allowed",
|
||||
});
|
||||
}
|
||||
} else if (refererValue) {
|
||||
const refererOrigin = getOriginFromReferer(refererValue);
|
||||
if (!refererOrigin || !isAllowedOrigin(refererOrigin)) {
|
||||
return res.status(403).json({
|
||||
error: "CSRF referer mismatch",
|
||||
message: "Referer not allowed",
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
const clientIdCandidates = getCsrfValidationClientIds(req);
|
||||
const headerName = getCsrfTokenHeader();
|
||||
const tokenHeader = req.headers[headerName];
|
||||
const token = Array.isArray(tokenHeader) ? tokenHeader[0] : tokenHeader;
|
||||
|
||||
if (!token) {
|
||||
return res.status(403).json({
|
||||
error: "CSRF token missing",
|
||||
message: `Missing ${headerName} header`,
|
||||
});
|
||||
}
|
||||
|
||||
const isValidToken = clientIdCandidates.some((clientId) =>
|
||||
validateCsrfToken(clientId, token)
|
||||
);
|
||||
if (!isValidToken) {
|
||||
return res.status(403).json({
|
||||
error: "CSRF token invalid",
|
||||
message: "Invalid or expired CSRF token. Please refresh and try again.",
|
||||
});
|
||||
}
|
||||
|
||||
next();
|
||||
};
|
||||
|
||||
app.use((req, res, next) => {
|
||||
if (req.path.startsWith("/auth/")) {
|
||||
return next();
|
||||
}
|
||||
csrfProtectionMiddleware(req, res, next);
|
||||
});
|
||||
};
|
||||
@@ -0,0 +1,63 @@
|
||||
type DrawingsCacheEntry = { body: Buffer; expiresAt: number };
|
||||
|
||||
export type DrawingsCacheKeyParts = {
|
||||
userId: string;
|
||||
searchTerm: string;
|
||||
collectionFilter: string;
|
||||
includeData: boolean;
|
||||
sortField: "name" | "createdAt" | "updatedAt";
|
||||
sortDirection: "asc" | "desc";
|
||||
};
|
||||
|
||||
export const createDrawingsCacheStore = (ttlMs: number) => {
|
||||
const drawingsCache = new Map<string, DrawingsCacheEntry>();
|
||||
|
||||
const buildDrawingsCacheKey = (keyParts: DrawingsCacheKeyParts) =>
|
||||
JSON.stringify([
|
||||
keyParts.userId,
|
||||
keyParts.searchTerm,
|
||||
keyParts.collectionFilter,
|
||||
keyParts.includeData ? "full" : "summary",
|
||||
keyParts.sortField,
|
||||
keyParts.sortDirection,
|
||||
]);
|
||||
|
||||
const getCachedDrawingsBody = (key: string): Buffer | null => {
|
||||
const entry = drawingsCache.get(key);
|
||||
if (!entry) return null;
|
||||
if (Date.now() > entry.expiresAt) {
|
||||
drawingsCache.delete(key);
|
||||
return null;
|
||||
}
|
||||
return entry.body;
|
||||
};
|
||||
|
||||
const cacheDrawingsResponse = (key: string, payload: unknown): Buffer => {
|
||||
const body = Buffer.from(JSON.stringify(payload));
|
||||
drawingsCache.set(key, {
|
||||
body,
|
||||
expiresAt: Date.now() + ttlMs,
|
||||
});
|
||||
return body;
|
||||
};
|
||||
|
||||
const invalidateDrawingsCache = () => {
|
||||
drawingsCache.clear();
|
||||
};
|
||||
|
||||
setInterval(() => {
|
||||
const now = Date.now();
|
||||
for (const [key, entry] of drawingsCache.entries()) {
|
||||
if (now > entry.expiresAt) {
|
||||
drawingsCache.delete(key);
|
||||
}
|
||||
}
|
||||
}, 60_000).unref();
|
||||
|
||||
return {
|
||||
buildDrawingsCacheKey,
|
||||
getCachedDrawingsBody,
|
||||
cacheDrawingsResponse,
|
||||
invalidateDrawingsCache,
|
||||
};
|
||||
};
|
||||
@@ -0,0 +1,221 @@
|
||||
import jwt from "jsonwebtoken";
|
||||
import { Server } from "socket.io";
|
||||
import { PrismaClient } from "../generated/client";
|
||||
import { AuthModeService } from "../auth/authMode";
|
||||
|
||||
interface User {
|
||||
id: string;
|
||||
name: string;
|
||||
initials: string;
|
||||
color: string;
|
||||
socketId: string;
|
||||
isActive: boolean;
|
||||
}
|
||||
|
||||
type RegisterSocketHandlersDeps = {
|
||||
io: Server;
|
||||
prisma: PrismaClient;
|
||||
authModeService: AuthModeService;
|
||||
jwtSecret: string;
|
||||
};
|
||||
|
||||
export const registerSocketHandlers = ({
|
||||
io,
|
||||
prisma,
|
||||
authModeService,
|
||||
jwtSecret,
|
||||
}: RegisterSocketHandlersDeps) => {
|
||||
const roomUsers = new Map<string, User[]>();
|
||||
const socketUserMap = new Map<string, string>();
|
||||
|
||||
const toPresenceName = (value: unknown): string => {
|
||||
if (typeof value !== "string") return "User";
|
||||
const trimmed = value.trim().slice(0, 120);
|
||||
return trimmed.length > 0 ? trimmed : "User";
|
||||
};
|
||||
|
||||
const toPresenceInitials = (name: string): string => {
|
||||
const words = name
|
||||
.split(/\s+/)
|
||||
.map((part) => part.trim())
|
||||
.filter((part) => part.length > 0);
|
||||
if (words.length === 0) return "U";
|
||||
const first = words[0]?.[0] ?? "";
|
||||
const second = words.length > 1 ? words[1]?.[0] ?? "" : "";
|
||||
const initials = `${first}${second}`.toUpperCase().slice(0, 2);
|
||||
return initials.length > 0 ? initials : "U";
|
||||
};
|
||||
|
||||
const toPresenceColor = (value: unknown): string => {
|
||||
if (typeof value !== "string") return "#4f46e5";
|
||||
const trimmed = value.trim();
|
||||
if (/^#[0-9a-fA-F]{3,8}$/.test(trimmed)) {
|
||||
return trimmed;
|
||||
}
|
||||
return "#4f46e5";
|
||||
};
|
||||
|
||||
const getSocketAuthUserId = async (token?: string): Promise<string | null> => {
|
||||
const authEnabled = await authModeService.getAuthEnabled();
|
||||
if (!authEnabled) {
|
||||
return "bootstrap-admin";
|
||||
}
|
||||
|
||||
if (!token) return null;
|
||||
|
||||
try {
|
||||
const decoded = jwt.verify(token, jwtSecret) as Record<string, unknown>;
|
||||
if (
|
||||
typeof decoded.userId !== "string" ||
|
||||
typeof decoded.email !== "string" ||
|
||||
decoded.type !== "access"
|
||||
) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const user = await prisma.user.findUnique({
|
||||
where: { id: decoded.userId },
|
||||
select: { id: true, isActive: true },
|
||||
});
|
||||
|
||||
if (!user || !user.isActive) return null;
|
||||
return user.id;
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
};
|
||||
|
||||
io.use(async (socket, next) => {
|
||||
try {
|
||||
const token = socket.handshake.auth?.token as string | undefined;
|
||||
const userId = await getSocketAuthUserId(token);
|
||||
|
||||
if (!userId) {
|
||||
return next(new Error("Authentication required"));
|
||||
}
|
||||
|
||||
socketUserMap.set(socket.id, userId);
|
||||
next();
|
||||
} catch {
|
||||
next(new Error("Authentication failed"));
|
||||
}
|
||||
});
|
||||
|
||||
io.on("connection", (socket) => {
|
||||
const authenticatedUserId = socketUserMap.get(socket.id);
|
||||
const authorizedDrawingIds = new Set<string>();
|
||||
|
||||
socket.on(
|
||||
"join-room",
|
||||
async ({
|
||||
drawingId,
|
||||
user,
|
||||
}: {
|
||||
drawingId: string;
|
||||
user: Omit<User, "socketId" | "isActive">;
|
||||
}) => {
|
||||
try {
|
||||
if (authenticatedUserId) {
|
||||
const drawing = await prisma.drawing.findFirst({
|
||||
where: { id: drawingId, userId: authenticatedUserId },
|
||||
select: { id: true },
|
||||
});
|
||||
|
||||
if (!drawing) {
|
||||
socket.emit("error", { message: "You do not have access to this drawing" });
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
const roomId = `drawing_${drawingId}`;
|
||||
socket.join(roomId);
|
||||
authorizedDrawingIds.add(drawingId);
|
||||
|
||||
let trustedUserId =
|
||||
typeof user?.id === "string" && user.id.trim().length > 0
|
||||
? user.id.trim().slice(0, 200)
|
||||
: socket.id;
|
||||
let trustedName = toPresenceName(user?.name);
|
||||
|
||||
if (authenticatedUserId && authenticatedUserId !== "bootstrap-admin") {
|
||||
const account = await prisma.user.findUnique({
|
||||
where: { id: authenticatedUserId },
|
||||
select: { id: true, name: true },
|
||||
});
|
||||
if (account) {
|
||||
trustedUserId = account.id;
|
||||
trustedName = toPresenceName(account.name);
|
||||
}
|
||||
}
|
||||
|
||||
const newUser: User = {
|
||||
id: trustedUserId,
|
||||
name: trustedName,
|
||||
initials: toPresenceInitials(trustedName),
|
||||
color: toPresenceColor(user?.color),
|
||||
socketId: socket.id,
|
||||
isActive: true,
|
||||
};
|
||||
|
||||
const currentUsers = roomUsers.get(roomId) || [];
|
||||
const filteredUsers = currentUsers.filter((u) => u.id !== newUser.id);
|
||||
filteredUsers.push(newUser);
|
||||
roomUsers.set(roomId, filteredUsers);
|
||||
|
||||
io.to(roomId).emit("presence-update", filteredUsers);
|
||||
} catch (err) {
|
||||
console.error("Error in join-room handler:", err);
|
||||
socket.emit("error", { message: "Failed to join room" });
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
socket.on("cursor-move", (data) => {
|
||||
const drawingId = typeof data?.drawingId === "string" ? data.drawingId : null;
|
||||
if (!drawingId || !authorizedDrawingIds.has(drawingId)) {
|
||||
return;
|
||||
}
|
||||
const roomId = `drawing_${drawingId}`;
|
||||
socket.volatile.to(roomId).emit("cursor-move", data);
|
||||
});
|
||||
|
||||
socket.on("element-update", (data) => {
|
||||
const drawingId = typeof data?.drawingId === "string" ? data.drawingId : null;
|
||||
if (!drawingId || !authorizedDrawingIds.has(drawingId)) {
|
||||
return;
|
||||
}
|
||||
const roomId = `drawing_${drawingId}`;
|
||||
socket.to(roomId).emit("element-update", data);
|
||||
});
|
||||
|
||||
socket.on(
|
||||
"user-activity",
|
||||
({ drawingId, isActive }: { drawingId: string; isActive: boolean }) => {
|
||||
if (!authorizedDrawingIds.has(drawingId)) {
|
||||
return;
|
||||
}
|
||||
const roomId = `drawing_${drawingId}`;
|
||||
const users = roomUsers.get(roomId);
|
||||
if (users) {
|
||||
const user = users.find((u) => u.socketId === socket.id);
|
||||
if (user) {
|
||||
user.isActive = isActive;
|
||||
io.to(roomId).emit("presence-update", users);
|
||||
}
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
socket.on("disconnect", () => {
|
||||
socketUserMap.delete(socket.id);
|
||||
roomUsers.forEach((users, roomId) => {
|
||||
const index = users.findIndex((u) => u.socketId === socket.id);
|
||||
if (index !== -1) {
|
||||
users.splice(index, 1);
|
||||
roomUsers.set(roomId, users);
|
||||
io.to(roomId).emit("presence-update", users);
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
};
|
||||
Reference in New Issue
Block a user