Slight code cleanup

This commit is contained in:
Martin Asprusten 2025-04-18 03:11:34 +02:00
parent d0b99c47a4
commit 8652ab2eec

View File

@ -2,12 +2,10 @@ import base64
import hashlib import hashlib
import json import json
import logging import logging
import math
import random
import secrets import secrets
import threading import threading
from collections import defaultdict from collections import defaultdict
from typing import Optional, Callable from typing import Optional
import Crypto.Util.number import Crypto.Util.number
from Crypto.Cipher import AES from Crypto.Cipher import AES
@ -15,7 +13,6 @@ from Crypto.Hash import SHAKE128
from Crypto.Protocol.DH import key_agreement from Crypto.Protocol.DH import key_agreement
from Crypto.PublicKey import ECC from Crypto.PublicKey import ECC
from Crypto.PublicKey.ECC import EccKey from Crypto.PublicKey.ECC import EccKey
from Crypto.Util.number import getPrime
from classes import Message from classes import Message
from classes.Crypto.CSPRNG import CSPRNG from classes.Crypto.CSPRNG import CSPRNG
@ -36,6 +33,7 @@ ENCRYPT_ANONYMOUS_STAGE = 'encrypt_announcement'
SHUFFLE_ANONYMOUS_STAGE = 'shuffle_announcement' SHUFFLE_ANONYMOUS_STAGE = 'shuffle_announcement'
DECRYPT_ANONYMOUS_STAGE = 'decrypt_announcement' DECRYPT_ANONYMOUS_STAGE = 'decrypt_announcement'
class Brain: class Brain:
def __init__(self, message_handler: MessageHandler, user_interface: UserInterface): def __init__(self, message_handler: MessageHandler, user_interface: UserInterface):
self.thread_lock = threading.Lock() self.thread_lock = threading.Lock()
@ -85,11 +83,11 @@ class Brain:
def receive_user_info(self, name: str, info_for_santa: str): def receive_user_info(self, name: str, info_for_santa: str):
# We will only receive this once # We will only receive this once
if ( if (
self.own_name is not None self.own_name is not None
or self.signing_key is not None or self.signing_key is not None
or self.introduction_message is not None or self.introduction_message is not None
or self.info_for_santa is not None or self.info_for_santa is not None
or self.random_seed is not None or self.random_seed is not None
): ):
return return
@ -145,7 +143,8 @@ class Brain:
return return
key, _ = self.known_participants[name] key, _ = self.known_participants[name]
if not message.check_signature(key): if not message.check_signature(key):
logger.warning(f'Received message that purports to be from {name} with invalid signature. Ignoring.') logger.warning(
f'Received message that purports to be from {name} with invalid signature. Ignoring.')
return return
self.received_messages[name].append(message) self.received_messages[name].append(message)
@ -181,7 +180,6 @@ class Brain:
self.user_interface.receive_user(name) self.user_interface.receive_user(name)
self.message_handler.send_message(self.introduction_message, self.signing_key) self.message_handler.send_message(self.introduction_message, self.signing_key)
# Santa's brain will be driven by receiving messages. We'll call this main method each time we receive a message. # Santa's brain will be driven by receiving messages. We'll call this main method each time we receive a message.
# This is probably inefficient, but it makes it easier to follow what the code is doing # This is probably inefficient, but it makes it easier to follow what the code is doing
def santa_loop(self) -> list[Message]: def santa_loop(self) -> list[Message]:
@ -200,10 +198,10 @@ class Brain:
# Next, if we haven't built our commutative ciphers and card values yet, attempt to do that first # Next, if we haven't built our commutative ciphers and card values yet, attempt to do that first
if ( if (
self.card_values is None self.card_values is None
or self.card_exchange_cipher is None or self.card_exchange_cipher is None
or self.announcement_build_cipher is None or self.announcement_build_cipher is None
or self.announcement_shuffle_cipher is None or self.announcement_shuffle_cipher is None
): ):
should_continue = self.build_ciphers() should_continue = self.build_ciphers()
if not should_continue: if not should_continue:
@ -290,12 +288,13 @@ class Brain:
announcements.append(message) announcements.append(message)
receiver_key = self.anonymous_keys[receiver_card_no] receiver_key = self.anonymous_keys[receiver_card_no]
def kdf(x): def kdf(x):
return SHAKE128.new(x).read(32) return SHAKE128.new(x).read(32)
session_key = key_agreement(eph_priv=self.secret_key, eph_pub=receiver_key, kdf=kdf) session_key = key_agreement(eph_priv=self.secret_key, eph_pub=receiver_key, kdf=kdf)
for announcement in announcements: for announcement in announcements:
announcement_hash = announcement.get_announcement_hash()
encrypted_announcement = announcement.get_encrypted_announcement().decode('utf-8') encrypted_announcement = announcement.get_encrypted_announcement().decode('utf-8')
try: try:
encrypted_announcement = json.loads(encrypted_announcement) encrypted_announcement = json.loads(encrypted_announcement)
@ -317,7 +316,6 @@ class Brain:
continue continue
return messages_to_send return messages_to_send
def build_ciphers(self) -> bool: def build_ciphers(self) -> bool:
received_seeds = [self.random_seed] received_seeds = [self.random_seed]
for name in self.chosen_participants: for name in self.chosen_participants:
@ -364,7 +362,7 @@ class Brain:
self.announcement_build_cipher = CommutativeCipher(p, q) self.announcement_build_cipher = CommutativeCipher(p, q)
self.announcement_shuffle_cipher = CommutativeCipher(p, q) self.announcement_shuffle_cipher = CommutativeCipher(p, q)
self.card_values = [random_generator.get_random_bytes(8) for i in range(len(self.chosen_participants) + 1)] self.card_values = [random_generator.get_random_bytes(8) for _ in range(len(self.chosen_participants) + 1)]
return True return True
def build_shuffle_message(self, all_participants: list[str]) -> Optional[ShuffleMessage]: def build_shuffle_message(self, all_participants: list[str]) -> Optional[ShuffleMessage]:
@ -398,7 +396,6 @@ class Brain:
def build_decrypt_cards_message(self, all_participants: list[str]) -> Optional[ShuffleMessage]: def build_decrypt_cards_message(self, all_participants: list[str]) -> Optional[ShuffleMessage]:
own_index = all_participants.index(self.own_name) own_index = all_participants.index(self.own_name)
# Again, special case if we are the first participant # Again, special case if we are the first participant
message = None
previous_participant = all_participants[(own_index - 1) % len(all_participants)] previous_participant = all_participants[(own_index - 1) % len(all_participants)]
if own_index == 0: if own_index == 0:
message = next( message = next(
@ -446,11 +443,10 @@ class Brain:
def build_anonymous_deck(self, all_participants: list[str]) -> Optional[ShuffleMessage]: def build_anonymous_deck(self, all_participants: list[str]) -> Optional[ShuffleMessage]:
own_index = all_participants.index(self.own_name) own_index = all_participants.index(self.own_name)
card_deck = None
if own_index == 0: if own_index == 0:
card_deck = [] card_deck = []
else: else:
previous_participant = all_participants[own_index-1] previous_participant = all_participants[own_index - 1]
message = next( message = next(
( (
message for message in self.received_messages[previous_participant] message for message in self.received_messages[previous_participant]
@ -526,7 +522,7 @@ class Brain:
if message is not None: if message is not None:
card_deck = [card_value for card_value in message.get_cards()] card_deck = [card_value for card_value in message.get_cards()]
else: else:
previous_participant = all_participants[own_index-1] previous_participant = all_participants[own_index - 1]
message = next( message = next(
( (
message for message in self.received_messages[previous_participant] message for message in self.received_messages[previous_participant]
@ -551,7 +547,6 @@ class Brain:
return ShuffleMessage(self.own_name, shuffled_cards, SHUFFLE_ANONYMOUS_STAGE) return ShuffleMessage(self.own_name, shuffled_cards, SHUFFLE_ANONYMOUS_STAGE)
def decrypt_anonymous_deck(self, all_participants: list[str]) -> Optional[ShuffleMessage]: def decrypt_anonymous_deck(self, all_participants: list[str]) -> Optional[ShuffleMessage]:
own_index = all_participants.index(self.own_name) own_index = all_participants.index(self.own_name)
card_deck = None card_deck = None
@ -592,7 +587,6 @@ class Brain:
return shuffle_message return shuffle_message
def get_anonymous_keys(self, message: ShuffleMessage): def get_anonymous_keys(self, message: ShuffleMessage):
anonymous_keys = {} anonymous_keys = {}
cards = message.get_cards() cards = message.get_cards()
@ -625,10 +619,10 @@ class Brain:
if self.anonymous_keys is None: if self.anonymous_keys is None:
return None return None
santa_card = (self.card_drawn - 1) % len(all_participants) santa_card = (self.card_drawn - 1) % len(all_participants)
santa_key = self.anonymous_keys[santa_card] santa_key = self.anonymous_keys[santa_card]
def kdf(x): def kdf(x):
return SHAKE128.new(x).read(32) return SHAKE128.new(x).read(32)
@ -649,4 +643,4 @@ class Brain:
} }
encrypted_announcement = json.dumps(encrypted_announcement).encode('utf-8') encrypted_announcement = json.dumps(encrypted_announcement).encode('utf-8')
return AnnouncementMessage(self.own_name, encrypted_announcement, hashed_announcement) return AnnouncementMessage(self.own_name, encrypted_announcement, hashed_announcement)