Looks like it's working, at least in my test setup

This commit is contained in:
Martin Asprusten 2025-04-18 03:08:54 +02:00
parent e8e1fce791
commit d0b99c47a4
6 changed files with 388 additions and 58 deletions

View File

@ -36,6 +36,9 @@ class Message:
except ValueError: except ValueError:
return False return False
def set_signature(self, signature: bytes):
self.message_fields['signature'] = base64.b64encode(signature).decode('utf-8')
def get_name(self) -> str: def get_name(self) -> str:
# All subclasses have a get_name function, which tells who sent the message # All subclasses have a get_name function, which tells who sent the message
pass pass

View File

@ -85,15 +85,15 @@ class MessageHandler:
logger.error(f'Received data of the wrong type for ready message: {message_string}') logger.error(f'Received data of the wrong type for ready message: {message_string}')
return return
elif not all(isinstance(participant, tuple) for participant in participants): elif not all(isinstance(participant, list) for participant in participants):
logger.error(f'Not all participants in participant list are tuples: {message_string}') logger.error(f'Not all participants in participant list are lists: {message_string}')
return return
elif not all(len(participant_tuple) == 2 for participant_tuple in participants): elif not all(len(participant_tuple) == 2 for participant_tuple in participants):
logger.error(f'Not all participant tuples are of length two in {message_string}') logger.error(f'Not all participant tuples are of length two in {message_string}')
return return
elif not all(isinstance(name, str) and isinstance(key, bytes) for name, key in participants): elif not all(isinstance(name, str) and isinstance(key, str) for name, key in participants):
logger.error(f'Not all participant tuples contain a name and a key') logger.error(f'Not all participant tuples contain a name and a key')
return return
@ -112,7 +112,7 @@ class MessageHandler:
logger.error(f'Random seed {random_seed} from {name} is not a valid base64 string') logger.error(f'Random seed {random_seed} from {name} is not a valid base64 string')
return return
message = ReadyMessage(name, participants, random_seed) message = ReadyMessage(name, decoded_participants, random_seed)
elif message_type == 'shuffle': elif message_type == 'shuffle':
name = message_object.get('name') name = message_object.get('name')
@ -143,13 +143,14 @@ class MessageHandler:
elif message_type == 'announcement': elif message_type == 'announcement':
name = message_object.get('name') name = message_object.get('name')
announcement = message_object.get('announcement') announcement = message_object.get('encrypted_announcement')
announcement_hash = message_object.get('announcement_hash')
if None in [name, announcement]: if None in [name, announcement, announcement_hash]:
logger.error(f'Did not receive all expected fields for announcement message: {message_string}') logger.error(f'Did not receive all expected fields for announcement message: {message_string}')
return return
elif not isinstance(name, str) or not isinstance(announcement, str): elif not isinstance(name, str) or not isinstance(announcement, str) or not isinstance(announcement_hash, str):
logger.error(f'Received fields for announcement message are not of correct type: {message_string}') logger.error(f'Received fields for announcement message are not of correct type: {message_string}')
return return
@ -159,11 +160,19 @@ class MessageHandler:
logger.error(f'{announcement} is not a valid base64 string') logger.error(f'{announcement} is not a valid base64 string')
return return
message = AnnouncementMessage(name, announcement) try:
announcement_hash = base64.b64decode(announcement_hash, validate=True)
except binascii.Error:
logger.error(f'{announcement} is not a valid base64 string')
return
message = AnnouncementMessage(name, announcement, announcement_hash)
if message is None: if message is None:
logger.error(f'Message type {message_type} does not exist') logger.error(f'Message type {message_type} does not exist')
return return
if message_object.get('signature'):
message.set_signature(base64.b64decode(message_object.get('signature')))
for message_receiver in self.receivers: for message_receiver in self.receivers:
message_receiver(message) message_receiver(message)

View File

