add WS types and basic runtime validation
This commit is contained in:
@@ -0,0 +1,137 @@
|
|||||||
|
from typing import Any, Literal, NotRequired, Optional, TypedDict, get_args, get_origin
|
||||||
|
|
||||||
|
class CreateCodeGame(TypedDict): # Client
|
||||||
|
play_as: Literal["b", "w", "r"]
|
||||||
|
time_mode: str
|
||||||
|
|
||||||
|
class JoinCodeGame(TypedDict): # Client
|
||||||
|
code: str
|
||||||
|
|
||||||
|
class CodeGameCreated(TypedDict): # Server
|
||||||
|
code: str
|
||||||
|
|
||||||
|
class JoinCodeGameSuccess(TypedDict): # Server
|
||||||
|
p1_name: str
|
||||||
|
play_as: NotRequired[Literal["b", "w", "r"]]
|
||||||
|
|
||||||
|
class JoinCodeGameFailure(TypedDict): # Server
|
||||||
|
reason: str
|
||||||
|
|
||||||
|
class P2Connected(TypedDict): # Server
|
||||||
|
p2_name: str
|
||||||
|
ready: bool
|
||||||
|
|
||||||
|
class UserReady(TypedDict): # Client or Server
|
||||||
|
ready: bool
|
||||||
|
|
||||||
|
class GameStart(TypedDict): # Server
|
||||||
|
play_as: Literal["b", "w", "r"]
|
||||||
|
time_left_ms: int
|
||||||
|
|
||||||
|
class Move_Request(TypedDict): # Client
|
||||||
|
from_square: str
|
||||||
|
to_square: str
|
||||||
|
promotion: NotRequired[str]
|
||||||
|
|
||||||
|
class Move_Accepted(TypedDict): # Server
|
||||||
|
from_square: str
|
||||||
|
to_square: str
|
||||||
|
promotion: NotRequired[str]
|
||||||
|
time_left_ms: int
|
||||||
|
|
||||||
|
class Move_Refused(TypedDict): # Server
|
||||||
|
reason: str
|
||||||
|
|
||||||
|
class RequestResign(TypedDict): # Client
|
||||||
|
accepted: NotRequired[bool] # missing = offer, true = accept, false = refuse
|
||||||
|
|
||||||
|
class RequestDraw(TypedDict): # Client
|
||||||
|
accepted: NotRequired[bool] # missing = offer, true = accept, false = refuse
|
||||||
|
|
||||||
|
class GameEnd(TypedDict): # Server
|
||||||
|
result: Literal["win", "loss", "draw"]
|
||||||
|
reason: str
|
||||||
|
|
||||||
|
|
||||||
|
#todo: implement later
|
||||||
|
class _RematchRequest(TypedDict): # Client
|
||||||
|
accepted: NotRequired[bool] # missing = offer, true = accept, false = refuse
|
||||||
|
|
||||||
|
|
||||||
|
ClientEventSchema = {
|
||||||
|
"create_code_game": CreateCodeGame,
|
||||||
|
"join_code_game": JoinCodeGame,
|
||||||
|
"user_ready": UserReady,
|
||||||
|
"move_request": Move_Request,
|
||||||
|
"request_resign": RequestResign,
|
||||||
|
"request_draw": RequestDraw,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _matches_annotation(value: Any, annotation: Any) -> bool:
|
||||||
|
"""Minimal runtime checker for the annotations used in WS payloads."""
|
||||||
|
origin = get_origin(annotation)
|
||||||
|
args = get_args(annotation)
|
||||||
|
|
||||||
|
if annotation is Any:
|
||||||
|
return True
|
||||||
|
if origin is None:
|
||||||
|
return isinstance(value, annotation)
|
||||||
|
if origin is Literal:
|
||||||
|
return value in args
|
||||||
|
if origin is Optional:
|
||||||
|
inner_type = args[0]
|
||||||
|
return value is None or _matches_annotation(value, inner_type)
|
||||||
|
if origin is list:
|
||||||
|
return isinstance(value, list) and all(_matches_annotation(v, args[0]) for v in value)
|
||||||
|
if origin is dict:
|
||||||
|
key_type, val_type = args
|
||||||
|
return isinstance(value, dict) and all(
|
||||||
|
_matches_annotation(k, key_type) and _matches_annotation(v, val_type)
|
||||||
|
for k, v in value.items()
|
||||||
|
)
|
||||||
|
if origin is tuple:
|
||||||
|
if not isinstance(value, tuple):
|
||||||
|
return False
|
||||||
|
if len(args) == 2 and args[1] is Ellipsis:
|
||||||
|
return all(_matches_annotation(v, args[0]) for v in value)
|
||||||
|
return len(value) == len(args) and all(_matches_annotation(v, t) for v, t in zip(value, args))
|
||||||
|
if origin is set:
|
||||||
|
return isinstance(value, set) and all(_matches_annotation(v, args[0]) for v in value)
|
||||||
|
|
||||||
|
# covers Union and "|" types
|
||||||
|
if args:
|
||||||
|
return any(_matches_annotation(value, arg) for arg in args)
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def validate_typed_dict_payload(payload: Any, schema: type[Any]) -> tuple[bool, str | None]:
|
||||||
|
if not isinstance(payload, dict):
|
||||||
|
return False, "payload must be an object"
|
||||||
|
|
||||||
|
required_keys = getattr(schema, "__required_keys__", set())
|
||||||
|
optional_keys = getattr(schema, "__optional_keys__", set())
|
||||||
|
allowed_keys = required_keys | optional_keys
|
||||||
|
|
||||||
|
missing = required_keys - payload.keys()
|
||||||
|
if missing:
|
||||||
|
return False, f"missing required keys: {', '.join(sorted(missing))}"
|
||||||
|
|
||||||
|
unexpected = payload.keys() - allowed_keys
|
||||||
|
if unexpected:
|
||||||
|
return False, f"unexpected keys: {', '.join(sorted(unexpected))}"
|
||||||
|
|
||||||
|
annotations = schema.__annotations__
|
||||||
|
for key, expected_type in annotations.items():
|
||||||
|
if key in payload and not _matches_annotation(payload[key], expected_type):
|
||||||
|
return False, f"invalid type for '{key}'"
|
||||||
|
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
|
||||||
|
def validate_client_event(event: str, payload: Any) -> tuple[bool, str | None]:
|
||||||
|
schema = ClientEventSchema.get(event)
|
||||||
|
if schema is None:
|
||||||
|
return False, f"unknown event '{event}'"
|
||||||
|
return validate_typed_dict_payload(payload, schema)
|
||||||
Reference in New Issue
Block a user