DecentraSanta/classes/SantasBrain.py
2025-04-17 19:25:23 +02:00

214 lines
8.6 KiB
Python

import hashlib
import logging
import secrets
import threading
from typing import Optional, Callable
from Crypto.PublicKey import ECC
from Crypto.PublicKey.ECC import EccKey
from Crypto.Util.number import getPrime
from classes import Message
from classes.Crypto.CSPRNG import CSPRNG
from classes.Crypto.CommutativeCipher import CommutativeCipher
from classes.MessageHandler import MessageHandler
from classes.MessageTypes.Introduction import IntroductionMessage
from classes.MessageTypes.Ready import ReadyMessage
from classes.UserInterface import UserInterface
logger = logging.getLogger(__name__)
class Brain:
def __init__(self, message_handler: MessageHandler, user_interface: UserInterface):
# We're going to need to do some
self.thread_lock = threading.Lock()
self.message_handler = message_handler
self.user_interface = user_interface
self.other_possible_participants: dict[str, tuple[EccKey, bytes]] = {}
self.name = None
self.extra_info_for_santa = None
self.key = ECC.generate(curve='p256')
self.random_seed = secrets.token_bytes(32)
self.message_handler.add_message_receiver(self.receive_message)
self.user_interface.add_user_info_listener(self.set_user_data)
self.user_interface.add_start_listener(self.receive_user_start_command)
# The following fields are used during the exchange itself
self.other_participants: dict[str, tuple[EccKey, bytes]] = {}
self.other_ready_participants: dict[str, tuple[list[str], bytes]] = {}
self.first_shuffle_key: CommutativeCipher = None
self.second_shuffle_key: CommutativeCipher = None
self.third_shuffle_key: CommutativeCipher = None
self.card_values: list[bytes] = []
def set_user_data(self, name: str, extra_info_for_santa: str):
with self.thread_lock:
if name in self.other_possible_participants:
logger.error(f'Participant already exists with username {name}. Please choose another.')
return
self.name = name
self.extra_info_for_santa = extra_info_for_santa
# Send our introductions out on the network
self.send_introduction_message()
def receive_user_start_command(self, other_participant_names: list[str]):
with self.thread_lock:
other_participants = {}
for other_participant_name in other_participant_names:
if other_participant_name not in self.other_possible_participants:
logger.error(f'Tried to start an exchange containing unknown participant {other_participant_name}')
return
other_participants[other_participant_name] = self.other_possible_participants[other_participant_name]
self.other_participants = other_participants
self.send_ready_message()
after_lock_command = None
with self.thread_lock:
if self.check_if_ready():
after_lock_command = self.start_exchange_process()
if after_lock_command is not None:
after_lock_command()
def receive_message(self, message: Message):
after_lock_call: Optional[Callable[[], None]] = None
with self.thread_lock:
# First, check if message signature is correct
sender_name = message.get_name()
key = None
if sender_name in self.other_possible_participants:
key, _ = self.other_possible_participants[sender_name]
elif isinstance(message, IntroductionMessage):
key = message.get_key()
if key is None:
logger.warning(f'Received message from participant {sender_name}, but there is no validation key.')
return
if not message.check_signature(key):
logger.warning(f'Received message from participant {sender_name} with invalid signature. Ignoring.')
return
if isinstance(message, IntroductionMessage):
after_lock_call = self.handle_introduction(message)
elif isinstance(message, ReadyMessage):
after_lock_call = self.receive_ready_message(message)
if after_lock_call is not None:
after_lock_call()
def handle_introduction(self, introduction: IntroductionMessage) -> Optional[Callable[[], None]]:
name = introduction.get_name()
key = introduction.get_key()
commit = introduction.get_seed_commit()
# Check if it's a participant we already know about
if name in self.other_possible_participants:
previous_key, previous_commit = self.other_possible_participants[name]
# Either it's a participant we already know about, or it's someone trying to use the same name
# as a previous participant. Either way, we don't really do anything
if previous_key != key or previous_commit != commit:
logger.warning(f'A second participant tried to register with the already used name {name}. Ignoring.')
return None
self.other_possible_participants[name] = (key, commit)
# Since this participant is new, they might not know about us yet. Send an introduction
# Also, tell the user interface about this new user
def post_introduction():
self.send_introduction_message()
self.user_interface.add_user(name)
return post_introduction
def send_introduction_message(self):
if self.name is None or self.key is None or self.random_seed is None:
return
# We need to commit to our random seed by hashing it, but we don't actually want to send the seed itself yet,
# to prevent others from crafting their seeds based on the value of ours
hasher = hashlib.sha512()
hasher.update(self.random_seed)
introduction_message = IntroductionMessage(self.name, self.key.public_key(), hasher.digest())
introduction_message.generate_and_sign(self.key)
self.message_handler.send_message(introduction_message)
def send_ready_message(self):
other_participants = [(name, self.other_possible_participants[name][0]) for name in self.other_possible_participants]
ready_message = ReadyMessage(self.name, other_participants, self.random_seed)
ready_message.generate_and_sign(self.key)
self.message_handler.send_message(ready_message)
def receive_ready_message(self, ready_message: ReadyMessage) -> Optional[Callable[[], None]]:
sender_name = ready_message.get_name()
sender_expected_participants = ready_message.get_participants()
sender_random_seed = ready_message.get_random_seed()
if sender_name not in self.other_possible_participants:
logger.warning(f'Received ready message from unknown participant {sender_name}')
return None
_, sender_commit = self.other_possible_participants[sender_name]
hasher = hashlib.sha512()
hasher.update(sender_random_seed)
seed_hash = hasher.digest()
if seed_hash != sender_commit:
logger.error(f'Participant {sender_name} sent random seed that did not match their initial commit!')
return None
for expected_participant_name, expected_participant_key in sender_expected_participants:
if expected_participant_name not in self.other_possible_participants:
logger.warning(f'Participant {sender_name} expects exchange with unknown participant {expected_participant_name}')
return None
our_key, _ = self.other_possible_participants[expected_participant_name]
if our_key != expected_participant_key:
logger.error(f'Participant {sender_name} has different public key for participant {expected_participant_name} than we have.')
return None
self.other_ready_participants[sender_name] = ([name for name, _ in sender_expected_participants], sender_random_seed)
if self.check_if_ready():
return self.start_exchange_process()
def check_if_ready(self):
list_of_names = set([name for name in self.other_participants] + [self.name])
for name in self.other_participants:
if name not in self.other_ready_participants:
return False
other_list_of_names = set(self.other_ready_participants[name][0] + [name])
if list_of_names != other_list_of_names:
logger.critical(f'Participant {name} does not have the same list of participants as us!')
return False
return True
def start_exchange_process(self) -> Optional[Callable[[], None]]:
# XOR all the seeds together
all_seeds = [self.random_seed]
for name in self.other_participants:
all_seeds.append(self.other_ready_participants[name][1])
longest_seed_length = max(len(seed) for seed in all_seeds)
total_seed = b'0' * longest_seed_length
for seed in all_seeds:
seed = b'0' * (longest_seed_length - len(seed)) + seed
total_seed = bytes(a ^ b for a, b in zip(total_seed, seed))
# Use the total random seed to initialize a cryptographically secure pseudo-random number generator
csprng = CSPRNG(total_seed)
# Since everyone is using the same seed, everyone should get same values for p and q
p = getPrime(1200, randfunc=csprng.get_random_bytes)
q = getPrime(800, randfunc=csprng.get_random_bytes)
self.first_shuffle_key = CommutativeCipher(p, q)
self.second_shuffle_key = CommutativeCipher(p, q)
self.third_shuffle_key = CommutativeCipher(p, q)
# In the same way, everyone should get the same values for each card
for i in range(len(self.other_participants) + 1):
self.card_values.append(csprng.get_random_bytes(16))