diff --git a/app/sockets/types.py b/app/sockets/types.py new file mode 100644 index 0000000..a1f2d49 --- /dev/null +++ b/app/sockets/types.py @@ -0,0 +1,137 @@ +from typing import Any, Literal, NotRequired, Optional, TypedDict, get_args, get_origin + +class CreateCodeGame(TypedDict): # Client + play_as: Literal["b", "w", "r"] + time_mode: str + +class JoinCodeGame(TypedDict): # Client + code: str + +class CodeGameCreated(TypedDict): # Server + code: str + +class JoinCodeGameSuccess(TypedDict): # Server + p1_name: str + play_as: NotRequired[Literal["b", "w", "r"]] + +class JoinCodeGameFailure(TypedDict): # Server + reason: str + +class P2Connected(TypedDict): # Server + p2_name: str + ready: bool + +class UserReady(TypedDict): # Client or Server + ready: bool + +class GameStart(TypedDict): # Server + play_as: Literal["b", "w", "r"] + time_left_ms: int + +class Move_Request(TypedDict): # Client + from_square: str + to_square: str + promotion: NotRequired[str] + +class Move_Accepted(TypedDict): # Server + from_square: str + to_square: str + promotion: NotRequired[str] + time_left_ms: int + +class Move_Refused(TypedDict): # Server + reason: str + +class RequestResign(TypedDict): # Client + accepted: NotRequired[bool] # missing = offer, true = accept, false = refuse + +class RequestDraw(TypedDict): # Client + accepted: NotRequired[bool] # missing = offer, true = accept, false = refuse + +class GameEnd(TypedDict): # Server + result: Literal["win", "loss", "draw"] + reason: str + + +#todo: implement later +class _RematchRequest(TypedDict): # Client + accepted: NotRequired[bool] # missing = offer, true = accept, false = refuse + + +ClientEventSchema = { + "create_code_game": CreateCodeGame, + "join_code_game": JoinCodeGame, + "user_ready": UserReady, + "move_request": Move_Request, + "request_resign": RequestResign, + "request_draw": RequestDraw, +} + + +def _matches_annotation(value: Any, annotation: Any) -> bool: + """Minimal runtime checker for the annotations used in WS payloads.""" + origin = get_origin(annotation) + args = get_args(annotation) + + if annotation is Any: + return True + if origin is None: + return isinstance(value, annotation) + if origin is Literal: + return value in args + if origin is Optional: + inner_type = args[0] + return value is None or _matches_annotation(value, inner_type) + if origin is list: + return isinstance(value, list) and all(_matches_annotation(v, args[0]) for v in value) + if origin is dict: + key_type, val_type = args + return isinstance(value, dict) and all( + _matches_annotation(k, key_type) and _matches_annotation(v, val_type) + for k, v in value.items() + ) + if origin is tuple: + if not isinstance(value, tuple): + return False + if len(args) == 2 and args[1] is Ellipsis: + return all(_matches_annotation(v, args[0]) for v in value) + return len(value) == len(args) and all(_matches_annotation(v, t) for v, t in zip(value, args)) + if origin is set: + return isinstance(value, set) and all(_matches_annotation(v, args[0]) for v in value) + + # covers Union and "|" types + if args: + return any(_matches_annotation(value, arg) for arg in args) + + return False + + +def validate_typed_dict_payload(payload: Any, schema: type[Any]) -> tuple[bool, str | None]: + if not isinstance(payload, dict): + return False, "payload must be an object" + + required_keys = getattr(schema, "__required_keys__", set()) + optional_keys = getattr(schema, "__optional_keys__", set()) + allowed_keys = required_keys | optional_keys + + missing = required_keys - payload.keys() + if missing: + return False, f"missing required keys: {', '.join(sorted(missing))}" + + unexpected = payload.keys() - allowed_keys + if unexpected: + return False, f"unexpected keys: {', '.join(sorted(unexpected))}" + + annotations = schema.__annotations__ + for key, expected_type in annotations.items(): + if key in payload and not _matches_annotation(payload[key], expected_type): + return False, f"invalid type for '{key}'" + + return True, None + + +def validate_client_event(event: str, payload: Any) -> tuple[bool, str | None]: + schema = ClientEventSchema.get(event) + if schema is None: + return False, f"unknown event '{event}'" + return validate_typed_dict_payload(payload, schema)