@ -4,11 +4,12 @@ from classes.Message import Message
class AnnouncementMessage(Message): class AnnouncementMessage(Message):
def __init__(self, name: str, announcement: bytes): def __init__(self, name: str, encrypted_announcement: bytes, announcement_hash: bytes):
super().__init__() super().__init__()
self.message_fields['type'] = 'announcement' self.message_fields['type'] = 'announcement'
self.message_fields['name'] = name self.set_name(name)
self.message_fields['announcement'] = announcement self.set_encrypted_announcement(encrypted_announcement)
self.set_announcement_hash(announcement_hash)
def set_name(self, name: str): def set_name(self, name: str):
self.message_fields['name'] = name self.message_fields['name'] = name
@ -16,8 +17,14 @@ class AnnouncementMessage(Message):
def get_name(self) -> str: def get_name(self) -> str:
return self.message_fields['name'] return self.message_fields['name']
def set_announcement(self, announcement: bytes): def set_encrypted_announcement(self, announcement: bytes):
self.message_fields['announcement'] = base64.b64encode(announcement).decode('utf-8') self.message_fields['encrypted_announcement'] = base64.b64encode(announcement).decode('utf-8')
def get_announcement(self) -> bytes: def get_encrypted_announcement(self) -> bytes:
return base64.b64decode(self.message_fields['announcement']) return base64.b64decode(self.message_fields['encrypted_announcement'])
def set_announcement_hash(self, announcement_hash: bytes):
self.message_fields['announcement_hash'] = base64.b64encode(announcement_hash).decode('utf-8')
def get_announcement_hash(self) -> bytes:
return base64.b64decode(self.message_fields['announcement_hash'])

View File

