170 lines
5.5 KiB
Python
170 lines
5.5 KiB
Python
import base64
|
|
import binascii
|
|
import json
|
|
import logging
|
|
from collections.abc import Callable
|
|
from json import JSONDecodeError
|
|
|
|
from Crypto.PublicKey import ECC
|
|
from Crypto.PublicKey.ECC import EccKey
|
|
|
|
from classes.Message import Message
|
|
from classes.MessageTypes.Announcement import AnnouncementMessage
|
|
from classes.MessageTypes.Introduction import IntroductionMessage
|
|
from classes.MessageTypes.Ready import ReadyMessage
|
|
from classes.MessageTypes.Shuffle import ShuffleMessage
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class MessageHandler:
|
|
def __init__(self):
|
|
self.receivers: list[Callable[[Message], None]] = []
|
|
|
|
def send_message(self, message: Message, signing_key: EccKey):
|
|
# Must be implemented by child classes
|
|
pass
|
|
|
|
def add_message_receiver(self, message_receiver: Callable[[Message], None]):
|
|
self.receivers.append(message_receiver)
|
|
|
|
def remove_message_receiver(self, message_receiver: Callable[[Message], None]):
|
|
self.receivers.remove(message_receiver)
|
|
|
|
def decode_received_message(self, message_string: str):
|
|
try:
|
|
message_object = json.loads(message_string)
|
|
except JSONDecodeError | UnicodeDecodeError:
|
|
logger.error(f'Could not decode received string {message_string}')
|
|
return
|
|
|
|
message_type = message_object.get('type')
|
|
if message_type is None:
|
|
logger.error(f'Message type not found in message {message_string}')
|
|
return
|
|
|
|
message = None
|
|
|
|
if message_type == 'introduction':
|
|
name = message_object.get('name')
|
|
seed_commit = message_object.get('seed_commit')
|
|
key = message_object.get('key')
|
|
|
|
if None in [name, seed_commit, key]:
|
|
logger.error(f'Did not find expected fields for introduction message: {message_string}')
|
|
return
|
|
|
|
elif not isinstance(name, str) or not isinstance(seed_commit, str) or not isinstance(key, str):
|
|
logger.error(f'Received data of the wrong type for introduction message: {message_string}')
|
|
return
|
|
|
|
try:
|
|
key = ECC.import_key(key)
|
|
except ValueError:
|
|
logger.error(f'{key} is not a valid key')
|
|
return
|
|
|
|
try:
|
|
seed_commit = base64.b64decode(seed_commit, validate=True)
|
|
except binascii.Error:
|
|
logger.error(f'Seed commit {seed_commit} is not a valid base64 string')
|
|
return
|
|
|
|
message = IntroductionMessage(name, key, seed_commit)
|
|
|
|
elif message_type == 'ready':
|
|
name = message_object.get('name')
|
|
participants = message_object.get('participants')
|
|
random_seed = message_object.get('random_seed')
|
|
|
|
if None in [name, participants, random_seed]:
|
|
logger.error(f'Did not find expected fields for ready message {message_string}')
|
|
return
|
|
|
|
elif not isinstance(name, str) or not isinstance(participants, list) or not isinstance(random_seed, str):
|
|
logger.error(f'Received data of the wrong type for ready message: {message_string}')
|
|
return
|
|
|
|
elif not all(isinstance(participant, tuple) for participant in participants):
|
|
logger.error(f'Not all participants in participant list are tuples: {message_string}')
|
|
return
|
|
|
|
elif not all(len(participant_tuple) == 2 for participant_tuple in participants):
|
|
logger.error(f'Not all participant tuples are of length two in {message_string}')
|
|
return
|
|
|
|
elif not all(isinstance(name, str) and isinstance(key, bytes) for name, key in participants):
|
|
logger.error(f'Not all participant tuples contain a name and a key')
|
|
return
|
|
|
|
decoded_participants = []
|
|
for participant_name, key in participants:
|
|
try:
|
|
key = ECC.import_key(key)
|
|
decoded_participants.append((participant_name, key))
|
|
except ValueError:
|
|
logger.error(f'Could not decode public key {key} of participant {participant_name}')
|
|
return
|
|
|
|
try:
|
|
random_seed = base64.b64decode(random_seed, validate=True)
|
|
except binascii.Error:
|
|
logger.error(f'Random seed {random_seed} from {name} is not a valid base64 string')
|
|
return
|
|
|
|
message = ReadyMessage(name, participants, random_seed)
|
|
|
|
elif message_type == 'shuffle':
|
|
name = message_object.get('name')
|
|
cards = message_object.get('cards')
|
|
stage = message_object.get('stage')
|
|
|
|
if None in [name, cards, stage]:
|
|
logger.error(f'Did not receive all expected fields for shuffle message: {message_string}')
|
|
return
|
|
|
|
elif not isinstance(name, str) or not isinstance(cards, list) or not isinstance(stage, str):
|
|
logger.error(f'Received fields were not correct type for shuffle message: {message_string}')
|
|
return
|
|
|
|
elif not all(isinstance(card, str) for card in cards):
|
|
logger.error(f'All received cards were not of type string: {message_string}')
|
|
return
|
|
|
|
new_cards = []
|
|
for card in cards:
|
|
try:
|
|
new_cards.append(base64.b64decode(card, validate=True))
|
|
except binascii.Error:
|
|
logger.error(f'{card} is not a valid base64 string')
|
|
return
|
|
|
|
message = ShuffleMessage(name, new_cards, stage)
|
|
|
|
elif message_type == 'announcement':
|
|
name = message_object.get('name')
|
|
announcement = message_object.get('announcement')
|
|
|
|
if None in [name, announcement]:
|
|
logger.error(f'Did not receive all expected fields for announcement message: {message_string}')
|
|
return
|
|
|
|
elif not isinstance(name, str) or not isinstance(announcement, str):
|
|
logger.error(f'Received fields for announcement message are not of correct type: {message_string}')
|
|
return
|
|
|
|
try:
|
|
announcement = base64.b64decode(announcement, validate=True)
|
|
except binascii.Error:
|
|
logger.error(f'{announcement} is not a valid base64 string')
|
|
return
|
|
|
|
message = AnnouncementMessage(name, announcement)
|
|
|
|
if message is None:
|
|
logger.error(f'Message type {message_type} does not exist')
|
|
return
|
|
|
|
for message_receiver in self.receivers:
|
|
message_receiver(message)
|