@ -1,4 +1,6 @@
import base64
import hashlib import hashlib
import json
import logging import logging
import math import math
import random import random
@ -8,6 +10,9 @@ from collections import defaultdict
from typing import Optional, Callable from typing import Optional, Callable
import Crypto.Util.number import Crypto.Util.number
from Crypto.Cipher import AES
from Crypto.Hash import SHAKE128
from Crypto.Protocol.DH import key_agreement
from Crypto.PublicKey import ECC from Crypto.PublicKey import ECC
from Crypto.PublicKey.ECC import EccKey from Crypto.PublicKey.ECC import EccKey
from Crypto.Util.number import getPrime from Crypto.Util.number import getPrime
@ -16,6 +21,7 @@ from classes import Message
from classes.Crypto.CSPRNG import CSPRNG from classes.Crypto.CSPRNG import CSPRNG
from classes.Crypto.CommutativeCipher import CommutativeCipher from classes.Crypto.CommutativeCipher import CommutativeCipher
from classes.MessageHandler import MessageHandler from classes.MessageHandler import MessageHandler
from classes.MessageTypes.Announcement import AnnouncementMessage
from classes.MessageTypes.Introduction import IntroductionMessage from classes.MessageTypes.Introduction import IntroductionMessage
from classes.MessageTypes.Ready import ReadyMessage from classes.MessageTypes.Ready import ReadyMessage
from classes.MessageTypes.Shuffle import ShuffleMessage from classes.MessageTypes.Shuffle import ShuffleMessage
@ -26,6 +32,7 @@ logger = logging.getLogger(__name__)
SHUFFLE_CARDS_STAGE = 'shuffle_cards' SHUFFLE_CARDS_STAGE = 'shuffle_cards'
DECRYPT_CARDS_STAGE = 'decrypt_cards' DECRYPT_CARDS_STAGE = 'decrypt_cards'
BUILD_ANONYMOUS_STAGE = 'build_announcement' BUILD_ANONYMOUS_STAGE = 'build_announcement'
ENCRYPT_ANONYMOUS_STAGE = 'encrypt_announcement'
SHUFFLE_ANONYMOUS_STAGE = 'shuffle_announcement' SHUFFLE_ANONYMOUS_STAGE = 'shuffle_announcement'
DECRYPT_ANONYMOUS_STAGE = 'decrypt_announcement' DECRYPT_ANONYMOUS_STAGE = 'decrypt_announcement'
@ -51,7 +58,7 @@ class Brain:
self.introduction_message: Optional[IntroductionMessage] = None self.introduction_message: Optional[IntroductionMessage] = None
self.info_for_santa: Optional[str] = None self.info_for_santa: Optional[str] = None
# Secret key, used for receiving information about who your secret santa is # Secret key, used for Diffie-Hellman exchange later
self.secret_key = ECC.generate(curve='p256') self.secret_key = ECC.generate(curve='p256')
self.user_interface.add_user_info_listener(self.receive_user_info) self.user_interface.add_user_info_listener(self.receive_user_info)
@ -67,6 +74,13 @@ class Brain:
self.sent_card_shuffling: bool = False self.sent_card_shuffling: bool = False
self.sent_card_decryption: bool = False self.sent_card_decryption: bool = False
self.card_drawn: Optional[int] = None self.card_drawn: Optional[int] = None
self.built_anonymous_deck = False
self.encrypted_anonymous_deck = False
self.shuffled_anonymous_deck = False
self.decrypted_anonymous_deck = False
self.anonymous_keys: Optional[dict[int, EccKey]] = None
self.sent_announcement = False
self.received_announcement = False
def receive_user_info(self, name: str, info_for_santa: str): def receive_user_info(self, name: str, info_for_santa: str):
# We will only receive this once # We will only receive this once
@ -164,6 +178,7 @@ class Brain:
and self.introduction_message is not None and self.introduction_message is not None
and self.signing_key is not None and self.signing_key is not None
): ):
self.user_interface.receive_user(name)
self.message_handler.send_message(self.introduction_message, self.signing_key) self.message_handler.send_message(self.introduction_message, self.signing_key)
@ -228,9 +243,78 @@ class Brain:
return messages_to_send return messages_to_send
self.decrypt_card_value(message.get_cards()[own_index]) self.decrypt_card_value(message.get_cards()[own_index])
# Next, anonymously publish a key someone else can use to encrypt a message telling you that you're
# their secret santa
if not self.built_anonymous_deck:
build_deck_message = self.build_anonymous_deck(all_participants)
if build_deck_message is None:
return messages_to_send
messages_to_send.append(build_deck_message)
self.built_anonymous_deck = True
if not self.encrypted_anonymous_deck:
encrypt_deck_message = self.encrypt_anonymous_deck(all_participants)
if encrypt_deck_message is None:
return messages_to_send
messages_to_send.append(encrypt_deck_message)
self.encrypted_anonymous_deck = True
if not self.shuffled_anonymous_deck:
shuffle_anonymous_deck_message = self.shuffle_anonymous_deck(all_participants)
if shuffle_anonymous_deck_message is None:
return messages_to_send
messages_to_send.append(shuffle_anonymous_deck_message)
self.shuffled_anonymous_deck = True
if not self.decrypted_anonymous_deck:
decrypted_anonymous_deck_message = self.decrypt_anonymous_deck(all_participants)
if decrypted_anonymous_deck_message is None:
return messages_to_send
messages_to_send.append(decrypted_anonymous_deck_message)
self.decrypted_anonymous_deck = True
if not self.sent_announcement:
announcement_message = self.build_announcement(all_participants)
if announcement_message is None:
return messages_to_send
messages_to_send.append(announcement_message)
self.sent_announcement = True
# Look through all announcements to find our secret santa receiver
if self.sent_announcement and self.anonymous_keys is not None and not self.received_announcement:
receiver_card_no = (self.card_drawn + 1) % len(all_participants)
announcements = []
for participant in self.received_messages:
for message in self.received_messages[participant]:
if isinstance(message, AnnouncementMessage):
announcements.append(message)
receiver_key = self.anonymous_keys[receiver_card_no]
def kdf(x):
return SHAKE128.new(x).read(32)
session_key = key_agreement(eph_priv=self.secret_key, eph_pub=receiver_key, kdf=kdf)
for announcement in announcements:
announcement_hash = announcement.get_announcement_hash()
encrypted_announcement = announcement.get_encrypted_announcement().decode('utf-8')
try:
encrypted_announcement = json.loads(encrypted_announcement)
ciphertext = base64.b64decode(encrypted_announcement['ciphertext'])
tag = base64.b64decode(encrypted_announcement['tag'])
nonce = base64.b64decode(encrypted_announcement['nonce'])
cipher = AES.new(session_key, AES.MODE_EAX, nonce=nonce)
plaintext = cipher.decrypt(ciphertext)
cipher.verify(tag)
plaintext = json.loads(plaintext)
name = plaintext['name']
extra = plaintext['extra']
self.user_interface.announce_recipient(name, extra)
self.received_announcement = True
except:
continue
return messages_to_send return messages_to_send
@ -300,13 +384,13 @@ class Brain:
if message is None: if message is None:
return None return None
card_deck = message.get_cards() card_deck = [card for card in message.get_cards()]
# Shuffle by drawing random numbers from secret # Shuffle by drawing random numbers from secret
shuffled_deck = [] shuffled_deck = []
while len(card_deck) > 0: while len(card_deck) > 0:
drawn_card = secrets.randbelow(len(card_deck)) drawn_card = secrets.randbelow(len(card_deck))
shuffled_deck.append(card_deck[drawn_card]) shuffled_deck.append(self.card_exchange_cipher.encode(card_deck[drawn_card]))
del card_deck[drawn_card] del card_deck[drawn_card]
return ShuffleMessage(self.own_name, shuffled_deck, SHUFFLE_CARDS_STAGE) return ShuffleMessage(self.own_name, shuffled_deck, SHUFFLE_CARDS_STAGE)
@ -357,4 +441,212 @@ class Brain:
self.process_failed = True self.process_failed = True
return return
self.card_drawn = self.card_values.index(decrypted_card_bytes) self.card_drawn = self.card_values.index(decrypted_card_bytes)
def build_anonymous_deck(self, all_participants: list[str]) -> Optional[ShuffleMessage]:
own_index = all_participants.index(self.own_name)
card_deck = None
if own_index == 0:
card_deck = []
else:
previous_participant = all_participants[own_index-1]
message = next(
(
message for message in self.received_messages[previous_participant]
if isinstance(message, ShuffleMessage) and message.get_stage() == BUILD_ANONYMOUS_STAGE
),
None
)
if message is None:
return None
card_deck = [card for card in message.get_cards()]
anonymous_message = {
'card_no': self.card_drawn,
'key': self.secret_key.public_key().export_key(format='OpenSSH')
}
anonymous_message = json.dumps(anonymous_message).encode('utf-8')
card_deck.append(self.announcement_build_cipher.encode(anonymous_message))
return ShuffleMessage(self.own_name, card_deck, BUILD_ANONYMOUS_STAGE)
def encrypt_anonymous_deck(self, all_participants: list[str]) -> Optional[ShuffleMessage]:
own_index = all_participants.index(self.own_name)
card_deck = None
if own_index == 0:
previous_participant = all_participants[-1]
message = next(
(
message for message in self.received_messages[previous_participant]
if isinstance(message, ShuffleMessage) and message.get_stage() == BUILD_ANONYMOUS_STAGE
),
None
)
if message is not None:
card_deck = message.get_cards()
else:
previous_participant = all_participants[own_index - 1]
message = next(
(
message for message in self.received_messages[previous_participant]
if isinstance(message, ShuffleMessage) and message.get_stage() == ENCRYPT_ANONYMOUS_STAGE
),
None
)
if message is not None:
card_deck = message.get_cards()
if card_deck is None:
return None
encrypted_deck = []
for index, card_value in enumerate(card_deck):
if index == own_index:
encrypted_deck.append(card_value)
else:
encrypted_deck.append(self.announcement_build_cipher.encode(card_value))
return ShuffleMessage(self.own_name, encrypted_deck, ENCRYPT_ANONYMOUS_STAGE)
def shuffle_anonymous_deck(self, all_participants: list[str]) -> Optional[ShuffleMessage]:
own_index = all_participants.index(self.own_name)
card_deck = None
if own_index == 0:
previous_participant = all_participants[-1]
message = next(
(
message for message in self.received_messages[previous_participant]
if isinstance(message, ShuffleMessage) and message.get_stage() == ENCRYPT_ANONYMOUS_STAGE
),
None
)
if message is not None:
card_deck = [card_value for card_value in message.get_cards()]
else:
previous_participant = all_participants[own_index-1]
message = next(
(
message for message in self.received_messages[previous_participant]
if isinstance(message, ShuffleMessage) and message.get_stage() == SHUFFLE_ANONYMOUS_STAGE
),
None
)
if message is not None:
card_deck = [card_value for card_value in message.get_cards()]
if card_deck is None:
return None
shuffled_cards = []
while len(card_deck) > 0:
draw_number = secrets.randbelow(len(card_deck))
card_value = card_deck[draw_number]
del card_deck[draw_number]
decrypted_previous = self.announcement_build_cipher.decode(card_value)
shuffled_cards.append(self.announcement_shuffle_cipher.encode(decrypted_previous))
return ShuffleMessage(self.own_name, shuffled_cards, SHUFFLE_ANONYMOUS_STAGE)
def decrypt_anonymous_deck(self, all_participants: list[str]) -> Optional[ShuffleMessage]:
own_index = all_participants.index(self.own_name)
card_deck = None
if own_index == 0:
previous_participant = all_participants[-1]
message = next(
(
message for message in self.received_messages[previous_participant]
if isinstance(message, ShuffleMessage) and message.get_stage() == SHUFFLE_ANONYMOUS_STAGE
),
None
)
if message is not None:
card_deck = [card_value for card_value in message.get_cards()]
else:
previous_participant = all_participants[own_index - 1]
message = next(
(
message for message in self.received_messages[previous_participant]
if isinstance(message, ShuffleMessage) and message.get_stage() == DECRYPT_ANONYMOUS_STAGE
),
None
)
if message is not None:
card_deck = [card_value for card_value in message.get_cards()]
if card_deck is None:
return None
decrypted_deck = [self.announcement_shuffle_cipher.decode(card_value) for card_value in card_deck]
shuffle_message = ShuffleMessage(self.own_name, decrypted_deck, DECRYPT_ANONYMOUS_STAGE)
# If we are the last participant, everything should now be decrypted
if own_index == len(all_participants) - 1:
self.get_anonymous_keys(shuffle_message)
return shuffle_message
def get_anonymous_keys(self, message: ShuffleMessage):
anonymous_keys = {}
cards = message.get_cards()
for card in cards:
try:
decoded_card = json.loads(card.decode('utf-8'))
anonymous_keys[decoded_card['card_no']] = ECC.import_key(decoded_card['key'])
except:
logger.critical(f'Received card {card} could not be decoded as JSON. Secret santa process failed.')
self.process_failed = True
return
self.anonymous_keys = anonymous_keys
def build_announcement(self, all_participants: list[str]) -> Optional[AnnouncementMessage]:
if self.anonymous_keys is None:
last_participant = all_participants[-1]
message = next(
(
message for message in self.received_messages[last_participant]
if isinstance(message, ShuffleMessage) and message.get_stage() == DECRYPT_ANONYMOUS_STAGE
),
None
)
if message is None:
return None
self.get_anonymous_keys(message)
if self.anonymous_keys is None:
return None
santa_card = (self.card_drawn - 1) % len(all_participants)
santa_key = self.anonymous_keys[santa_card]
def kdf(x):
return SHAKE128.new(x).read(32)
session_key = key_agreement(eph_priv=self.secret_key, eph_pub=santa_key, kdf=kdf)
session_cipher = AES.new(session_key, AES.MODE_EAX)
message_string = {'name': self.own_name, 'extra': self.info_for_santa}
ciphertext, tag = session_cipher.encrypt_and_digest(json.dumps(message_string).encode('utf-8'))
hasher = hashlib.sha512()
hasher.update(self.info_for_santa.encode('utf-8'))
hashed_announcement = hasher.digest()
encrypted_announcement = {
'ciphertext': base64.b64encode(ciphertext).decode('utf-8'),
'tag': base64.b64encode(tag).decode('utf-8'),
'nonce': base64.b64encode(session_cipher.nonce).decode('utf-8')
}
encrypted_announcement = json.dumps(encrypted_announcement).encode('utf-8')
return AnnouncementMessage(self.own_name, encrypted_announcement, hashed_announcement)

View File

@ -2,14 +2,26 @@ from collections.abc import Callable
class UserInterface: class UserInterface:
def add_user(self, name: str): def __init__(self):
pass self.user_info_listeners: list[Callable[[str, str], None]] = []
self.start_listeners: list[Callable[[list[str]], None]] = []
def add_user_info_listener(self, callback: Callable[[str, str], None]): def receive_user(self, name: str):
pass
def add_start_listener(self, callback: Callable[[list[str]], None]):
pass pass
def announce_recipient(self, name: str, other_info: str): def announce_recipient(self, name: str, other_info: str):
pass pass
def set_user_info(self, name: str, info_for_santa: str):
for listener in self.user_info_listeners:
listener(name, info_for_santa)
def start_exchange(self, list_of_participants: list[str]):
for listener in self.start_listeners:
listener(list_of_participants)
def add_user_info_listener(self, callback: Callable[[str, str], None]):
self.user_info_listeners.append(callback)
def add_start_listener(self, callback: Callable[[list[str]], None]):
self.start_listeners.append(callback)

73
main.py
View File

@ -1,45 +1,52 @@
import base64 from Crypto.PublicKey.ECC import EccKey
import secrets
from base64 import b64encode
import Crypto.Cipher.PKCS1_OAEP from classes.Message import Message
from Crypto.Util.number import getPrime from classes.MessageHandler import MessageHandler
from Crypto.PublicKey import ECC, RSA from classes.SantasBrain import Brain
from classes.UserInterface import UserInterface
from classes.Crypto.CommutativeCipher import CommutativeCipher participants: list[tuple[MessageHandler, UserInterface]] = []
from classes.MessageTypes.Introduction import IntroductionMessage
p = getPrime(1500) class TestMessageHandler(MessageHandler):
q = getPrime(1000) def __init__(self):
cipher1 = CommutativeCipher(p, q) super().__init__()
cipher2 = CommutativeCipher(p, q)
message = 'Hei på deg'.encode('utf-8 ')
c1 = cipher1.encode(message)
c2 = cipher2.encode(c1)
d1 = cipher1.decode(c2) def send_message(self, message: Message, signing_key: EccKey):
print(cipher2.decode(d1)) message_string = message.generate_and_sign(signing_key)
for message_handler, _ in participants:
if message_handler != self:
message_handler.decode_received_message(message_string)
key = ECC.generate(curve='p256') class TestUserInterface(UserInterface):
def __init__(self, own_name, extra_info):
super().__init__()
self.received_names = []
self.own_name = own_name
self.extra_info = extra_info
test = key.public_key().export_key(format='OpenSSH') def tell(self):
self.set_user_info(self.own_name, self.extra_info)
seed = secrets.randbits(256).to_bytes(32) def lets_go(self):
self.start_exchange(self.received_names)
key1 = ECC.import_key(test) def receive_user(self, name: str):
key2 = ECC.import_key(test) self.received_names.append(name)
print(f'Keys are equal: {key1 == key2}') def announce_recipient(self, name: str, other_info: str):
print(f'{self.own_name}: Received {name}, {other_info}')
test = IntroductionMessage('Martin', key.public_key(), seed) number_of_participants = 20
print(test.generate_and_sign(key))
print(test.check_signature(key.public_key()))
rsa_key = RSA.generate(2048) for i in range(number_of_participants):
message = 'Keep it secret! Keep it safe!'.encode('utf-8') test_message_handler = TestMessageHandler()
cipher = Crypto.Cipher.PKCS1_OAEP.new(rsa_key.public_key()) test_user_interface = TestUserInterface(f'User {i}', f'Lives at {i} street')
ciphertext = cipher.encrypt(message) brain = Brain(test_message_handler, test_user_interface)
decoder = Crypto.Cipher.PKCS1_OAEP.new(rsa_key)
print(decoder.decrypt(ciphertext)) participants.append((test_message_handler, test_user_interface))
test_user_interface.tell()
for i in range(number_of_participants):
handler, interface = participants[i]
interface.lets_go()
print(len(b64encode(rsa_key.public_key().export_key(format='DER'))))