commit 1df4790fb416bbeec7c034a70a8837f0bc086501 Author: saymrwulf Date: Thu Apr 23 13:20:30 2026 +0200 Initial implementation of the Crisis protocol (Richter, 2019) Complete Python PoC of "Probabilistically Self Organizing Total Order in Unstructured P2P Networks". Implements all 10 algorithms from the paper: message generation, integrity checks, Lamport graphs, virtual synchronous rounds, safe voting patterns, virtual leader election (BA*), longest chain rule, total order via Kahn's algorithm, and push/pull gossip. Includes simulation harness, full node binary, and 72 passing tests. Co-Authored-By: Claude Opus 4.6 diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..d1b5921 --- /dev/null +++ b/.gitignore @@ -0,0 +1,26 @@ +# Python +__pycache__/ +*.py[cod] +*$py.class +*.egg-info/ +dist/ +build/ +*.egg + +# Virtual environment +.venv/ + +# IDE +.idea/ +.vscode/ +*.swp +*.swo + +# OS +.DS_Store +Thumbs.db + +# Testing +.pytest_cache/ +.coverage +htmlcov/ diff --git a/Crisis.mirco-richter-2019.pdf b/Crisis.mirco-richter-2019.pdf new file mode 100644 index 0000000..e24014a Binary files /dev/null and b/Crisis.mirco-richter-2019.pdf differ diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..37fe60f --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,30 @@ +[project] +name = "crisis" +version = "0.1.0" +description = "Crisis: Probabilistically Self Organizing Total Order in Unstructured P2P Networks" +readme = "README.md" +requires-python = ">=3.11" +license = "CC-BY-4.0" +authors = [ + { name = "Mirco Richter (paper)", email = "mirco.richter@mailbox.org" }, +] +dependencies = [ + "networkx>=3.0", +] + +[project.optional-dependencies] +dev = [ + "pytest>=8.0", + "rich>=13.0", +] + +[project.scripts] +crisis-node = "crisis.node:main" +crisis-demo = "crisis.demo:main" + +[build-system] +requires = ["setuptools>=68.0"] +build-backend = "setuptools.build_meta" + +[tool.pytest.ini_options] +testpaths = ["tests"] diff --git a/src/crisis/__init__.py b/src/crisis/__init__.py new file mode 100644 index 0000000..f56df7f --- /dev/null +++ b/src/crisis/__init__.py @@ -0,0 +1,21 @@ +""" +Crisis: Probabilistically Self Organizing Total Order in Unstructured P2P Networks + +A Python implementation of the Crisis protocol described by Mirco Richter (2019). + +The protocol achieves total order on messages in fully open, unstructured +Peer-to-Peer networks through virtual voting -- votes are never sent explicitly +but are deduced from the causal relationships between messages encoded in +Lamport graphs. + +Key components: + - crypto: Random oracle model (SHA-256 hash function) + - message: Message and Vertex data structures + - weight: Weight systems (PoW-based Sybil resistance) + - graph: Lamport graphs with integrity checking + - rounds: Virtual synchronous rounds + - voting: Safe voting patterns and virtual leader election (BA*) + - order: Total order via leader stream and topological sorting + - gossip: Push/pull gossip for member discovery and message dissemination + - node: Full Crisis node tying all components together +""" diff --git a/src/crisis/crypto.py b/src/crisis/crypto.py new file mode 100644 index 0000000..55897f2 --- /dev/null +++ b/src/crisis/crypto.py @@ -0,0 +1,104 @@ +""" +Random Oracle Model (Section 2.1) + +We work in the random oracle model, assuming the existence of a cryptographic +hash function that behaves like a random oracle: + + H : {0,1}* -> {0,1}^p (Eq. 1) + +We use SHA-256 as our concrete instantiation. H is assumed to be collision-, +preimage-, and second-preimage-resistant. + +We call H(b) the *digest* of the binary string b. +""" + +import hashlib +from typing import Union + +# The digest length in bytes (SHA-256 produces 32 bytes = 256 bits). +DIGEST_LENGTH = 32 + + +def digest(data: Union[bytes, bytearray]) -> bytes: + """Compute the SHA-256 digest of arbitrary binary data. + + This is the core random oracle H used throughout the protocol. + Every reference to "the digest of" a message or byte string in the + paper maps to this function. + + Returns: + 32-byte digest (256 bits). + """ + return hashlib.sha256(data).digest() + + +def digest_hex(data: Union[bytes, bytearray]) -> str: + """Convenience: return the digest as a hex string for display.""" + return digest(data).hex() + + +def verify_digest(data: bytes, expected: bytes) -> bool: + """Check that H(data) equals the expected digest.""" + return digest(data) == expected + + +# --------------------------------------------------------------------------- +# Least significant bit helper (used in the virtual coin flip, Algorithm 7) +# --------------------------------------------------------------------------- + +def least_significant_bit(h: bytes) -> int: + """Return the least significant bit of a hash value. + + Used in Algorithm 7 (virtual leader election) for the "genuine coin flip" + stage, where the LSB of H(v_hat.m) determines the binary vote. + + The paper defines: + b_coin := lsb(H(x.m)) for max weight x in S + """ + return h[-1] & 1 + + +# --------------------------------------------------------------------------- +# Proof-of-Work helpers (used by the weight system, Section 3.1.1) +# --------------------------------------------------------------------------- + +def count_leading_zero_bits(h: bytes) -> int: + """Count the number of leading zero bits in a hash value. + + This is the standard measure of proof-of-work difficulty: a hash with + k leading zero bits required roughly 2^k hash evaluations to find. + """ + count = 0 + for byte in h: + if byte == 0: + count += 8 + else: + # Count leading zeros in this byte + count += (byte ^ 0xFF).bit_length() - (255 - byte).bit_length() + # Simpler: count leading zeros via bit tricks + for bit_pos in range(7, -1, -1): + if byte & (1 << bit_pos): + return count + count += 1 + break + return count + + +def count_leading_zero_bits(h: bytes) -> int: + """Count the number of leading zero bits in a hash value. + + A hash with k leading zero bits required roughly 2^k evaluations to find. + Used by the PoW weight function to assign weight to messages. + """ + count = 0 + for byte in h: + if byte == 0: + count += 8 + continue + # Count leading zeros in this non-zero byte + for bit_pos in range(7, -1, -1): + if byte & (1 << bit_pos): + return count + count += 1 + break + return count diff --git a/src/crisis/demo.py b/src/crisis/demo.py new file mode 100644 index 0000000..66f7da1 --- /dev/null +++ b/src/crisis/demo.py @@ -0,0 +1,356 @@ +""" +Demonstration / Simulation Harness + +This module provides a deterministic, single-process simulation of the Crisis +protocol with N virtual nodes. It is designed as the foundation for a lecture +series: each phase of the protocol can be observed step by step. + +The simulation bypasses the network layer entirely -- messages are delivered +directly between in-memory Lamport graphs. This makes the consensus mechanism +visible without network noise. + +Usage: + python -m crisis.demo # Run the full demo + python -m crisis.demo --nodes 5 # 5 honest nodes + python -m crisis.demo --byzantine 1 # 1 byzantine node + python -m crisis.demo --rounds 10 # Run for 10 message rounds +""" + +from __future__ import annotations + +import os +import random +import time +from dataclasses import dataclass, field +from typing import Optional + +from crisis.crypto import digest +from crisis.graph import LamportGraph +from crisis.message import Message, Vertex, ID_LENGTH, NONCE_LENGTH +from crisis.order import LeaderStream, compute_order +from crisis.rounds import compute_rounds, max_round, last_vertices_in_round +from crisis.voting import compute_safe_voting_pattern, compute_virtual_leader_election +from crisis.weight import ProofOfWorkWeight, DifficultyOracle + + +# --------------------------------------------------------------------------- +# Simulated Node +# --------------------------------------------------------------------------- + +@dataclass +class SimulatedNode: + """A simulated Crisis node running in-memory. + + Each node has its own Lamport graph and process id. Messages are + exchanged by directly sharing Message objects (no serialization needed). + """ + name: str + process_id: bytes + graph: LamportGraph + leader_stream: LeaderStream = field(default_factory=LeaderStream) + is_byzantine: bool = False + messages_created: int = 0 + + def generate_message(self, payload: str) -> Message: + """Generate a new message from this node.""" + self.messages_created += 1 + return self.graph.generate_message( + self.process_id, + payload.encode(), + ) + + +# --------------------------------------------------------------------------- +# Simulation Engine +# --------------------------------------------------------------------------- + +class Simulation: + """Deterministic simulation of N Crisis nodes. + + Runs the protocol in lock-step rounds: + 1. Each node generates a message + 2. Messages are gossiped (delivered to all nodes) + 3. Consensus is computed on each node + 4. State is displayed + + This allows observing how the Lamport graph grows, rounds emerge, + and total order converges. + """ + + def __init__(self, num_honest: int = 3, num_byzantine: int = 0, + pow_zeros: int = 0, difficulty: int = 2, + connectivity_k: int = 1, seed: int = 42): + self.difficulty_oracle = DifficultyOracle(constant_difficulty=difficulty) + self.connectivity_k = connectivity_k + self.weight_system = ProofOfWorkWeight(min_leading_zeros=pow_zeros) + self.seed = seed + random.seed(seed) + + # Create nodes + self.nodes: list[SimulatedNode] = [] + for i in range(num_honest): + name = f"honest-{i}" + pid = digest(name.encode())[:ID_LENGTH] + graph = LamportGraph(weight_system=self.weight_system) + self.nodes.append(SimulatedNode( + name=name, process_id=pid, graph=graph + )) + + for i in range(num_byzantine): + name = f"byzantine-{i}" + pid = digest(name.encode())[:ID_LENGTH] + graph = LamportGraph(weight_system=self.weight_system) + self.nodes.append(SimulatedNode( + name=name, process_id=pid, graph=graph, is_byzantine=True + )) + + self.step_count = 0 + self.all_messages: list[Message] = [] + + def step(self) -> dict: + """Execute one simulation step. + + Returns a dict with step results for display. + """ + self.step_count += 1 + step_results = { + "step": self.step_count, + "new_messages": [], + "node_states": [], + } + + # Phase 1: Each node generates a message + new_messages: list[tuple[SimulatedNode, Message]] = [] + for node in self.nodes: + if node.is_byzantine: + msg = self._byzantine_message(node) + else: + payload = f"step-{self.step_count}-{node.name}" + msg = node.generate_message(payload) + + if msg is not None: + new_messages.append((node, msg)) + step_results["new_messages"].append({ + "from": node.name, + "digest": msg.compute_digest().hex()[:12], + "weight": self.weight_system.weight(msg), + "payload": msg.payload.decode(errors="replace"), + }) + + # Phase 2: Gossip -- deliver all messages to all nodes + for source_node, msg in new_messages: + self.all_messages.append(msg) + for target_node in self.nodes: + # Deliver to all nodes (including source, for consistency) + target_node.graph.extend(msg) + + # Also re-deliver older messages that nodes might be missing + # (simulates pull gossip catching up) + for msg in self.all_messages: + for node in self.nodes: + node.graph.extend(msg) # extend() is idempotent (integrity check) + + # Phase 3: Compute consensus on each node + for node in self.nodes: + compute_rounds(node.graph, self.difficulty_oracle, self.connectivity_k) + + for vertex in node.graph.all_vertices(): + if vertex.is_last: + compute_safe_voting_pattern( + vertex, node.graph, self.difficulty_oracle, + self.connectivity_k + ) + + leader_dict: dict[int, list[tuple[int, Message]]] = {} + for vertex in node.graph.all_vertices(): + if vertex.svp: + compute_virtual_leader_election( + vertex, node.graph, self.difficulty_oracle, + self.connectivity_k, leader_dict + ) + + for round_num, entries in leader_dict.items(): + for deciding_round, leader_msg in entries: + node.leader_stream.update(round_num, deciding_round, leader_msg) + + ordered = compute_order(node.graph, node.leader_stream) + + mr = max_round(node.graph) + step_results["node_states"].append({ + "name": node.name, + "vertices": node.graph.vertex_count(), + "max_round": mr, + "leaders": len(node.leader_stream.leaders), + "ordered": len(ordered), + "is_byzantine": node.is_byzantine, + }) + + return step_results + + def _byzantine_message(self, node: SimulatedNode) -> Optional[Message]: + """Generate a byzantine message. + + Byzantine nodes can exhibit several faulty behaviors: + - Mutations: same id, forking the causal chain + - Strategic distribution: different messages to different peers + - Time travel: referencing old rounds + + For this demo, we generate a message with a random payload that + may not reference the latest same-id message (creating a mutation). + """ + payload = f"byz-{self.step_count}-{node.name}-{random.randint(0, 999)}" + + # 50% chance of creating a mutation (not referencing last same-id vertex) + if random.random() < 0.5 and node.graph.vertex_count() > 0: + # Pick random digests instead of following the chain + available = list(node.graph.vertices.keys()) + num_refs = min(random.randint(1, 3), len(available)) + digests = tuple(random.sample(available, num_refs)) + nonce = os.urandom(NONCE_LENGTH) + return Message( + nonce=nonce, id=node.process_id, + digests=digests, payload=payload.encode() + ) + else: + return node.generate_message(payload) + + def run(self, num_steps: int = 10, verbose: bool = True) -> list[dict]: + """Run the simulation for a number of steps.""" + results = [] + for _ in range(num_steps): + result = self.step() + results.append(result) + if verbose: + _print_step(result) + + if verbose: + _print_convergence_summary(self) + + return results + + +# --------------------------------------------------------------------------- +# Display functions +# --------------------------------------------------------------------------- + +def _print_step(result: dict) -> None: + """Print the results of a simulation step.""" + print(f"\n{'='*70}") + print(f" Step {result['step']}") + print(f"{'='*70}") + + print(f"\n New messages:") + for msg in result["new_messages"]: + print(f" {msg['from']:>15s} -> {msg['digest']} " + f"w={msg['weight']} {msg['payload'][:40]}") + + print(f"\n Node states:") + print(f" {'Name':>15s} {'Vertices':>8s} {'Round':>5s} " + f"{'Leaders':>7s} {'Ordered':>7s}") + print(f" {'-'*15} {'-'*8} {'-'*5} {'-'*7} {'-'*7}") + for ns in result["node_states"]: + byz = " [BYZ]" if ns["is_byzantine"] else "" + print(f" {ns['name']:>15s} {ns['vertices']:>8d} " + f"{ns['max_round']:>5d} {ns['leaders']:>7d} " + f"{ns['ordered']:>7d}{byz}") + + +def _print_convergence_summary(sim: Simulation) -> None: + """Print a summary showing whether honest nodes have converged.""" + print(f"\n{'='*70}") + print(f" Convergence Summary") + print(f"{'='*70}") + + honest_nodes = [n for n in sim.nodes if not n.is_byzantine] + + # Check if all honest nodes have the same total order + orders = [] + for node in honest_nodes: + ordered = compute_order(node.graph, node.leader_stream) + order_digests = [v.message_digest.hex()[:12] for v in ordered] + orders.append(order_digests) + + if len(orders) >= 2: + # Compare pairwise + all_agree = all(o == orders[0] for o in orders[1:]) + if all_agree: + print(f"\n All {len(honest_nodes)} honest nodes AGREE " + f"on total order ({len(orders[0])} messages)") + else: + print(f"\n Honest nodes have DIVERGENT total orders " + f"(convergence in progress)") + for i, (node, order) in enumerate(zip(honest_nodes, orders)): + print(f" {node.name}: {len(order)} ordered messages") + + # Show the total order from the first honest node + if orders and orders[0]: + print(f"\n Total order (from {honest_nodes[0].name}):") + first_node = honest_nodes[0] + ordered = compute_order(first_node.graph, first_node.leader_stream) + for v in ordered[:20]: # Show first 20 + print(f" pos={v.total_position:>3d} " + f"hash={v.message_digest.hex()[:12]} " + f"r={v.round} " + f"payload={v.payload.decode(errors='replace')[:40]}") + if len(ordered) > 20: + print(f" ... and {len(ordered) - 20} more") + + # Show leader stream + if honest_nodes: + ls = honest_nodes[0].leader_stream + if ls.leaders: + print(f"\n Leader stream ({honest_nodes[0].name}):") + for round_num, (dec_round, msg) in sorted(ls.leaders.items()): + print(f" round {round_num}: leader={msg.compute_digest().hex()[:12]} " + f"decided_in_round={dec_round}") + + print() + + +# --------------------------------------------------------------------------- +# CLI entry point +# --------------------------------------------------------------------------- + +def main(): + import argparse + + parser = argparse.ArgumentParser( + description="Crisis Protocol Simulation", + epilog="Demonstrates probabilistic total order convergence" + ) + parser.add_argument("--nodes", type=int, default=3, + help="Number of honest nodes (default: 3)") + parser.add_argument("--byzantine", type=int, default=0, + help="Number of byzantine nodes (default: 0)") + parser.add_argument("--steps", type=int, default=10, + help="Number of simulation steps (default: 10)") + parser.add_argument("--pow-zeros", type=int, default=0, + help="Min PoW leading zeros (default: 0 = no PoW)") + parser.add_argument("--difficulty", type=int, default=2, + help="Difficulty oracle constant (default: 2)") + parser.add_argument("--seed", type=int, default=42, + help="Random seed for reproducibility (default: 42)") + + args = parser.parse_args() + + print(f"Crisis Protocol Simulation") + print(f" Honest nodes: {args.nodes}") + print(f" Byzantine nodes: {args.byzantine}") + print(f" Steps: {args.steps}") + print(f" PoW zeros: {args.pow_zeros}") + print(f" Difficulty: {args.difficulty}") + print(f" Seed: {args.seed}") + + sim = Simulation( + num_honest=args.nodes, + num_byzantine=args.byzantine, + pow_zeros=args.pow_zeros, + difficulty=args.difficulty, + seed=args.seed, + ) + + sim.run(num_steps=args.steps) + + +if __name__ == "__main__": + main() diff --git a/src/crisis/gossip.py b/src/crisis/gossip.py new file mode 100644 index 0000000..6b8fa58 --- /dev/null +++ b/src/crisis/gossip.py @@ -0,0 +1,414 @@ +""" +Communication (Section 4) + +Crisis is built on top of two simple push & pull gossip protocols: +1. Member discovery gossip (Algorithm 3) +2. Message gossip (Algorithm 4) + +These are well suited for communication in unstructured P2P networks. +All the system needs is a way to distribute messages in a byzantine-prone +environment. + +4.3 Member Discovery Gossip (Algorithm 3): + Each process maintains a partial view Π_j(t) of the network. + Periodically, a process pushes its neighbor list to a random peer + and pulls neighbor lists from other peers. + +4.4 Message Gossip (Algorithm 4): + Processes push unordered messages to random peers and pull missing + messages. Already ordered messages are pushed only as responses + to pull requests (stop criterion for push gossip). + +This module implements both gossip protocols using asyncio for the +"run in parallel forever" loops described in the paper. +""" + +from __future__ import annotations + +import asyncio +import json +import logging +import os +import random +import struct +from dataclasses import dataclass, field +from typing import Optional + +from crisis.graph import LamportGraph +from crisis.message import Message, Vertex, NONCE_LENGTH, ID_LENGTH +from crisis.crypto import DIGEST_LENGTH + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Peer identity and network view +# --------------------------------------------------------------------------- + +@dataclass +class PeerInfo: + """Information about a known peer in the network.""" + host: str + port: int + process_id: bytes = b"" # The peer's virtual process id, if known + + @property + def address(self) -> tuple[str, int]: + return (self.host, self.port) + + def __hash__(self): + return hash((self.host, self.port)) + + def __eq__(self, other): + if not isinstance(other, PeerInfo): + return NotImplemented + return self.host == other.host and self.port == other.port + + +@dataclass +class NetworkView: + """Π_j(t): a process's partial view of the network at time t. + + "No process must know the entire system and each j ∈ Π(t) might + have a partial view Π_j(t) only." (Section 4.3) + """ + peers: set[PeerInfo] = field(default_factory=set) + max_peers: int = 50 # Limit to prevent unbounded growth + + def add_peer(self, peer: PeerInfo) -> None: + if len(self.peers) < self.max_peers: + self.peers.add(peer) + + def remove_peer(self, peer: PeerInfo) -> None: + self.peers.discard(peer) + + def random_peer(self) -> Optional[PeerInfo]: + if not self.peers: + return None + return random.choice(list(self.peers)) + + def random_subset(self, k: int) -> list[PeerInfo]: + peers_list = list(self.peers) + return random.sample(peers_list, min(k, len(peers_list))) + + +# --------------------------------------------------------------------------- +# Message serialization for network transport +# --------------------------------------------------------------------------- + +def serialize_message(message: Message) -> bytes: + """Serialize a Message for network transmission. + + Format: [total_length:4][nonce:8][id:32][num_digests:2][digests...][payload] + """ + body = message.serialize() + length = len(body) + return struct.pack("!I", length) + body + + +def deserialize_message(data: bytes) -> Message: + """Deserialize a Message from network bytes. + + Parses the fixed-size fields and reconstructs the Message object. + """ + offset = 0 + nonce = data[offset:offset + NONCE_LENGTH] + offset += NONCE_LENGTH + + id_bytes = data[offset:offset + ID_LENGTH] + offset += ID_LENGTH + + num_digests = int.from_bytes(data[offset:offset + 2], "big") + offset += 2 + + digests = [] + for _ in range(num_digests): + d = data[offset:offset + DIGEST_LENGTH] + digests.append(d) + offset += DIGEST_LENGTH + + payload = data[offset:] + + return Message( + nonce=nonce, + id=id_bytes, + digests=tuple(digests), + payload=payload, + ) + + +# --------------------------------------------------------------------------- +# Protocol message types +# --------------------------------------------------------------------------- + +# Simple protocol: 1-byte type prefix +MSG_TYPE_PUSH_MESSAGE = b"\x01" # Push a crisis message +MSG_TYPE_PULL_REQUEST = b"\x02" # Request missing messages +MSG_TYPE_PULL_RESPONSE = b"\x03" # Response with requested messages +MSG_TYPE_PEER_PUSH = b"\x04" # Push peer list +MSG_TYPE_PEER_PULL = b"\x05" # Request peer list +MSG_TYPE_PEER_RESPONSE = b"\x06" # Response with peer list + + +# --------------------------------------------------------------------------- +# Gossip Server +# --------------------------------------------------------------------------- + +class GossipServer: + """Asyncio-based gossip server implementing Algorithms 3 and 4. + + Runs two parallel loops: + 1. Member discovery push & pull (Algorithm 3) + 2. Message push & pull (Algorithm 4) + + Plus a listener that handles incoming connections. + """ + + def __init__(self, host: str, port: int, graph: LamportGraph, + network_view: NetworkView, + push_interval: float = 2.0, + discovery_interval: float = 5.0): + self.host = host + self.port = port + self.graph = graph + self.network_view = network_view + self.push_interval = push_interval + self.discovery_interval = discovery_interval + self._server: Optional[asyncio.Server] = None + self._running = False + + async def start(self) -> None: + """Start the gossip server and all gossip loops.""" + self._running = True + self._server = await asyncio.start_server( + self._handle_connection, self.host, self.port + ) + logger.info(f"Gossip server listening on {self.host}:{self.port}") + + # Run the gossip loops concurrently (paper: "run in parallel forever") + await asyncio.gather( + self._server.serve_forever(), + self._discovery_push_loop(), + self._message_push_loop(), + ) + + async def stop(self) -> None: + """Stop the gossip server.""" + self._running = False + if self._server: + self._server.close() + await self._server.wait_closed() + + # ------------------------------------------------------------------ + # Algorithm 3: Member discovery push & pull + # ------------------------------------------------------------------ + + async def _discovery_push_loop(self) -> None: + """Algorithm 3, lines 1-5: periodically push peer list to random peers.""" + while self._running: + await asyncio.sleep(self.discovery_interval) + + peer = self.network_view.random_peer() + if peer is None: + continue + + try: + await self._send_peer_push(peer) + await self._send_peer_pull(peer) + except (ConnectionError, OSError) as e: + logger.debug(f"Discovery push to {peer.address} failed: {e}") + self.network_view.remove_peer(peer) + + async def _send_peer_push(self, peer: PeerInfo) -> None: + """Push our peer list to a remote peer.""" + peer_data = self._encode_peer_list(list(self.network_view.peers)) + await self._send_to_peer(peer, MSG_TYPE_PEER_PUSH + peer_data) + + async def _send_peer_pull(self, peer: PeerInfo) -> None: + """Request a peer list from a remote peer.""" + response = await self._send_and_receive(peer, MSG_TYPE_PEER_PULL) + if response and response[0:1] == MSG_TYPE_PEER_RESPONSE: + new_peers = self._decode_peer_list(response[1:]) + for p in new_peers: + if p.host != self.host or p.port != self.port: + self.network_view.add_peer(p) + + # ------------------------------------------------------------------ + # Algorithm 4: Message gossip push & pull + # ------------------------------------------------------------------ + + async def _message_push_loop(self) -> None: + """Algorithm 4, lines 1-5: push unordered messages to random peers. + + "Messages are retransmitted via push gossip, only if they don't + have a total order yet." (Section 4.4) + """ + while self._running: + await asyncio.sleep(self.push_interval) + + peer = self.network_view.random_peer() + if peer is None: + continue + + # Push messages that don't have total_position yet + unordered = [ + v for v in self.graph.all_vertices() + if v.total_position is None + ] + + if not unordered: + continue + + try: + for vertex in unordered: + msg_bytes = serialize_message(vertex.m) + await self._send_to_peer( + peer, MSG_TYPE_PUSH_MESSAGE + msg_bytes + ) + except (ConnectionError, OSError) as e: + logger.debug(f"Message push to {peer.address} failed: {e}") + + # ------------------------------------------------------------------ + # Connection handler (incoming) + # ------------------------------------------------------------------ + + async def _handle_connection(self, reader: asyncio.StreamReader, + writer: asyncio.StreamWriter) -> None: + """Handle an incoming gossip connection. + + Algorithm 3, lines 6-13 (peer data) and Algorithm 4, lines 6-13 + (message data). + """ + try: + data = await asyncio.wait_for(reader.read(65536), timeout=10.0) + if not data: + return + + msg_type = data[0:1] + payload = data[1:] + + if msg_type == MSG_TYPE_PUSH_MESSAGE: + # Received a message: try to extend our Lamport graph + self._handle_push_message(payload) + + elif msg_type == MSG_TYPE_PULL_REQUEST: + # Someone wants messages: send what we have + response = self._handle_pull_request(payload) + writer.write(response) + await writer.drain() + + elif msg_type == MSG_TYPE_PEER_PUSH: + # Received a peer list: update our view + new_peers = self._decode_peer_list(payload) + for p in new_peers: + if p.host != self.host or p.port != self.port: + self.network_view.add_peer(p) + + elif msg_type == MSG_TYPE_PEER_PULL: + # Someone wants our peer list + response = MSG_TYPE_PEER_RESPONSE + self._encode_peer_list( + list(self.network_view.peers) + ) + writer.write(response) + await writer.drain() + + except (asyncio.TimeoutError, ConnectionError): + pass + finally: + writer.close() + + def _handle_push_message(self, data: bytes) -> Optional[Vertex]: + """Process a pushed message: validate and extend graph if valid. + + Algorithm 4, lines 7-8: "if MESSAGE_INTEGRITY(m, G) then + expand G with vertex v, such that v.m = m" + """ + try: + # Parse length prefix + if len(data) < 4: + return None + length = struct.unpack("!I", data[:4])[0] + msg_data = data[4:4 + length] + + message = deserialize_message(msg_data) + return self.graph.extend(message) + except Exception as e: + logger.debug(f"Failed to process pushed message: {e}") + return None + + def _handle_pull_request(self, data: bytes) -> bytes: + """Respond to a pull request with messages the requester is missing. + + Algorithm 4, lines 10-11: "respond with appropriate set of messages" + """ + # Data contains a list of digests the requester already has + known_digests = set() + offset = 0 + while offset + DIGEST_LENGTH <= len(data): + known_digests.add(data[offset:offset + DIGEST_LENGTH]) + offset += DIGEST_LENGTH + + # Send messages the requester doesn't have + response_parts = [MSG_TYPE_PULL_RESPONSE] + for d, vertex in self.graph.vertices.items(): + if d not in known_digests: + response_parts.append(serialize_message(vertex.m)) + return b"".join(response_parts) + + # ------------------------------------------------------------------ + # Network I/O helpers + # ------------------------------------------------------------------ + + async def _send_to_peer(self, peer: PeerInfo, data: bytes) -> None: + """Send data to a peer (fire-and-forget).""" + reader, writer = await asyncio.open_connection(peer.host, peer.port) + writer.write(data) + await writer.drain() + writer.close() + + async def _send_and_receive(self, peer: PeerInfo, data: bytes) -> Optional[bytes]: + """Send data and wait for a response.""" + try: + reader, writer = await asyncio.open_connection(peer.host, peer.port) + writer.write(data) + await writer.drain() + response = await asyncio.wait_for(reader.read(65536), timeout=5.0) + writer.close() + return response + except Exception: + return None + + # ------------------------------------------------------------------ + # Peer list encoding + # ------------------------------------------------------------------ + + @staticmethod + def _encode_peer_list(peers: list[PeerInfo]) -> bytes: + """Encode a list of peers as bytes: [count:2][host_len:1][host][port:2]...""" + parts = [struct.pack("!H", len(peers))] + for peer in peers: + host_bytes = peer.host.encode("utf-8") + parts.append(struct.pack("!B", len(host_bytes))) + parts.append(host_bytes) + parts.append(struct.pack("!H", peer.port)) + return b"".join(parts) + + @staticmethod + def _decode_peer_list(data: bytes) -> list[PeerInfo]: + """Decode a peer list from bytes.""" + if len(data) < 2: + return [] + count = struct.unpack("!H", data[:2])[0] + offset = 2 + peers = [] + for _ in range(count): + if offset >= len(data): + break + host_len = data[offset] + offset += 1 + host = data[offset:offset + host_len].decode("utf-8") + offset += host_len + port = struct.unpack("!H", data[offset:offset + 2])[0] + offset += 2 + peers.append(PeerInfo(host=host, port=port)) + return peers diff --git a/src/crisis/graph.py b/src/crisis/graph.py new file mode 100644 index 0000000..f32682b --- /dev/null +++ b/src/crisis/graph.py @@ -0,0 +1,479 @@ +""" +Lamport Graphs (Section 3.2) + +Lamport graphs represent the causal partial order between messages as a +directed acyclic graph. They are the central data structure of the Crisis +protocol -- all consensus state is derived from the graph structure. + +Definition 3.5 (Lamport Graph): + Let V ⊂ VERTEX be a finite set of vertices, such that all vertices + v_hat with v_hat ≤ v for all v ∈ V are in V, but no two vertices in V + are equivalent. Then the graph G = (V, A) with (v, v_hat) ∈ A if and + only if v -> v_hat is called a *Lamport graph*. + +Key properties: + - Directed and acyclic (Proposition 3.6) + - The past of a vertex is invariant across Lamport graphs (Theorem 3.7) + - No two equivalent vertices exist in the same graph + +This module implements: + - Algorithm 1: Message generation + - Algorithm 2: Message integrity checking and graph extension + - Causality queries (past, future, timelike, spacelike) +""" + +from __future__ import annotations + +import os +from typing import Optional + +from crisis.crypto import digest +from crisis.message import Message, Vertex, Vote, ID_LENGTH, NONCE_LENGTH +from crisis.weight import ProofOfWorkWeight, WeightSystem + + +class LamportGraph: + """A Lamport graph: a DAG of vertices connected by causal acknowledgement. + + The graph is stored as: + - vertices: dict mapping message digest -> Vertex + - edges: dict mapping digest -> set of digests it references + (i.e. v -> v_hat means v acknowledges v_hat) + + Invariants maintained: + - No two vertices have the same underlying message (no equivalence) + - All referenced digests either exist in the graph or are the empty digest + - The graph is acyclic (guaranteed by hash function properties) + """ + + def __init__(self, weight_system: WeightSystem | None = None): + self.weight_system: WeightSystem = weight_system or ProofOfWorkWeight(min_leading_zeros=0) + + # digest -> Vertex + self.vertices: dict[bytes, Vertex] = {} + + # digest -> set of digests this vertex references (outgoing causal edges) + # An edge v -> v_hat means "v acknowledges v_hat" i.e. H(v_hat.m) ∈ v.m.digests + self.edges: dict[bytes, set[bytes]] = {} + + # Reverse edges for efficient "future" queries + # digest -> set of digests that reference this vertex + self.reverse_edges: dict[bytes, set[bytes]] = {} + + # ------------------------------------------------------------------ + # Graph queries + # ------------------------------------------------------------------ + + def __len__(self) -> int: + return len(self.vertices) + + def __contains__(self, digest_or_vertex) -> bool: + if isinstance(digest_or_vertex, Vertex): + return digest_or_vertex.message_digest in self.vertices + return digest_or_vertex in self.vertices + + def get_vertex(self, msg_digest: bytes) -> Optional[Vertex]: + return self.vertices.get(msg_digest) + + def all_vertices(self) -> list[Vertex]: + return list(self.vertices.values()) + + def vertex_count(self) -> int: + return len(self.vertices) + + # ------------------------------------------------------------------ + # Causality (Definition 3.2) + # ------------------------------------------------------------------ + # m -> m_hat (m happens before m_hat) iff: + # - m = m_hat, OR + # - there is a chain m -> m1 -> ... -> mk -> m_hat + # In our DAG: v has an edge to v_hat means v acknowledges v_hat. + # So v is in the *future* of v_hat, and v_hat is in the *past* of v. + + def direct_causes(self, v: Vertex) -> list[Vertex]: + """Return the direct causes of v (vertices that v acknowledges). + + These are the vertices whose digests appear in v.m.digests. + In graph terms: the outgoing neighbors of v. + """ + result = [] + for d in self.edges.get(v.message_digest, set()): + vertex = self.vertices.get(d) + if vertex is not None: + result.append(vertex) + return result + + def direct_effects(self, v: Vertex) -> list[Vertex]: + """Return the direct effects of v (vertices that acknowledge v). + + In graph terms: the incoming neighbors of v (who references v). + """ + result = [] + for d in self.reverse_edges.get(v.message_digest, set()): + vertex = self.vertices.get(d) + if vertex is not None: + result.append(vertex) + return result + + def past(self, v: Vertex) -> set[Vertex]: + """G_v: the subgraph of G containing all causes of v. + + Definition 3.5: "the subgraph G_v of G that contains all causes + of v is called the *past* of v". + + Returns the set of all vertices that are causally before v + (including v itself -- reflexivity). + """ + visited: set[bytes] = set() + stack = [v.message_digest] + + while stack: + current = stack.pop() + if current in visited: + continue + visited.add(current) + for neighbor in self.edges.get(current, set()): + if neighbor in self.vertices and neighbor not in visited: + stack.append(neighbor) + + return {self.vertices[d] for d in visited if d in self.vertices} + + def future(self, v: Vertex) -> set[Vertex]: + """All vertices that are causally after v (including v itself).""" + visited: set[bytes] = set() + stack = [v.message_digest] + + while stack: + current = stack.pop() + if current in visited: + continue + visited.add(current) + for neighbor in self.reverse_edges.get(current, set()): + if neighbor in self.vertices and neighbor not in visited: + stack.append(neighbor) + + return {self.vertices[d] for d in visited if d in self.vertices} + + def is_cause_of(self, v: Vertex, v_hat: Vertex) -> bool: + """Check if v ≤ v_hat (v is in the past of v_hat). + + Definition 3.4: v is said to happen before v_hat (v ≤ v_hat) + if there is a causality chain from v to v_hat. + """ + if v == v_hat: + return True + return v in self.past(v_hat) + + def are_timelike(self, v: Vertex, v_hat: Vertex) -> bool: + """Check if v and v_hat are timelike (comparable / causally related).""" + return self.is_cause_of(v, v_hat) or self.is_cause_of(v_hat, v) + + def are_spacelike(self, v: Vertex, v_hat: Vertex) -> bool: + """Check if v and v_hat are spacelike (incomparable / no causal relation). + + Spacelike vertices are the ones that need the total order protocol + to become comparable. The protocol extends the timelike partial + order to cover spacelike vertices as well. + """ + return not self.are_timelike(v, v_hat) + + # ------------------------------------------------------------------ + # Mutations (Definition 4.2) + # ------------------------------------------------------------------ + + def find_mutations(self, vertex_id: bytes) -> list[list[Vertex]]: + """Find mutations: vertices with the same id that are spacelike. + + Definition 4.2: Two vertices v and v_hat in G are called a *mutation* + of a virtual process if they have the same id and are spacelike, + i.e. neither v ≤ v_hat nor v_hat ≤ v holds. + + Mutations are the virtual voting equivalent of equivocation -- a + byzantine actor sending different votes to different processes. + + Returns a list of groups of mutually spacelike same-id vertices. + """ + # Group vertices by id + by_id: dict[bytes, list[Vertex]] = {} + for v in self.vertices.values(): + by_id.setdefault(v.id, []).append(v) + + mutations = [] + for vid, group in by_id.items(): + if vid != vertex_id: + continue + # Find spacelike pairs within the group + spacelike_group = [] + for i, v1 in enumerate(group): + for v2 in group[i + 1:]: + if self.are_spacelike(v1, v2): + if v1 not in spacelike_group: + spacelike_group.append(v1) + if v2 not in spacelike_group: + spacelike_group.append(v2) + if spacelike_group: + mutations.append(spacelike_group) + return mutations + + # ------------------------------------------------------------------ + # Byte-level correctness (part of Algorithm 2) + # ------------------------------------------------------------------ + + def _bytelevel_correctness(self, message: Message) -> bool: + """BYTELEVEL_CORRECTNESS: basic structural validation of a message. + + Checks that the message has valid field lengths and is well-formed. + """ + if len(message.nonce) != NONCE_LENGTH: + return False + if len(message.id) != ID_LENGTH: + return False + for d in message.digests: + if len(d) != 32: # DIGEST_LENGTH + return False + return True + + def _payload_correctness(self, message: Message) -> bool: + """PAYLOAD_CORRECTNESS: validate the payload against system rules. + + In this PoC, any payload is accepted. A real system would enforce + application-specific validation here. + """ + return True + + # ------------------------------------------------------------------ + # Algorithm 2: Message integrity (Section 4.2) + # ------------------------------------------------------------------ + + def message_integrity(self, message: Message) -> bool: + """Check whether a message can be validly added to this Lamport graph. + + Algorithm 2 from the paper: + + 1. Check BYTELEVEL_CORRECTNESS(m) + 2. Check w(m) > c_min (weight threshold) + 3. Check PAYLOAD_CORRECTNESS(m.payload) + 4. Check no equivalent vertex exists (no vertex with same digest) + 5. For each digest in m.digests: + - It must reference a vertex in G + - All referenced vertices must have different id's + 6. If there is a vertex v in G with v.id = m.id: + - One of m.digests must reference v (or a vertex in v's past) + Ensures the virtual process forms a chain, not a tree. + + Returns True if the message passes integrity checks. + """ + # Step 1: byte-level structure + if not self._bytelevel_correctness(message): + return False + + # Step 2: weight threshold + if not self.weight_system.is_valid_weight(message): + return False + + # Step 3: payload rules + if not self._payload_correctness(message): + return False + + msg_digest = message.compute_digest() + + # Step 4: no duplicate (no equivalent vertex) + if msg_digest in self.vertices: + return False + + # Step 5: all referenced digests must exist in G + # and all referenced vertices must have different ids + referenced_ids: set[bytes] = set() + for ref_digest in message.digests: + if ref_digest not in self.vertices: + return False + ref_vertex = self.vertices[ref_digest] + if ref_vertex.id in referenced_ids: + return False # Two references to same id + referenced_ids.add(ref_vertex.id) + + # Step 6: if same id exists, must reference it (chain constraint) + # Find the "last vertex" with this id (not referenced by any other + # vertex with the same id) + same_id_vertices = [v for v in self.vertices.values() if v.id == message.id] + if same_id_vertices: + # Check that at least one digest references a same-id vertex + referenced_digests = set(message.digests) + found_chain_link = False + for v in same_id_vertices: + if v.message_digest in referenced_digests: + found_chain_link = True + break + if not found_chain_link: + return False + + return True + + # ------------------------------------------------------------------ + # Lamport graph extension (Section 4.2) + # ------------------------------------------------------------------ + + def extend(self, message: Message) -> Optional[Vertex]: + """Attempt to extend the Lamport graph with a new message. + + If the message passes integrity checks (Algorithm 2), create a new + vertex and add it to the graph with appropriate edges. + + Proposition 4.1 guarantees that the extension of a Lamport graph + by a valid message is itself a Lamport graph. + + Returns the new Vertex if successful, None if integrity check fails. + """ + if not self.message_integrity(message): + return None + + vertex = Vertex(m=message) + msg_digest = message.compute_digest() + + # Add vertex + self.vertices[msg_digest] = vertex + + # Add edges: this vertex -> each referenced vertex + self.edges[msg_digest] = set() + for ref_digest in message.digests: + self.edges[msg_digest].add(ref_digest) + # Reverse edge + if ref_digest not in self.reverse_edges: + self.reverse_edges[ref_digest] = set() + self.reverse_edges[ref_digest].add(msg_digest) + + # Initialize reverse_edges entry for this vertex + if msg_digest not in self.reverse_edges: + self.reverse_edges[msg_digest] = set() + + return vertex + + # ------------------------------------------------------------------ + # Algorithm 1: Message generation (Section 4.1) + # ------------------------------------------------------------------ + + def generate_message(self, process_id: bytes, payload: bytes, + weight_system: WeightSystem | None = None) -> Message: + """Generate a valid message for a given virtual process id. + + Algorithm 1 from the paper: + 1. Find the last vertex v with v.id = id in G + 2. Choose S ⊂ {v.m | v ∈ G ∧ v ∉ G_v} such that all have different ids + 3. Return message with digests = {H(v.m)} ∪ {H(m) | m ∈ S ∪ {v.m}} + + The nonce is chosen so that w(m) > c_min (via mining if PoW). + """ + ws = weight_system or self.weight_system + + # Find the last vertex with this process id + last_vertex = self._find_last_vertex(process_id) + + # Collect digests: last same-id vertex + a sample of other vertices + digests_list: list[bytes] = [] + + if last_vertex is not None: + # Must reference the last vertex with same id + digests_list.append(last_vertex.message_digest) + + # Add cross-references to vertices NOT in last_vertex's past + past_digests = {v.message_digest for v in self.past(last_vertex)} + candidates = [ + v for d, v in self.vertices.items() + if d not in past_digests + and v.id != process_id + and d != last_vertex.message_digest + ] + + # Include candidates with different ids + seen_ids: set[bytes] = {process_id} + for candidate in candidates: + if candidate.id not in seen_ids: + digests_list.append(candidate.message_digest) + seen_ids.add(candidate.id) + else: + # First message for this id: reference a sample of existing vertices + seen_ids = {process_id} + for v in self.vertices.values(): + if v.id not in seen_ids: + digests_list.append(v.message_digest) + seen_ids.add(v.id) + + digests_tuple = tuple(digests_list) + + # Mine a valid nonce (or just find one that meets threshold) + if isinstance(ws, ProofOfWorkWeight): + message = ws.mine_nonce(process_id, digests_tuple, payload) + else: + # For non-PoW systems, use a random nonce + nonce = os.urandom(NONCE_LENGTH) + message = Message(nonce=nonce, id=process_id, digests=digests_tuple, payload=payload) + + return message + + def _find_last_vertex(self, process_id: bytes) -> Optional[Vertex]: + """Find the last vertex with a given process id. + + A vertex is "last" for an id if no other vertex with the same id + references it (i.e. it has no same-id successor). + """ + same_id = [v for v in self.vertices.values() if v.id == process_id] + if not same_id: + return None + + # Find the one that is not referenced by any other same-id vertex + referenced_by_same_id: set[bytes] = set() + for v in same_id: + for d in v.digests: + ref = self.vertices.get(d) + if ref is not None and ref.id == process_id: + referenced_by_same_id.add(d) + + for v in same_id: + if v.message_digest not in referenced_by_same_id: + return v + + # Fallback: return the one added most recently (by convention) + return same_id[-1] + + # ------------------------------------------------------------------ + # Vertices by id (for virtual process queries) + # ------------------------------------------------------------------ + + def vertices_by_id(self, process_id: bytes) -> list[Vertex]: + """Return all vertices belonging to a given virtual process id.""" + return [v for v in self.vertices.values() if v.id == process_id] + + def all_process_ids(self) -> set[bytes]: + """Return all unique virtual process ids in this graph.""" + return {v.id for v in self.vertices.values()} + + def last_vertices_by_id(self) -> dict[bytes, Vertex]: + """Return the last vertex for each virtual process id.""" + result = {} + for pid in self.all_process_ids(): + last = self._find_last_vertex(pid) + if last is not None: + result[pid] = last + return result + + # ------------------------------------------------------------------ + # Weight queries + # ------------------------------------------------------------------ + + def vertex_weight(self, v: Vertex) -> int: + """w(v) = w(v.m): the weight of a vertex is the weight of its message.""" + return self.weight_system.weight(v.m) + + def set_weight(self, vertices: set[Vertex] | list[Vertex]) -> int: + """w(M) := ⊕_{m ∈ M} w(m): the combined weight of a set of vertices.""" + total = 0 + for v in vertices: + total = self.weight_system.weight_sum(total, self.vertex_weight(v)) + return total + + # ------------------------------------------------------------------ + # Display + # ------------------------------------------------------------------ + + def __repr__(self) -> str: + return f"LamportGraph(vertices={len(self.vertices)}, ids={len(self.all_process_ids())})" diff --git a/src/crisis/message.py b/src/crisis/message.py new file mode 100644 index 0000000..e2423f5 --- /dev/null +++ b/src/crisis/message.py @@ -0,0 +1,250 @@ +""" +Data Structures (Section 3) + +3.1 Messages +------------- +Messages distribute payload across the network. The purpose of the protocol +is to establish a total order on those messages that respects causality. + +A message is a byte string of variable length with the following structure +(paper, page 3): + + struct Message { + byte[c1] nonce, + byte[c2] id, + byte[c3] num_digests, + byte[p * num_digests] digests, + byte[] payload + } + +Where c1, c2, c3 are fixed protocol constants and p is the digest length. + +The *nonce* is used by the weight function (e.g. PoW grinding). +The *id* groups messages into virtual processes. +The *digests* field encodes causal acknowledgement of other messages. + +Key insight: a message that acknowledges other messages defines an inherent +natural causality -- this is the Lamport "happens-before" relation (1978). + + m -> m_hat iff H(m_hat) is contained in m.digests (Eq. 2) + +3.1.3 Vertices +--------------- +To establish total order, messages are extended by local voting data that is +NOT transmitted. Votes are deduced from the causal relation between messages. +This is the key characteristic of virtual voting (Moser & Melliar-Smith). + + struct Vertex { + Message m, + Option round, + Option is_last, + Option> svp, # safe voting pattern + Option<(Message, Option)> vote, + Option total_position + } +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Optional + +from crisis.crypto import digest, DIGEST_LENGTH + + +# --------------------------------------------------------------------------- +# Protocol constants (c1, c2, c3 from the paper) +# --------------------------------------------------------------------------- +# These define the byte-lengths of the fixed-size fields in a message. +# Chosen for a practical PoC: generous enough for real use, compact enough +# for clarity. + +NONCE_LENGTH = 8 # c1: 8 bytes of nonce (plenty for PoW search space) +ID_LENGTH = 32 # c2: 32 bytes for virtual process id (a hash) +NUM_DIGESTS_LENGTH = 2 # c3: 2 bytes => up to 65535 referenced digests + + +# --------------------------------------------------------------------------- +# Message +# --------------------------------------------------------------------------- + +@dataclass(frozen=True) +class Message: + """An immutable Crisis message as defined in Section 3.1. + + A message is the atomic unit of communication in the Crisis protocol. + It carries a payload and encodes causal history through its digests field. + + Attributes: + nonce: Used by the weight function (e.g. PoW nonce grinding). + id: Groups this message into a virtual process. + digests: Tuple of digests of causally prior messages (H values). + payload: The actual application data being ordered. + """ + nonce: bytes + id: bytes + digests: tuple[bytes, ...] = () + payload: bytes = b"" + + def __post_init__(self): + if len(self.nonce) != NONCE_LENGTH: + raise ValueError(f"nonce must be {NONCE_LENGTH} bytes, got {len(self.nonce)}") + if len(self.id) != ID_LENGTH: + raise ValueError(f"id must be {ID_LENGTH} bytes, got {len(self.id)}") + for i, d in enumerate(self.digests): + if len(d) != DIGEST_LENGTH: + raise ValueError(f"digest[{i}] must be {DIGEST_LENGTH} bytes") + + def serialize(self) -> bytes: + """Serialize this message to a canonical byte string. + + The serialized form is what gets hashed to produce the message's digest. + Format: nonce | id | num_digests (2 bytes big-endian) | digests... | payload + """ + num = len(self.digests) + parts = [ + self.nonce, + self.id, + num.to_bytes(NUM_DIGESTS_LENGTH, "big"), + ] + for d in self.digests: + parts.append(d) + parts.append(self.payload) + return b"".join(parts) + + def compute_digest(self) -> bytes: + """Compute H(m) -- the digest of this message. + + This is the value other messages include in their digests field + to acknowledge this message (establishing causality, Eq. 2). + """ + return digest(self.serialize()) + + @property + def num_digests(self) -> int: + return len(self.digests) + + def __repr__(self) -> str: + h = self.compute_digest().hex()[:12] + return f"Message(id={self.id.hex()[:8]}..., digests={self.num_digests}, hash={h}...)" + + +# --------------------------------------------------------------------------- +# The empty message (paper: ∅ ∈ MESSAGE) +# --------------------------------------------------------------------------- +# "We postulate a special non-message ∅ ∈ MESSAGE" (Section 3.1) +# Acknowledgement of ∅ is defined as H(empty string). + +EMPTY_MESSAGE_DIGEST = digest(b"") + + +# --------------------------------------------------------------------------- +# Vote +# --------------------------------------------------------------------------- + +@dataclass(frozen=True) +class Vote: + """A virtual vote as computed locally by each vertex. + + From the paper (Algorithm 7): v.vote(r) = (l, b) describes v's vote + on some message l, together with a possibly undecided binary value + b ∈ {⊥, 0, 1} in a round r. + + Attributes: + message: The message l being voted on (None = ∅, the non-leader). + binary: The binary part of the vote: None=⊥ (undecided), 0, or 1. + """ + message: Optional[Message] = None + binary: Optional[int] = None # None = ⊥, 0, or 1 + + def __repr__(self) -> str: + msg_str = "∅" if self.message is None else self.message.compute_digest().hex()[:8] + bin_str = "⊥" if self.binary is None else str(self.binary) + return f"Vote({msg_str}, {bin_str})" + + +# --------------------------------------------------------------------------- +# Vertex +# --------------------------------------------------------------------------- + +@dataclass +class Vertex: + """A vertex in a Lamport graph (Section 3.1.3). + + A vertex wraps a message and adds locally-computed consensus state. + The additional fields (round, is_last, svp, vote, total_position) are + never transmitted -- they are deduced from the causal structure. + + From the paper (page 5, Eq. 6): + w(v) <- w(v.m) + v.nonce <- v.m.nonce + v.id <- v.m.id + v.num_digests <- v.m.num_digests + v.digests <- v.m.digests + v.payload <- v.m.payload + + Attributes: + m: The underlying message. + round: The virtual round number (Algorithm 5). + is_last: Whether this is a "last vertex" of its round (Alg 5). + svp: Safe voting pattern -- ordered set of round numbers. + vote: Per-round votes: round -> Vote. + total_position: Final position in the total order (Algorithm 9/10). + """ + m: Message + + # Locally computed consensus state (initialized to None / ⊥) + round: Optional[int] = None + is_last: Optional[bool] = None + svp: list[int] = field(default_factory=list) + vote: dict[int, Vote] = field(default_factory=dict) + total_position: Optional[int] = None + + # ------------------------------------------------------------------ + # Convenience accessors that delegate to the underlying message + # ------------------------------------------------------------------ + + @property + def nonce(self) -> bytes: + return self.m.nonce + + @property + def id(self) -> bytes: + return self.m.id + + @property + def digests(self) -> tuple[bytes, ...]: + return self.m.digests + + @property + def payload(self) -> bytes: + return self.m.payload + + @property + def message_digest(self) -> bytes: + """H(v.m) -- the digest that uniquely identifies this vertex's message.""" + return self.m.compute_digest() + + # ------------------------------------------------------------------ + # Equivalence (Definition 3.3) + # ------------------------------------------------------------------ + # "Two vertices v and v_hat are equivalent if v.m = v_hat.m" + # i.e. they wrap the same underlying message. + + def equivalent_to(self, other: Vertex) -> bool: + """Check vertex equivalence: same underlying message.""" + return self.message_digest == other.message_digest + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Vertex): + return NotImplemented + return self.message_digest == other.message_digest + + def __hash__(self) -> int: + return hash(self.message_digest) + + def __repr__(self) -> str: + h = self.message_digest.hex()[:12] + round_str = str(self.round) if self.round is not None else "?" + last_str = "*" if self.is_last else "" + return f"Vertex({h}..., r={round_str}{last_str})" diff --git a/src/crisis/node.py b/src/crisis/node.py new file mode 100644 index 0000000..fcfcd4d --- /dev/null +++ b/src/crisis/node.py @@ -0,0 +1,336 @@ +""" +Crisis Node (Section 5.9 -- The Crisis Protocol) + +This module ties all components together into a full Crisis node. + +From the paper (Section 5.9): + "The overall algorithm works as follows: Member discovery (3) and + message gossip (4) are executed in infinite loops, concurrently to + the rest of the system. Ideally the message sending loop is executed + on as many parallel threads as possible. This implies that an overall + unbounded amount of new messages arrive over time due to our liveness + assumption. In addition each process may generate messages and write + them into its own Lamport graph." + +The full node runs these concurrent loops: + 1. Gossip: member discovery + message dissemination + 2. Message generation: create new messages with PoW + 3. Consensus: compute rounds, voting patterns, leader elections, order + +Each loop runs independently and they communicate through the shared +Lamport graph. +""" + +from __future__ import annotations + +import asyncio +import logging +import os +import time +from typing import Optional + +from crisis.crypto import digest +from crisis.graph import LamportGraph +from crisis.gossip import GossipServer, NetworkView, PeerInfo +from crisis.message import Message, Vertex, ID_LENGTH, NONCE_LENGTH +from crisis.order import LeaderStream, compute_order +from crisis.rounds import compute_rounds +from crisis.voting import compute_virtual_leader_election, compute_safe_voting_pattern +from crisis.weight import ProofOfWorkWeight, DifficultyOracle + +logger = logging.getLogger(__name__) + + +class CrisisNode: + """A full Crisis protocol node. + + Combines all protocol components into a single running process: + - Lamport graph (the shared DAG) + - Weight system (PoW) + - Difficulty oracle + - Gossip server (member discovery + message dissemination) + - Consensus engine (rounds, voting, ordering) + + Attributes: + process_id: This node's virtual process identity. + graph: The local Lamport graph. + leader_stream: The evolving total order leader stream. + network_view: Known peers in the network. + """ + + def __init__(self, host: str = "127.0.0.1", port: int = 9000, + min_pow_zeros: int = 1, + difficulty_constant: int = 4, + connectivity_k: int = 2, + message_interval: float = 3.0, + consensus_interval: float = 5.0, + seed_peers: list[tuple[str, int]] | None = None): + # Identity: use a hash of host:port as this node's virtual process id + self.process_id = digest(f"{host}:{port}".encode())[:ID_LENGTH] + self.host = host + self.port = port + + # Protocol components + self.weight_system = ProofOfWorkWeight(min_leading_zeros=min_pow_zeros) + self.difficulty = DifficultyOracle(constant_difficulty=difficulty_constant) + self.connectivity_k = connectivity_k + self.graph = LamportGraph(weight_system=self.weight_system) + self.leader_stream = LeaderStream() + + # Timing + self.message_interval = message_interval + self.consensus_interval = consensus_interval + + # Network + self.network_view = NetworkView() + if seed_peers: + for h, p in seed_peers: + self.network_view.add_peer(PeerInfo(host=h, port=p)) + + self.gossip = GossipServer( + host=host, port=port, + graph=self.graph, + network_view=self.network_view, + ) + + # State + self._running = False + self._message_count = 0 + + # Callbacks for monitoring + self.on_new_vertex: Optional[callable] = None + self.on_round_update: Optional[callable] = None + self.on_order_update: Optional[callable] = None + + # ------------------------------------------------------------------ + # Main entry point + # ------------------------------------------------------------------ + + async def run(self) -> None: + """Start all protocol loops concurrently. + + This is the Crisis protocol (Section 5.9): three concurrent loops. + """ + self._running = True + logger.info( + f"Crisis node starting on {self.host}:{self.port} " + f"(id={self.process_id.hex()[:16]}...)" + ) + + try: + await asyncio.gather( + self._gossip_loop(), + self._message_generation_loop(), + self._consensus_loop(), + ) + except asyncio.CancelledError: + logger.info("Crisis node shutting down") + finally: + self._running = False + + async def stop(self) -> None: + self._running = False + await self.gossip.stop() + + # ------------------------------------------------------------------ + # Loop 1: Gossip (Algorithms 3 & 4) + # ------------------------------------------------------------------ + + async def _gossip_loop(self) -> None: + """Run the gossip server (member discovery + message dissemination).""" + try: + await self.gossip.start() + except Exception as e: + logger.error(f"Gossip loop error: {e}") + + # ------------------------------------------------------------------ + # Loop 2: Message generation (Algorithm 1) + # ------------------------------------------------------------------ + + async def _message_generation_loop(self) -> None: + """Periodically generate new messages and add them to the graph. + + Each message: + 1. References the last same-id message (chain constraint) + 2. References a sample of other vertices (cross-links for connectivity) + 3. Has a PoW nonce meeting the weight threshold + 4. Carries an application payload + """ + while self._running: + await asyncio.sleep(self.message_interval) + + try: + payload = self._generate_payload() + message = self.graph.generate_message( + self.process_id, payload, self.weight_system + ) + vertex = self.graph.extend(message) + + if vertex is not None: + self._message_count += 1 + logger.debug( + f"Generated message #{self._message_count}: {vertex}" + ) + if self.on_new_vertex: + self.on_new_vertex(vertex) + + except Exception as e: + logger.error(f"Message generation error: {e}") + + def _generate_payload(self) -> bytes: + """Generate a payload for a new message. + + In this PoC, payloads are simple timestamped entries. + A real application would put actual data here. + """ + self._message_count += 1 + return f"msg-{self._message_count}-{time.time():.3f}".encode() + + # ------------------------------------------------------------------ + # Loop 3: Consensus (Algorithms 5, 6, 7, 9, 10) + # ------------------------------------------------------------------ + + async def _consensus_loop(self) -> None: + """Periodically recompute consensus state. + + From Section 5.9 and the proof section (Section 6): + "algorithms (5), (6) and (7) are executed in that order concurrently + on each vertex from V... the total order loop (9) runs concurrently + and waits for updates of the leader stream." + """ + while self._running: + await asyncio.sleep(self.consensus_interval) + + if self.graph.vertex_count() == 0: + continue + + try: + # Step 1: Compute rounds (Algorithm 5) + compute_rounds(self.graph, self.difficulty, self.connectivity_k) + + if self.on_round_update: + self.on_round_update(self.graph) + + # Step 2: Compute safe voting patterns (Algorithm 6) + for vertex in self.graph.all_vertices(): + if vertex.is_last: + compute_safe_voting_pattern( + vertex, self.graph, self.difficulty, + self.connectivity_k + ) + + # Step 3: Virtual leader election (Algorithm 7) + leader_dict: dict[int, list[tuple[int, Message]]] = {} + for vertex in self.graph.all_vertices(): + if vertex.svp: + compute_virtual_leader_election( + vertex, self.graph, self.difficulty, + self.connectivity_k, leader_dict + ) + + # Update leader stream from election results + for round_num, entries in leader_dict.items(): + for deciding_round, leader_msg in entries: + self.leader_stream.update( + round_num, deciding_round, leader_msg + ) + + # Step 4: Compute total order (Algorithms 9 & 10) + ordered = compute_order(self.graph, self.leader_stream) + + if ordered and self.on_order_update: + self.on_order_update(ordered) + + logger.debug( + f"Consensus: {self.graph.vertex_count()} vertices, " + f"{len(self.leader_stream.leaders)} leaders, " + f"{len(ordered)} ordered" + ) + + except Exception as e: + logger.error(f"Consensus loop error: {e}") + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + def submit_message(self, payload: bytes) -> Optional[Vertex]: + """Submit an application message to be ordered by the protocol.""" + message = self.graph.generate_message( + self.process_id, payload, self.weight_system + ) + return self.graph.extend(message) + + def get_total_order(self) -> list[tuple[int, bytes]]: + """Get the current total order as (position, payload) pairs.""" + ordered = compute_order(self.graph, self.leader_stream) + return [ + (v.total_position, v.payload) + for v in ordered + if v.total_position is not None + ] + + def status(self) -> dict: + """Return a summary of this node's current state.""" + from crisis.rounds import max_round as get_max_round + return { + "process_id": self.process_id.hex()[:16], + "address": f"{self.host}:{self.port}", + "vertices": self.graph.vertex_count(), + "process_ids": len(self.graph.all_process_ids()), + "max_round": get_max_round(self.graph), + "leaders": len(self.leader_stream.leaders), + "peers": len(self.network_view.peers), + "messages_generated": self._message_count, + } + + +# --------------------------------------------------------------------------- +# CLI entry point +# --------------------------------------------------------------------------- + +def main(): + """Run a Crisis node from the command line.""" + import argparse + + parser = argparse.ArgumentParser( + description="Crisis Protocol Node", + epilog="Probabilistically self-organizing total order in P2P networks" + ) + parser.add_argument("--host", default="127.0.0.1", help="Listen address") + parser.add_argument("--port", type=int, default=9000, help="Listen port") + parser.add_argument("--pow-zeros", type=int, default=1, + help="Min PoW leading zeros (weight threshold)") + parser.add_argument("--difficulty", type=int, default=4, + help="Difficulty oracle constant") + parser.add_argument("--msg-interval", type=float, default=3.0, + help="Seconds between message generation") + parser.add_argument("--peers", nargs="*", default=[], + help="Seed peers as host:port") + + args = parser.parse_args() + + seed_peers = [] + for peer_str in args.peers: + h, p = peer_str.rsplit(":", 1) + seed_peers.append((h, int(p))) + + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(name)s] %(levelname)s: %(message)s" + ) + + node = CrisisNode( + host=args.host, + port=args.port, + min_pow_zeros=args.pow_zeros, + difficulty_constant=args.difficulty, + seed_peers=seed_peers, + message_interval=args.msg_interval, + ) + + asyncio.run(node.run()) + + +if __name__ == "__main__": + main() diff --git a/src/crisis/order.py b/src/crisis/order.py new file mode 100644 index 0000000..3c7e264 --- /dev/null +++ b/src/crisis/order.py @@ -0,0 +1,254 @@ +""" +Total Order (Section 5.8) + +As time goes by and the Lamport graph grows, more and more round leaders +are computed and incorporated into the global leader stream LEADER_G(·). + +Algorithm 9 (Order loop): watches for leader stream updates and recomputes +total order. Total order is achieved by topological sorting on the past +of appropriate vertices. + +Algorithm 10 (Total order using Kahn's algorithm): generates total order +in linear runtime by processing vertices without outgoing causal edges first, +using voting weight to break ties among spacelike vertices. + +The total order converges probabilistically: any two non-byzantine processes +will eventually compute the same total order (Proposition 6.21). + +Definition 5.17 (Leader Stream): + LEADER_G : N -> Option<(uint, MESSAGE)> + is called the *global leader stream* of the Lamport graph. + +Corollary 6.19 (Leader stream convergence): + If the probability for new rounds and safe voting pattern is not zero, + the leader streams of any two honest processes will converge. +""" + +from __future__ import annotations + +from typing import Optional + +from crisis.graph import LamportGraph +from crisis.message import Message, Vertex + + +# --------------------------------------------------------------------------- +# Leader Stream (Definition 5.17) +# --------------------------------------------------------------------------- + +class LeaderStream: + """The global leader stream of a Lamport graph. + + Maps round numbers to (deciding_round, leader_message) pairs. + Uses the Nakamoto longest chain rule: when a new leader is decided + in a later round, it may replace leaders decided in earlier rounds. + + The leader stream converges to contain a single element per round + (Theorem 6.18), and honest processes' leader streams converge to + the same values (Corollary 6.19). + """ + + def __init__(self): + # round_number -> (deciding_round, leader_message) + self.leaders: dict[int, tuple[int, Message]] = {} + + def update(self, round_number: int, deciding_round: int, + leader_message: Message) -> bool: + """Update the leader for a round via the Nakamoto longest chain rule. + + Algorithm 8 (LONG_CHAIN): keep only the leader decided in the + highest round. Delete leaders from previous rounds that have + lower deciding rounds. + + Returns True if the leader stream was modified. + """ + current = self.leaders.get(round_number) + + if current is not None: + existing_deciding_round, _ = current + if existing_deciding_round >= deciding_round: + return False # Already have a leader from a higher round + + self.leaders[round_number] = (deciding_round, leader_message) + + # Prune: remove leaders with lower deciding rounds + # (longest chain rule -- keep only the longest) + max_deciding = max(dr for dr, _ in self.leaders.values()) + to_remove = [] + for r, (dr, _) in self.leaders.items(): + if dr < max_deciding and r < round_number: + to_remove.append(r) + for r in to_remove: + del self.leaders[r] + + return True + + def get_leader(self, round_number: int) -> Optional[Message]: + """Get the current leader message for a round, if any.""" + entry = self.leaders.get(round_number) + return entry[1] if entry else None + + def max_round(self) -> int: + """Highest round with a decided leader.""" + return max(self.leaders.keys()) if self.leaders else -1 + + def all_leaders(self) -> list[tuple[int, Message]]: + """Return all leaders ordered by round number.""" + return [(r, msg) for r, (_, msg) in sorted(self.leaders.items())] + + def __repr__(self) -> str: + rounds = sorted(self.leaders.keys()) + return f"LeaderStream(rounds={rounds})" + + +# --------------------------------------------------------------------------- +# Algorithm 9: Order Loop +# --------------------------------------------------------------------------- + +def compute_order(graph: LamportGraph, leader_stream: LeaderStream) -> list[Vertex]: + """Algorithm 9: compute total order from the leader stream. + + Pseudocode: + 1: loop order update loop + 2: wait for LEADER_G(·) to change + 3: s <- min round of all changed LEADER_G(t) + 4: r <- max round of all LEADER_G(t) ≠ ∅ + 5: v_{l_r} <- leader in highest round, smallest s in G + 6: n <- max(v.total_position | v ∈ Ord_G(v_{l_{r-1}})) + 7: for x ≤ t ≤ r do + 8: randomly choose (p, l_t) ∈ LEADER_G(t) + 9: if l_t ≠ ∅ then + 10: ORDER(Ord_G(v_t), n) ▷ v_t.m = l_t + 11: end if + 12: end for + 13: end loop + + For this PoC, we compute the order in a single pass over the current + leader stream state. + """ + if not leader_stream.leaders: + return [] + + ordered: list[Vertex] = [] + position = 0 + + # Process leaders in round order + for round_number, leader_message in leader_stream.all_leaders(): + # Find the vertex corresponding to this leader message + leader_digest = leader_message.compute_digest() + leader_vertex = graph.get_vertex(leader_digest) + + if leader_vertex is None: + continue + + # Order the past of this leader vertex (excluding already-ordered) + past_vertices = graph.past(leader_vertex) + already_ordered = {v.message_digest for v in ordered} + new_vertices = [ + v for v in past_vertices + if v.message_digest not in already_ordered + ] + + # Sort new vertices using Kahn's algorithm (Algorithm 10) + sorted_new = _kahns_total_order(new_vertices, graph) + + for v in sorted_new: + v.total_position = position + ordered.append(v) + position += 1 + + return ordered + + +# --------------------------------------------------------------------------- +# Algorithm 10: Total Order using Kahn's Algorithm +# --------------------------------------------------------------------------- + +def _kahns_total_order(vertices: list[Vertex], graph: LamportGraph) -> list[Vertex]: + """Algorithm 10: generate total order using Kahn's algorithm. + + Kahn's algorithm in its "arrow reversed" incarnation: we want to order + the past before the future in our Lamport graph. + + Pseudocode from the paper: + 1: procedure ORDER(dag:Ord(v), uint:last) + 2: n <- last + 1 + 3: S <- set of all elements of Ord(v) with no outgoing edges + 4: while S ≠ ∅ do + 5: remove x with highest weight w(x) from S + 6: x.total_position <- n + 7: n <- n + 1 + 8: for each vertex y ∈ Ord(v) with edge e : y -> x do + 9: remove edge e from Ord(v) + 10: if y has no other outgoing edge then + 11: S <- S ∪ {y} + 12: end if + 13: end for + 14: end while + 15: end procedure + + Tie-breaking by voting weight ensures that all honest processes produce + the same total order from equivalent Lamport graphs. + """ + if not vertices: + return [] + + # Build a local subgraph for just these vertices + vertex_set = {v.message_digest for v in vertices} + + # out_degree: for each vertex, count edges to other vertices in this set + out_edges: dict[bytes, set[bytes]] = {} + in_edges: dict[bytes, set[bytes]] = {} + + for v in vertices: + d = v.message_digest + out_edges[d] = set() + in_edges[d] = set() + + for v in vertices: + d = v.message_digest + for cause_d in graph.edges.get(d, set()): + if cause_d in vertex_set: + out_edges[d].add(cause_d) + in_edges[cause_d].add(d) + + # Start with vertices that have no outgoing edges (sinks = earliest causes) + result: list[Vertex] = [] + available = [ + v for v in vertices + if len(out_edges[v.message_digest]) == 0 + ] + + while available: + # Remove the vertex with highest weight (deterministic tie-breaking) + available.sort(key=lambda v: graph.vertex_weight(v), reverse=True) + chosen = available.pop(0) + result.append(chosen) + + # Remove edges pointing to chosen + chosen_d = chosen.message_digest + for referrer_d in list(in_edges.get(chosen_d, set())): + out_edges[referrer_d].discard(chosen_d) + if len(out_edges[referrer_d]) == 0: + referrer_vertex = graph.get_vertex(referrer_d) + if referrer_vertex is not None and referrer_vertex not in result: + available.append(referrer_vertex) + + return result + + +# --------------------------------------------------------------------------- +# Convenience: full pipeline +# --------------------------------------------------------------------------- + +def total_order_positions(graph: LamportGraph, + leader_stream: LeaderStream) -> dict[bytes, int]: + """Return a mapping of message digest -> total order position. + + This is the final output of the Crisis protocol: a total order on + messages that respects causality and is probabilistically invariant + among all honest participants. + """ + ordered = compute_order(graph, leader_stream) + return {v.message_digest: v.total_position for v in ordered + if v.total_position is not None} diff --git a/src/crisis/rounds.py b/src/crisis/rounds.py new file mode 100644 index 0000000..227991a --- /dev/null +++ b/src/crisis/rounds.py @@ -0,0 +1,231 @@ +""" +Virtual Synchronous Rounds (Section 5.3) + +Lamport graphs represent a timelike order between vertices that we interpret +as virtual communication channels. Going one step further, we can think from +inside the Lamport graph to define a virtual clock tick as a transition from +one vertex to another. + +This simple idea allows for internal synchronism that enables us to execute +strongly synchronous agreement protocols like Feldman & Micali's BA* +virtually, but without any compromise in external asynchronism. + +Algorithm 5 (Virtual synchronous rounds): + The algorithm computes *round numbers* and the *is_last* property + of any vertex. + + - The round number is computed by taking the largest round of all + direct causes. + - If the vertex is a direct effect of a current round vertex with + the is_last property, a new round begins. + - If the vertex has enough last vertices of the previous round in its + past and it is k-reachable from all of them, the vertex becomes a + last vertex in its own round. + +Definition 5.1 (k-reachability): + v_hat is said to be k-reachable from v, if the overall weight of all + vertices in all paths from v to v_hat is greater than k. + +Proposition 5.3 (Round invariance): + The round number and is_last property do not depend on the actual + Lamport graph, but are the same for equivalent vertices. +""" + +from __future__ import annotations + +from crisis.graph import LamportGraph +from crisis.message import Vertex +from crisis.weight import DifficultyOracle + + +def compute_rounds(graph: LamportGraph, difficulty: DifficultyOracle, + connectivity_k: int = 2) -> None: + """Execute Algorithm 5 on all vertices in the graph. + + This computes v.round and v.is_last for every vertex v in the graph. + The algorithm processes vertices in causal order (causes before effects) + to ensure dependencies are resolved before they are needed. + + Args: + graph: The Lamport graph to process. + difficulty: The difficulty oracle d : N -> W. + connectivity_k: The connectivity parameter k for k-reachability. + """ + # Process vertices in topological order (causes first) + ordered = _topological_sort(graph) + + for vertex in ordered: + _compute_round_for_vertex(vertex, graph, difficulty, connectivity_k) + + +def _compute_round_for_vertex(vertex: Vertex, graph: LamportGraph, + difficulty: DifficultyOracle, + connectivity_k: int) -> None: + """Algorithm 5: compute round number and is_last for a single vertex. + + Pseudocode from the paper: + 1: procedure ROUND(vertex:v, lamport_graph:G) + 2: N_v <- {v_hat ∈ G | v -> v_hat} # direct causes + 3: r <- max({v_hat.round | v_hat ∈ N_v} ∪ {0}) + 4: if there is a v_hat ∈ N_v with v_hat.is_last and v_hat.round = r then + 5: v.round <- r + 1 + 6: else + 7: v.round <- r + 8: end if + 9: S_r <- {v_hat ∈ G | v_hat.round = v.round - 1, v_hat.is_last, v_hat ≤_k v} + 10: if w(S_r) > 3 * d_r then + 11: v.is_last <- true + 12: else + 13: v.is_last <- (r = 0) + 14: end if + 15: end procedure + """ + # Step 2: direct causes + direct_causes = graph.direct_causes(vertex) + + # Step 3: max round of direct causes (default 0 if no causes) + if direct_causes: + max_round = max( + (dc.round if dc.round is not None else 0) for dc in direct_causes + ) + else: + max_round = 0 + + # Steps 4-8: determine this vertex's round + # If any direct cause is a "last vertex" of the current max round, + # this vertex starts a new round. + has_last_cause_in_max_round = any( + dc.is_last and dc.round == max_round + for dc in direct_causes + if dc.round is not None and dc.is_last is not None + ) + + if has_last_cause_in_max_round: + vertex.round = max_round + 1 + else: + vertex.round = max_round + + # Steps 9-14: determine is_last + r = vertex.round + if r == 0: + # All round-0 vertices are "last" (bootstrapping) + vertex.is_last = True + return + + # Find last vertices of the previous round that are k-reachable from v + d_r = difficulty.difficulty(r) + + previous_round_lasts = [ + v_hat for v_hat in graph.all_vertices() + if v_hat.round == r - 1 + and v_hat.is_last + and _is_k_reachable(v_hat, vertex, graph, connectivity_k) + ] + + # Weight of k-reachable last vertices from previous round + weight_of_previous_lasts = graph.set_weight(previous_round_lasts) + + if weight_of_previous_lasts > 3 * d_r: + vertex.is_last = True + else: + vertex.is_last = False + + +def _is_k_reachable(v_from: Vertex, v_to: Vertex, + graph: LamportGraph, k: int) -> bool: + """Check k-reachability (Definition 5.1). + + v_hat is k-reachable from v if the overall weight of all vertices in + all paths from v to v_hat is greater than k. + + For simplicity in this PoC, we approximate this by checking if v_from + is in the past of v_to and the total weight along the path exceeds k. + + The paper notes (page 11): "counting disjoint paths is computationally + expensive and not really necessary in our setting... all we need is some + insurance that information flows through enough real world processes." + We use total path weight as a simpler proxy. + """ + if v_from not in graph.past(v_to): + return False + + # Compute the weight of all vertices in the path from v_from to v_to + # (all vertices that are in both the future of v_from and the past of v_to) + past_of_to = graph.past(v_to) + future_of_from = graph.future(v_from) + + path_vertices = past_of_to & future_of_from + total_weight = graph.set_weight(path_vertices) + + return total_weight > k + + +def _topological_sort(graph: LamportGraph) -> list[Vertex]: + """Sort vertices in causal order: causes come before their effects. + + Uses Kahn's algorithm. Vertices with no causes (sources) come first. + This ensures that when we process a vertex, all its causes already + have their round numbers computed. + """ + # Compute in-degree (number of causes each vertex has within the graph) + in_degree: dict[bytes, int] = {} + for d, v in graph.vertices.items(): + in_degree[d] = 0 + + for d, v in graph.vertices.items(): + for ref_d in graph.edges.get(d, set()): + if ref_d in graph.vertices: + # ref_d is a cause of d, so d has an additional in-edge + # But we want causal order: causes first + # edges go from effect -> cause, so we need reverse + pass + + # Actually: edges[d] contains the causes of d (d -> cause). + # For topological sort where causes come first, we need: + # in_degree[d] = number of digests in edges[d] that are in the graph + for d in graph.vertices: + count = 0 + for cause_d in graph.edges.get(d, set()): + if cause_d in graph.vertices: + count += 1 + in_degree[d] = count + + # Start with vertices that have no causes (in_degree = 0) + queue = [d for d, deg in in_degree.items() if deg == 0] + result = [] + + while queue: + current = queue.pop(0) + result.append(graph.vertices[current]) + + # For each vertex that current is a cause of (reverse edges) + for effect_d in graph.reverse_edges.get(current, set()): + if effect_d in in_degree: + in_degree[effect_d] -= 1 + if in_degree[effect_d] == 0: + queue.append(effect_d) + + return result + + +# --------------------------------------------------------------------------- +# Queries on computed rounds +# --------------------------------------------------------------------------- + +def last_vertices_in_round(graph: LamportGraph, round_number: int) -> list[Vertex]: + """Return all last vertices in a given round.""" + return [ + v for v in graph.all_vertices() + if v.round == round_number and v.is_last + ] + + +def max_round(graph: LamportGraph) -> int: + """Return the highest round number in the graph.""" + rounds = [v.round for v in graph.all_vertices() if v.round is not None] + return max(rounds) if rounds else 0 + + +def vertices_in_round(graph: LamportGraph, round_number: int) -> list[Vertex]: + """Return all vertices in a given round.""" + return [v for v in graph.all_vertices() if v.round == round_number] diff --git a/src/crisis/voting.py b/src/crisis/voting.py new file mode 100644 index 0000000..02e08cd --- /dev/null +++ b/src/crisis/voting.py @@ -0,0 +1,527 @@ +""" +Virtual Voting, Safe Voting Patterns, and Leader Election (Section 5) + +This module implements the heart of the Crisis protocol: the virtual voting +mechanism that achieves total order without ever sending explicit vote messages. + +Key concepts: + +5.5 Virtual Process Sortition & Knowledge Graphs + - Knowledge graph (Def 5.8): quotient graph projecting vertices to virtual + processes, representing what each process "knows" about others. + - Quorum selector (Def 5.11): deterministically chooses a subset of virtual + processes for each round -- the quorum that participates in agreement. + +5.6 Safe Voting Pattern + - Voting sets (Def 5.12): the set of vertices participating in round s + agreement, reachable with connectivity k from vertex v. + - Algorithm 6: computes the safe voting pattern -- a nested sequence of + rounds where voting took place with appropriately bounded byzantine weight. + +5.7 Local Leader Election + - Algorithm 7: virtual leader elections -- an adaptation of Chen, Feldman + & Micali's BA* to virtual voting on Lamport graphs. + - Three stage types: initial proposal (δ=0), presorting/gradecast (δ∈{1,2}), + and BBA* binary agreement (δ≥3) with "coin fixed to 0/1" and "genuine + coin flip" sub-stages. + +5.8 Longest Chain Rule + - Algorithm 8: maintains the leader stream by keeping only the longest + chain of round leaders (similar to Nakamoto's longest chain rule). +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Optional + +from crisis.crypto import digest, least_significant_bit +from crisis.graph import LamportGraph +from crisis.message import Message, Vertex, Vote, EMPTY_MESSAGE_DIGEST +from crisis.rounds import last_vertices_in_round, max_round +from crisis.weight import DifficultyOracle + + +# --------------------------------------------------------------------------- +# Knowledge Graph (Definition 5.8) +# --------------------------------------------------------------------------- + +@dataclass +class KnowledgeGraph: + """The round s knowledge graph of vertex v (Definition 5.8). + + Given rounds s < r, a Lamport graph G, and v a last message in round r, + the knowledge graph Π^s_v is the quotient graph G^s_v / ≃_id. + + Each node in the knowledge graph represents a virtual process (identified + by its id). An edge from process id to id' means that some vertex with + v.id = id in round s has a vertex with v_hat.id = id' in its past. + + This represents what each virtual process "knows" about others. + """ + # id -> set of ids that this process has edges to + edges: dict[bytes, set[bytes]] = field(default_factory=dict) + # id -> total weight of vertices in this equivalence class + weights: dict[bytes, int] = field(default_factory=dict) + + +def build_knowledge_graph(vertex: Vertex, round_s: int, + graph: LamportGraph) -> KnowledgeGraph: + """Build the round s knowledge graph for vertex v. + + Collects all round-s vertices in v's past, groups them by id, + and builds the quotient graph. + """ + kg = KnowledgeGraph() + past = graph.past(vertex) + + # Find all round-s vertices in v's past + round_s_vertices = [v for v in past if v.round == round_s] + + # Group by id and compute edges + for v_s in round_s_vertices: + vid = v_s.id + if vid not in kg.edges: + kg.edges[vid] = set() + if vid not in kg.weights: + kg.weights[vid] = 0 + + kg.weights[vid] = graph.weight_system.weight_sum( + kg.weights[vid], graph.vertex_weight(v_s) + ) + + # Add edges based on what this vertex references + for cause in graph.direct_causes(v_s): + if cause.round is not None and cause.round == round_s: + kg.edges[vid].add(cause.id) + + return kg + + +# --------------------------------------------------------------------------- +# Quorum Selector (Definition 5.11) +# --------------------------------------------------------------------------- + +def select_quorum(knowledge_graph: KnowledgeGraph, n: int = 3) -> set[bytes]: + """Select a quorum from a knowledge graph (Definition 5.11). + + Example 3 (Highest voting weight quorum): + Choose the weakly connected component with the highest combined voting + weight, then take the heaviest n virtual processes from it. + + The quorum selector serves as a filter to reduce byzantine noise that + might appear in the voting process. By restricting to a heavily + connected component, faulty behavior based on graph partition is reduced. + """ + if not knowledge_graph.edges: + return set() + + # Find weakly connected components using simple BFS + all_ids = set(knowledge_graph.edges.keys()) + visited: set[bytes] = set() + components: list[set[bytes]] = [] + + for start_id in all_ids: + if start_id in visited: + continue + component: set[bytes] = set() + queue = [start_id] + while queue: + current = queue.pop(0) + if current in visited: + continue + visited.add(current) + component.add(current) + # Follow edges in both directions (weakly connected) + for neighbor in knowledge_graph.edges.get(current, set()): + if neighbor not in visited and neighbor in all_ids: + queue.append(neighbor) + # Reverse edges + for other_id, neighbors in knowledge_graph.edges.items(): + if current in neighbors and other_id not in visited: + queue.append(other_id) + components.append(component) + + # Choose the component with highest total weight + def component_weight(comp: set[bytes]) -> int: + return sum(knowledge_graph.weights.get(pid, 0) for pid in comp) + + best_component = max(components, key=component_weight) + + # Take the n heaviest processes from this component + sorted_by_weight = sorted( + best_component, + key=lambda pid: knowledge_graph.weights.get(pid, 0), + reverse=True + ) + + return set(sorted_by_weight[:n]) + + +# --------------------------------------------------------------------------- +# Voting Sets (Definition 5.12) +# --------------------------------------------------------------------------- + +def voting_set(vertex: Vertex, round_s: int, connectivity_k: int, + graph: LamportGraph) -> set[Vertex]: + """Compute S_v(s,k): v's round s voting set (Definition 5.12). + + S_v(s,k) := { x | x.id ∈ Q(v,s) ∧ x ≤_{(r-s)*k} v + ∧ x.round = s ∧ x.is_last = true } + + The voting set consists of all last vertices in round s that: + 1. Belong to a quorum-selected virtual process + 2. Are k-reachable from v (with distance scaled by round gap) + 3. Are in v's past + """ + if vertex.round is None: + return set() + + r = vertex.round + if round_s >= r: + return set() + + # Build knowledge graph and select quorum + kg = build_knowledge_graph(vertex, round_s, graph) + quorum = select_quorum(kg) + + past_of_v = graph.past(vertex) + + result = set() + for v_hat in past_of_v: + if (v_hat.round == round_s + and v_hat.is_last + and v_hat.id in quorum): + result.add(v_hat) + + return result + + +# --------------------------------------------------------------------------- +# Algorithm 6: Safe Voting Pattern (Section 5.6) +# --------------------------------------------------------------------------- + +def compute_safe_voting_pattern(vertex: Vertex, graph: LamportGraph, + difficulty: DifficultyOracle, + connectivity_k: int = 2) -> None: + """Algorithm 6: compute the safe voting pattern for a vertex. + + The safe voting pattern v.svp is a totally ordered set of round numbers + where "safe" voting took place. Safe means: + - The voting set has enough overall weight + - The svp of all members agree + - Byzantine weight is bounded + + Pseudocode from the paper: + 1: procedure SVP(vertex:v, lamport_graph:G) + 2: v.svp <- ∅ + 3: if v.is_last and [safe voting pattern conditions are met] then + 4: s <- maximum of all such k + 5: v.svp <- v.svp ∪ {s} for all t ≤ s + 6: end if + 7: end procedure + + The procedure checks if the current vertex's round qualifies as a new + entry in the safe voting pattern by verifying weight and agreement + conditions from its voting set. + """ + vertex.svp = [] + + if not vertex.is_last or vertex.round is None or vertex.round == 0: + return + + r = vertex.round + + # Check each previous round for safe voting pattern membership + for s in range(r): + d_s = difficulty.difficulty(s) + + # Get voting set for round s + vs = voting_set(vertex, s, connectivity_k, graph) + if not vs: + continue + + total_weight = graph.set_weight(vs) + + # Check if voting weight exceeds threshold (6 * d_s from Eq. 8) + if total_weight <= 6 * d_s: + continue + + # Check that all members of the voting set have compatible svp + svps_agree = True + for x in vs: + for y in vs: + if x.svp != y.svp: + # Allow prefix agreement + min_len = min(len(x.svp), len(y.svp)) + if x.svp[:min_len] != y.svp[:min_len]: + svps_agree = False + break + if not svps_agree: + break + + if svps_agree: + vertex.svp.append(s) + + # svp is a nested sequence: add current round + if vertex.svp: + vertex.svp.append(r) + + +# --------------------------------------------------------------------------- +# Initial Vote Function (Definition 5.16, Example 4) +# --------------------------------------------------------------------------- + +def initial_vote(vertices: set[Vertex], graph: LamportGraph) -> Optional[Message]: + """INITIAL_VOTE: deterministically choose a leader proposal (Def 5.16). + + Example 4 (Highest weight): Choose the underlying message of the highest + voting weight vertex. Since we assume it is infeasible to have different + vertices of equal weight, this is practically deterministic. + + The initial vote function is a system parameter. Different choices lead + to different long-term behavior. Ideally all members of a safe voting + pattern would compute the same initial vote. + """ + if not vertices: + return None + + best_vertex = max(vertices, key=lambda v: graph.vertex_weight(v)) + return best_vertex.m + + +# --------------------------------------------------------------------------- +# Algorithm 7: Virtual Leader Elections (Section 5.7) +# --------------------------------------------------------------------------- + +def compute_virtual_leader_election(vertex: Vertex, graph: LamportGraph, + difficulty: DifficultyOracle, + connectivity_k: int, + leader_stream: dict[int, list[tuple[int, Message]]]) -> None: + """Algorithm 7: compute votes for all rounds in v's safe voting pattern. + + This is the core virtual BA* protocol. For each element t in v.svp, + the vertex computes a vote v.vote(t) = (l, b) based on the stage δ + (the position of that round in the svp). + + Stage types (determined by δ = d_{v.svp}(s, t)): + δ = 0: Initial leader proposal + δ = 1: Leader presorting (gradecast step) + δ = 2: BBA* initialization (gradecast step) + δ ≥ 3: Binary agreement rounds + δ mod 3 = 0: Coin fixed to 0 + δ mod 3 = 1: Coin fixed to 1 + δ mod 3 = 2: Genuine coin flip + + The paper notes: "every step is entirely virtual and no votes are + actually sent to other real world processes." + """ + if not vertex.svp: + return + + s = max(vertex.svp) if vertex.svp else None + if s is None: + return + + for t_idx, t in enumerate(vertex.svp): + delta = t_idx # stage = position in svp + _compute_vote_for_stage(vertex, t, delta, s, graph, difficulty, + connectivity_k, leader_stream) + + +def _compute_vote_for_stage(vertex: Vertex, t: int, delta: int, s: int, + graph: LamportGraph, difficulty: DifficultyOracle, + connectivity_k: int, + leader_stream: dict[int, list[tuple[int, Message]]]) -> None: + """Compute vertex's vote for a specific stage of the virtual leader election. + + Implements the branching logic of Algorithm 7 (pages 19-20 of the paper). + """ + d_s = difficulty.difficulty(s) + vs = voting_set(vertex, t, connectivity_k, graph) + n = graph.set_weight(vs) + + NON_LEADER = None # ∅ in the paper + + if delta == 0: + # Stage 0: Initial leader proposal + l = initial_vote(vs, graph) + vertex.vote[t] = Vote(message=l, binary=None) # (INITIAL_VOTE(S), ⊥) + + elif delta == 1: + # Stage 1: Leader presorting + # Find message with highest round-t voting weight in S + l = _highest_weight_message(vs, graph) + + if l is not None: + # Check if l has super majority weight + l_weight = _vote_weight_for(vs, t, l, None, graph) # votes for (l, ⊥) + if l_weight > n - d_s: + vertex.vote[t] = Vote(message=l, binary=None) # (l, ⊥) + else: + vertex.vote[t] = Vote(message=NON_LEADER, binary=None) # (∅, ⊥) + else: + vertex.vote[t] = Vote(message=NON_LEADER, binary=None) + + elif delta == 2: + # Stage 2: BBA* initialization (gradecast) + l = _highest_weight_message(vs, graph) + + if l is not None: + l_weight_undecided = _vote_weight_for(vs, t, l, None, graph) + if l_weight_undecided > n - d_s: + vertex.vote[t] = Vote(message=l, binary=0) + else: + l_weight_1 = _vote_weight_for(vs, t, l, 1, graph) + if l_weight_1 > d_s: + vertex.vote[t] = Vote(message=l, binary=1) + else: + vertex.vote[t] = Vote(message=NON_LEADER, binary=1) + else: + vertex.vote[t] = Vote(message=NON_LEADER, binary=1) + + else: + # Stage δ ≥ 3: Binary agreement (BBA*) + coin_stage = delta % 3 + l = _highest_weight_message(vs, graph) + + if coin_stage == 0: + # Coin fixed to 0 + _bba_coin_fixed(vertex, t, vs, l, n, d_s, graph, + leader_stream, s, fixed_value=0) + elif coin_stage == 1: + # Coin fixed to 1 + _bba_coin_fixed(vertex, t, vs, l, n, d_s, graph, + leader_stream, s, fixed_value=1) + else: + # Genuine coin flip (coin_stage == 2) + _bba_genuine_coin(vertex, t, vs, l, n, d_s, graph) + + +def _bba_coin_fixed(vertex: Vertex, t: int, vs: set[Vertex], + l: Optional[Message], n: int, d_s: int, + graph: LamportGraph, + leader_stream: dict[int, list[tuple[int, Message]]], + s: int, fixed_value: int) -> None: + """BBA* stage with coin fixed to 0 or 1.""" + other_value = 1 - fixed_value + + if l is not None: + weight_for_fixed = _vote_weight_for_binary(vs, t, fixed_value, graph) + if weight_for_fixed > n - d_s: + vertex.vote[t] = Vote(message=l, binary=fixed_value) + # If weight = n, we have agreement: update leader stream + if weight_for_fixed == n: + _update_leader_stream(leader_stream, l, s) + return + + weight_for_other = _vote_weight_for_binary(vs, t, other_value, graph) + if weight_for_other > n - d_s: + vertex.vote[t] = Vote(message=l, binary=other_value) + return + + vertex.vote[t] = Vote(message=l, binary=fixed_value) + + +def _bba_genuine_coin(vertex: Vertex, t: int, vs: set[Vertex], + l: Optional[Message], n: int, d_s: int, + graph: LamportGraph) -> None: + """BBA* stage with genuine coin flip.""" + if l is not None: + weight_0 = _vote_weight_for_binary(vs, t, 0, graph) + if weight_0 > n - d_s: + vertex.vote[t] = Vote(message=l, binary=0) + return + + weight_1 = _vote_weight_for_binary(vs, t, 1, graph) + if weight_1 > n - d_s: + vertex.vote[t] = Vote(message=l, binary=1) + return + + # Genuine coin flip: use LSB of hash of heaviest vertex's message + if vs: + heaviest = max(vs, key=lambda v: graph.vertex_weight(v)) + h = heaviest.m.compute_digest() + b_coin = least_significant_bit(h) + vertex.vote[t] = Vote(message=l, binary=b_coin) + else: + vertex.vote[t] = Vote(message=l, binary=0) + + +# --------------------------------------------------------------------------- +# Helper functions for vote weight computation +# --------------------------------------------------------------------------- + +def _highest_weight_message(vs: set[Vertex], graph: LamportGraph) -> Optional[Message]: + """Find the message with the highest voting weight in a set.""" + if not vs: + return None + best = max(vs, key=lambda v: graph.vertex_weight(v)) + return best.m + + +def _vote_weight_for(vs: set[Vertex], round_t: int, + target_msg: Optional[Message], target_binary: Optional[int], + graph: LamportGraph) -> int: + """Compute total voting weight for a specific vote (l, b) in a voting set.""" + total = 0 + for v in vs: + vote = v.vote.get(round_t) + if vote is None: + continue + msg_match = (vote.message is None and target_msg is None) or \ + (vote.message is not None and target_msg is not None and + vote.message.compute_digest() == target_msg.compute_digest()) + bin_match = vote.binary == target_binary + if msg_match and bin_match: + total = graph.weight_system.weight_sum(total, graph.vertex_weight(v)) + return total + + +def _vote_weight_for_binary(vs: set[Vertex], round_t: int, + target_binary: int, + graph: LamportGraph) -> int: + """Compute total voting weight for a specific binary value in a voting set.""" + total = 0 + for v in vs: + vote = v.vote.get(round_t) + if vote is not None and vote.binary == target_binary: + total = graph.weight_system.weight_sum(total, graph.vertex_weight(v)) + return total + + +# --------------------------------------------------------------------------- +# Algorithm 8: Longest Chain Rule (Section 5.8) +# --------------------------------------------------------------------------- + +def _update_leader_stream(leader_stream: dict[int, list[tuple[int, Message]]], + message: Message, round_number: int) -> None: + """Algorithm 8: update the leader stream with a new leader candidate. + + The longest chain rule keeps only the chain with the highest deciding + round for each round leader. When a new round leader is decided at a + higher deciding round, previous entries with lower deciding rounds are + replaced. + + Pseudocode: + 1: procedure LONG_CHAIN(set{(uint,MESSAGE)}:S, MESSAGE:m, uint:s) + 2: if there is no (l, t) ∈ S with t > s then + 3: S <- {(l, t) ∈ S | t < s} ∪ (s, m) + 4: end if + 5: return S + 6: end procedure + """ + if round_number not in leader_stream: + leader_stream[round_number] = [] + + entries = leader_stream[round_number] + + # Check if there's already an entry with a higher deciding round + has_higher = any(t > round_number for (t, _) in entries) + if has_higher: + return + + # Remove entries with lower deciding rounds, add new one + leader_stream[round_number] = [ + (t, m) for (t, m) in entries if t < round_number + ] + [(round_number, message)] diff --git a/src/crisis/weight.py b/src/crisis/weight.py new file mode 100644 index 0000000..b81ea41 --- /dev/null +++ b/src/crisis/weight.py @@ -0,0 +1,190 @@ +""" +Weight Systems (Section 3.1.1) + +Definition 3.1 (Weight system): Let MESSAGE be the metric space of all +messages and (W, ≤) a totally ordered set. Then the tuple (W, w, ⊕, c_min) +is a *weight system* if w is a function + + w : MESSAGE -> W (Eq. 3) + +that assigns an element of W to any message, c_min ∈ W is a constant called +the *weight threshold*, and ⊕ is a function + + ⊕ : W × W -> W (Eq. 4) + +called the *weight sum*, such that: + + - Tamper proof: w(m) >= c_min and m_hat ≠ m implies w(m_hat) < c_min + with high probability. + - Uniqueness: m ≠ m_hat implies w(m) ≠ w(m_hat) with high probability. + - Summability: (W, ⊕) is a totally ordered, abelian group. + +The weight w(m) is interpreted as the amount of voting power m holds to +influence total order generation. + +This module provides: + 1. An abstract WeightSystem protocol + 2. A concrete Proof-of-Work implementation (Hashcash-style) + +The PoW weight function counts leading zero bits of H(m), similar to Bitcoin's +difficulty mechanism (Nakamoto, 2009; Beck, 2002). +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Protocol + +from crisis.crypto import digest, count_leading_zero_bits +from crisis.message import Message + + +# --------------------------------------------------------------------------- +# Abstract weight system +# --------------------------------------------------------------------------- + +class WeightSystem(Protocol): + """Protocol defining the weight system interface (Definition 3.1). + + Any concrete weight system must provide: + - weight(): Compute w(m) for a message + - threshold: The minimum weight c_min + - weight_sum(): Compute ⊕ for two weights + """ + + @property + def threshold(self) -> int: + """c_min: the minimum weight threshold. + + Messages with weight below this are rejected. This prevents Sybil + attacks by ensuring every message requires a minimum investment. + """ + ... + + def weight(self, message: Message) -> int: + """w(m): compute the weight of a message. + + The weight represents the voting power of this message in the + consensus protocol. + """ + ... + + def weight_sum(self, a: int, b: int) -> int: + """⊕: combine two weights. + + Must form a totally ordered abelian group. + For our purposes, ordinary integer addition suffices. + """ + ... + + def is_valid_weight(self, message: Message) -> bool: + """Check whether w(m) >= c_min.""" + ... + + +# --------------------------------------------------------------------------- +# Proof-of-Work weight system +# --------------------------------------------------------------------------- + +@dataclass +class ProofOfWorkWeight: + """A Hashcash-style Proof-of-Work weight system. + + The weight of a message is the number of leading zero bits in H(m). + This is similar to Bitcoin's mining: finding a message whose hash starts + with k zero bits requires approximately 2^k hash evaluations on average. + + The nonce field of the message is used to search for a valid hash, + analogous to Bitcoin's block header nonce. + + Attributes: + min_leading_zeros: c_min -- minimum leading zero bits required. + A value of 1 means every message needs at least + 1 leading zero bit (50% of hashes qualify). + """ + min_leading_zeros: int = 1 + + @property + def threshold(self) -> int: + return self.min_leading_zeros + + def weight(self, message: Message) -> int: + """Count leading zero bits in H(m). + + More leading zeros = more work performed = higher voting weight. + """ + h = message.compute_digest() + return count_leading_zero_bits(h) + + def weight_sum(self, a: int, b: int) -> int: + """Simple integer addition for combining weights. + + This satisfies the abelian group requirement: (Z, +) is a totally + ordered abelian group with identity 0. + """ + return a + b + + def is_valid_weight(self, message: Message) -> bool: + """Check w(m) >= c_min.""" + return self.weight(message) >= self.threshold + + def mine_nonce(self, id_bytes: bytes, digests: tuple[bytes, ...], + payload: bytes, target_weight: int | None = None) -> Message: + """Search for a nonce that produces a message meeting the weight target. + + This is the "nonce grinding" step: try successive nonce values until + H(m) has enough leading zero bits. + + Args: + id_bytes: The virtual process id for this message. + digests: Causal acknowledgements (digests of prior messages). + payload: The application payload. + target_weight: Minimum weight to achieve. Defaults to c_min. + + Returns: + A Message with a valid nonce. + """ + if target_weight is None: + target_weight = self.threshold + + nonce_int = 0 + while True: + nonce = nonce_int.to_bytes(8, "big") + msg = Message(nonce=nonce, id=id_bytes, digests=digests, payload=payload) + if self.weight(msg) >= target_weight: + return msg + nonce_int += 1 + + +# --------------------------------------------------------------------------- +# Difficulty Oracle (Section 5.4, Definition 5.2) +# --------------------------------------------------------------------------- + +@dataclass +class DifficultyOracle: + """Maps round numbers to difficulty values (Definition 5.2). + + The difficulty oracle d : N -> W maps natural numbers (rounds) onto + weights. The value d_r := d(r) is called the *round r difficulty*. + + The difficulty is designed so that the overall voting weight per round + is bounded: + + lim sum(w_s^G / d_s) <= 6 (Eq. 8) + + for all time parameters t, where w_s^G is the overall voting weight of + last vertices in round s. + + Example 1 (paper): A fixed constant that does not change over time. + This is the simplest starting point for a PoC. + """ + constant_difficulty: int = 4 + + def difficulty(self, round_number: int) -> int: + """d(r): return the difficulty for round r. + + For this PoC we use a fixed constant (paper Example 1). + A production system might adapt this based on observed voting + weight, similar to Bitcoin's difficulty adjustment. + """ + return self.constant_difficulty diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_crypto.py b/tests/test_crypto.py new file mode 100644 index 0000000..375d9e8 --- /dev/null +++ b/tests/test_crypto.py @@ -0,0 +1,57 @@ +"""Tests for the crypto module (random oracle model).""" + +from crisis.crypto import ( + digest, digest_hex, verify_digest, + least_significant_bit, count_leading_zero_bits, + DIGEST_LENGTH, +) + + +def test_digest_returns_32_bytes(): + h = digest(b"hello") + assert len(h) == DIGEST_LENGTH == 32 + + +def test_digest_is_deterministic(): + assert digest(b"test") == digest(b"test") + + +def test_digest_different_inputs_different_outputs(): + assert digest(b"a") != digest(b"b") + + +def test_digest_hex_matches(): + h = digest(b"hello") + assert digest_hex(b"hello") == h.hex() + + +def test_verify_digest(): + h = digest(b"data") + assert verify_digest(b"data", h) + assert not verify_digest(b"other", h) + + +def test_least_significant_bit(): + # 0x00 -> LSB = 0, 0x01 -> LSB = 1 + assert least_significant_bit(b"\x00") == 0 + assert least_significant_bit(b"\x01") == 1 + assert least_significant_bit(b"\x02") == 0 + assert least_significant_bit(b"\x03") == 1 + assert least_significant_bit(b"\xff") == 1 + assert least_significant_bit(b"\xfe") == 0 + + +def test_count_leading_zero_bits(): + assert count_leading_zero_bits(b"\xff") == 0 + assert count_leading_zero_bits(b"\x7f") == 1 + assert count_leading_zero_bits(b"\x3f") == 2 + assert count_leading_zero_bits(b"\x00\xff") == 8 + assert count_leading_zero_bits(b"\x00\x00\x01") == 23 + assert count_leading_zero_bits(b"\x00") == 8 + + +def test_empty_digest_is_well_defined(): + """Paper: 'Acknowledgement of the empty string is defined as H(∅)'.""" + h = digest(b"") + assert len(h) == 32 + assert h == digest(b"") # deterministic diff --git a/tests/test_graph.py b/tests/test_graph.py new file mode 100644 index 0000000..cb1676d --- /dev/null +++ b/tests/test_graph.py @@ -0,0 +1,218 @@ +"""Tests for the Lamport graph with integrity checks.""" + +import os +import pytest +from crisis.crypto import digest +from crisis.graph import LamportGraph +from crisis.message import Message, Vertex, ID_LENGTH, NONCE_LENGTH +from crisis.weight import ProofOfWorkWeight + + +def make_id(name: str) -> bytes: + return digest(name.encode())[:ID_LENGTH] + + +def make_nonce(n: int = 0) -> bytes: + return n.to_bytes(NONCE_LENGTH, "big") + + +def make_graph(pow_zeros: int = 0) -> LamportGraph: + return LamportGraph(weight_system=ProofOfWorkWeight(min_leading_zeros=pow_zeros)) + + +class TestLamportGraphExtension: + + def test_extend_single_message(self): + g = make_graph() + msg = Message(nonce=make_nonce(), id=make_id("alice"), payload=b"hello") + v = g.extend(msg) + assert v is not None + assert g.vertex_count() == 1 + + def test_extend_chain(self): + """Messages from the same id must form a chain.""" + g = make_graph() + m1 = Message(nonce=make_nonce(0), id=make_id("alice"), payload=b"first") + v1 = g.extend(m1) + assert v1 is not None + + m2 = Message( + nonce=make_nonce(1), id=make_id("alice"), + digests=(m1.compute_digest(),), + payload=b"second" + ) + v2 = g.extend(m2) + assert v2 is not None + assert g.vertex_count() == 2 + + def test_reject_duplicate(self): + """No two equivalent vertices in the same graph.""" + g = make_graph() + msg = Message(nonce=make_nonce(), id=make_id("alice"), payload=b"x") + g.extend(msg) + v2 = g.extend(msg) + assert v2 is None # Rejected: duplicate + assert g.vertex_count() == 1 + + def test_reject_missing_reference(self): + """Digests must reference existing vertices.""" + g = make_graph() + fake_digest = digest(b"nonexistent") + msg = Message( + nonce=make_nonce(), id=make_id("alice"), + digests=(fake_digest,), payload=b"orphan" + ) + v = g.extend(msg) + assert v is None # Rejected + + def test_reject_broken_chain(self): + """Second message from same id must reference a same-id vertex.""" + g = make_graph() + id_a = make_id("alice") + id_b = make_id("bob") + + m1 = Message(nonce=make_nonce(0), id=id_a, payload=b"first") + g.extend(m1) + + m_bob = Message(nonce=make_nonce(1), id=id_b, payload=b"bob's msg") + g.extend(m_bob) + + # Alice's second message references bob but not herself -> rejected + m2 = Message( + nonce=make_nonce(2), id=id_a, + digests=(m_bob.compute_digest(),), + payload=b"broken chain" + ) + v = g.extend(m2) + assert v is None + + +class TestCausality: + + def _build_chain(self): + """Build a simple 3-message chain: m1 <- m2 <- m3.""" + g = make_graph() + id_a = make_id("alice") + m1 = Message(nonce=make_nonce(0), id=id_a, payload=b"m1") + v1 = g.extend(m1) + + m2 = Message(nonce=make_nonce(1), id=id_a, + digests=(m1.compute_digest(),), payload=b"m2") + v2 = g.extend(m2) + + m3 = Message(nonce=make_nonce(2), id=id_a, + digests=(m2.compute_digest(),), payload=b"m3") + v3 = g.extend(m3) + + return g, v1, v2, v3 + + def test_direct_causes(self): + g, v1, v2, v3 = self._build_chain() + causes_of_v3 = g.direct_causes(v3) + assert v2 in causes_of_v3 + assert v1 not in causes_of_v3 + + def test_direct_effects(self): + g, v1, v2, v3 = self._build_chain() + effects_of_v1 = g.direct_effects(v1) + assert v2 in effects_of_v1 + assert v3 not in effects_of_v1 # v3 is indirect + + def test_past(self): + """G_v: the past of v contains all its causes.""" + g, v1, v2, v3 = self._build_chain() + past_of_v3 = g.past(v3) + assert v1 in past_of_v3 + assert v2 in past_of_v3 + assert v3 in past_of_v3 # reflexive + + def test_future(self): + g, v1, v2, v3 = self._build_chain() + future_of_v1 = g.future(v1) + assert v2 in future_of_v1 + assert v3 in future_of_v1 + assert v1 in future_of_v1 # reflexive + + def test_is_cause_of(self): + g, v1, v2, v3 = self._build_chain() + assert g.is_cause_of(v1, v3) + assert g.is_cause_of(v1, v2) + assert not g.is_cause_of(v3, v1) + + def test_timelike(self): + g, v1, v2, v3 = self._build_chain() + assert g.are_timelike(v1, v3) + assert g.are_timelike(v3, v1) + + def test_spacelike(self): + """Two independent vertices are spacelike.""" + g = make_graph() + m_a = Message(nonce=make_nonce(0), id=make_id("alice"), payload=b"a") + m_b = Message(nonce=make_nonce(0), id=make_id("bob"), payload=b"b") + va = g.extend(m_a) + vb = g.extend(m_b) + assert g.are_spacelike(va, vb) + assert not g.are_timelike(va, vb) + + +class TestInvarianceOfThePast: + """Theorem 3.7: The past of equivalent vertices in two Lamport graphs + have the same cardinality.""" + + def test_past_invariance_simple(self): + """Same message in two different graphs has same-size past.""" + g1 = make_graph() + g2 = make_graph() + id_a = make_id("alice") + + m1 = Message(nonce=make_nonce(0), id=id_a, payload=b"genesis") + m2 = Message(nonce=make_nonce(1), id=id_a, + digests=(m1.compute_digest(),), payload=b"second") + + # Add to both graphs + g1.extend(m1) + v1_in_g1 = g1.extend(m2) + + g2.extend(m1) + v1_in_g2 = g2.extend(m2) + + # Past should be the same size + assert len(g1.past(v1_in_g1)) == len(g2.past(v1_in_g2)) + + +class TestMessageGeneration: + + def test_generate_first_message(self): + g = make_graph() + msg = g.generate_message(make_id("alice"), b"hello") + v = g.extend(msg) + assert v is not None + assert v.payload == b"hello" + + def test_generate_chain(self): + g = make_graph() + pid = make_id("alice") + m1 = g.generate_message(pid, b"first") + g.extend(m1) + + m2 = g.generate_message(pid, b"second") + v2 = g.extend(m2) + assert v2 is not None + # Second message should reference the first + assert m1.compute_digest() in m2.digests + + def test_generate_cross_references(self): + """Messages should reference vertices from other process ids.""" + g = make_graph() + pid_a = make_id("alice") + pid_b = make_id("bob") + + m_a = g.generate_message(pid_a, b"alice's msg") + g.extend(m_a) + + m_b = g.generate_message(pid_b, b"bob's msg") + g.extend(m_b) + + # Alice's second message should reference bob's message + m_a2 = g.generate_message(pid_a, b"alice second") + assert m_b.compute_digest() in m_a2.digests or m_a.compute_digest() in m_a2.digests diff --git a/tests/test_message.py b/tests/test_message.py new file mode 100644 index 0000000..dee516e --- /dev/null +++ b/tests/test_message.py @@ -0,0 +1,125 @@ +"""Tests for the message and vertex data structures.""" + +import pytest +from crisis.crypto import digest, DIGEST_LENGTH +from crisis.message import ( + Message, Vertex, Vote, + NONCE_LENGTH, ID_LENGTH, NUM_DIGESTS_LENGTH, + EMPTY_MESSAGE_DIGEST, +) + + +def make_id(name: str) -> bytes: + return digest(name.encode())[:ID_LENGTH] + + +def make_nonce(n: int = 0) -> bytes: + return n.to_bytes(NONCE_LENGTH, "big") + + +class TestMessage: + + def test_create_minimal_message(self): + msg = Message(nonce=make_nonce(), id=make_id("test"), digests=(), payload=b"") + assert msg.num_digests == 0 + + def test_nonce_length_validation(self): + with pytest.raises(ValueError, match="nonce"): + Message(nonce=b"\x00", id=make_id("x")) + + def test_id_length_validation(self): + with pytest.raises(ValueError, match="id"): + Message(nonce=make_nonce(), id=b"\x00") + + def test_digest_length_validation(self): + with pytest.raises(ValueError, match="digest"): + Message(nonce=make_nonce(), id=make_id("x"), + digests=(b"\x00",)) + + def test_serialize_roundtrip_deterministic(self): + msg = Message(nonce=make_nonce(42), id=make_id("proc1"), + digests=(), payload=b"hello world") + serialized = msg.serialize() + assert isinstance(serialized, bytes) + # Same message serializes the same way + assert msg.serialize() == serialized + + def test_compute_digest_deterministic(self): + msg = Message(nonce=make_nonce(), id=make_id("test"), payload=b"data") + d1 = msg.compute_digest() + d2 = msg.compute_digest() + assert d1 == d2 + assert len(d1) == DIGEST_LENGTH + + def test_different_messages_different_digests(self): + m1 = Message(nonce=make_nonce(1), id=make_id("a"), payload=b"x") + m2 = Message(nonce=make_nonce(2), id=make_id("a"), payload=b"x") + assert m1.compute_digest() != m2.compute_digest() + + def test_message_with_digests(self): + parent = Message(nonce=make_nonce(), id=make_id("a"), payload=b"parent") + child = Message( + nonce=make_nonce(1), id=make_id("a"), + digests=(parent.compute_digest(),), + payload=b"child" + ) + assert child.num_digests == 1 + assert child.digests[0] == parent.compute_digest() + + def test_message_is_immutable(self): + msg = Message(nonce=make_nonce(), id=make_id("x"), payload=b"y") + with pytest.raises(AttributeError): + msg.nonce = b"\x00" * NONCE_LENGTH + + +class TestVertex: + + def test_vertex_wraps_message(self): + msg = Message(nonce=make_nonce(), id=make_id("proc"), payload=b"data") + v = Vertex(m=msg) + assert v.nonce == msg.nonce + assert v.id == msg.id + assert v.payload == msg.payload + assert v.digests == msg.digests + + def test_vertex_default_state(self): + msg = Message(nonce=make_nonce(), id=make_id("x")) + v = Vertex(m=msg) + assert v.round is None + assert v.is_last is None + assert v.svp == [] + assert v.vote == {} + assert v.total_position is None + + def test_vertex_equivalence(self): + """Definition 3.3: two vertices are equivalent if v.m = v_hat.m""" + msg = Message(nonce=make_nonce(), id=make_id("x"), payload=b"same") + v1 = Vertex(m=msg) + v2 = Vertex(m=msg) + assert v1.equivalent_to(v2) + assert v1 == v2 + assert hash(v1) == hash(v2) + + def test_vertex_non_equivalence(self): + m1 = Message(nonce=make_nonce(1), id=make_id("x")) + m2 = Message(nonce=make_nonce(2), id=make_id("x")) + v1 = Vertex(m=m1) + v2 = Vertex(m=m2) + assert not v1.equivalent_to(v2) + assert v1 != v2 + + +class TestVote: + + def test_vote_undecided(self): + v = Vote(message=None, binary=None) + assert "∅" in repr(v) + assert "⊥" in repr(v) + + def test_vote_with_message(self): + msg = Message(nonce=make_nonce(), id=make_id("x")) + v = Vote(message=msg, binary=1) + assert v.binary == 1 + + def test_empty_message_digest(self): + assert EMPTY_MESSAGE_DIGEST == digest(b"") diff --git a/tests/test_order.py b/tests/test_order.py new file mode 100644 index 0000000..239b0e6 --- /dev/null +++ b/tests/test_order.py @@ -0,0 +1,126 @@ +"""Tests for total order computation.""" + +from crisis.crypto import digest +from crisis.graph import LamportGraph +from crisis.message import Message, ID_LENGTH, NONCE_LENGTH +from crisis.order import LeaderStream, compute_order, _kahns_total_order +from crisis.weight import ProofOfWorkWeight + + +def make_id(name: str) -> bytes: + return digest(name.encode())[:ID_LENGTH] + + +def make_nonce(n: int = 0) -> bytes: + return n.to_bytes(NONCE_LENGTH, "big") + + +def make_graph() -> LamportGraph: + return LamportGraph(weight_system=ProofOfWorkWeight(min_leading_zeros=0)) + + +class TestLeaderStream: + + def test_empty_stream(self): + ls = LeaderStream() + assert ls.max_round() == -1 + assert ls.get_leader(0) is None + + def test_add_leader(self): + ls = LeaderStream() + msg = Message(nonce=make_nonce(), id=make_id("leader"), payload=b"L") + updated = ls.update(0, 1, msg) + assert updated is True + assert ls.get_leader(0) is msg + + def test_higher_deciding_round_replaces(self): + ls = LeaderStream() + m1 = Message(nonce=make_nonce(1), id=make_id("l1"), payload=b"old") + m2 = Message(nonce=make_nonce(2), id=make_id("l2"), payload=b"new") + + ls.update(0, 1, m1) + ls.update(0, 2, m2) + + assert ls.get_leader(0) is m2 + + def test_lower_deciding_round_rejected(self): + ls = LeaderStream() + m1 = Message(nonce=make_nonce(1), id=make_id("l1"), payload=b"first") + m2 = Message(nonce=make_nonce(2), id=make_id("l2"), payload=b"late") + + ls.update(0, 5, m1) + updated = ls.update(0, 3, m2) + + assert updated is False + assert ls.get_leader(0) is m1 + + def test_all_leaders_sorted(self): + ls = LeaderStream() + m0 = Message(nonce=make_nonce(0), id=make_id("l0"), payload=b"r0") + m1 = Message(nonce=make_nonce(1), id=make_id("l1"), payload=b"r1") + m2 = Message(nonce=make_nonce(2), id=make_id("l2"), payload=b"r2") + + ls.update(2, 3, m2) + ls.update(0, 1, m0) + ls.update(1, 2, m1) + + leaders = ls.all_leaders() + rounds = [r for r, _ in leaders] + assert rounds == sorted(rounds) + + +class TestKahnsAlgorithm: + + def test_empty_input(self): + g = make_graph() + result = _kahns_total_order([], g) + assert result == [] + + def test_single_vertex(self): + g = make_graph() + msg = Message(nonce=make_nonce(), id=make_id("x"), payload=b"only") + v = g.extend(msg) + result = _kahns_total_order([v], g) + assert result == [v] + + def test_chain_order(self): + """A chain should be ordered causes-first.""" + g = make_graph() + pid = make_id("alice") + m1 = Message(nonce=make_nonce(0), id=pid, payload=b"first") + v1 = g.extend(m1) + + m2 = Message(nonce=make_nonce(1), id=pid, + digests=(m1.compute_digest(),), payload=b"second") + v2 = g.extend(m2) + + m3 = Message(nonce=make_nonce(2), id=pid, + digests=(m2.compute_digest(),), payload=b"third") + v3 = g.extend(m3) + + result = _kahns_total_order([v1, v2, v3], g) + # Causes come first + assert result.index(v1) < result.index(v2) + assert result.index(v2) < result.index(v3) + + def test_respects_causality(self): + """Total order must be consistent with causal order.""" + g = make_graph() + m_a = Message(nonce=make_nonce(0), id=make_id("alice"), payload=b"a") + va = g.extend(m_a) + + m_b = Message(nonce=make_nonce(0), id=make_id("bob"), payload=b"b") + vb = g.extend(m_b) + + # Carol references both alice and bob + m_c = Message( + nonce=make_nonce(1), id=make_id("carol"), + digests=(m_a.compute_digest(), m_b.compute_digest()), + payload=b"c" + ) + vc = g.extend(m_c) + + result = _kahns_total_order([va, vb, vc], g) + # Carol must come after both alice and bob + assert result.index(va) < result.index(vc) + assert result.index(vb) < result.index(vc) diff --git a/tests/test_rounds.py b/tests/test_rounds.py new file mode 100644 index 0000000..96feb6f --- /dev/null +++ b/tests/test_rounds.py @@ -0,0 +1,110 @@ +"""Tests for virtual synchronous rounds.""" + +from crisis.crypto import digest +from crisis.graph import LamportGraph +from crisis.message import Message, ID_LENGTH, NONCE_LENGTH +from crisis.rounds import compute_rounds, max_round, last_vertices_in_round, vertices_in_round +from crisis.weight import ProofOfWorkWeight, DifficultyOracle + + +def make_id(name: str) -> bytes: + return digest(name.encode())[:ID_LENGTH] + + +def make_nonce(n: int = 0) -> bytes: + return n.to_bytes(NONCE_LENGTH, "big") + + +def make_graph() -> LamportGraph: + return LamportGraph(weight_system=ProofOfWorkWeight(min_leading_zeros=0)) + + +class TestRoundComputation: + + def test_single_vertex_round_zero(self): + """A single vertex with no causes is in round 0.""" + g = make_graph() + msg = Message(nonce=make_nonce(), id=make_id("alice"), payload=b"genesis") + v = g.extend(msg) + compute_rounds(g, DifficultyOracle(constant_difficulty=1)) + assert v.round == 0 + + def test_single_vertex_is_last(self): + """Round 0 vertices are always 'last' (bootstrapping).""" + g = make_graph() + msg = Message(nonce=make_nonce(), id=make_id("alice")) + v = g.extend(msg) + compute_rounds(g, DifficultyOracle(constant_difficulty=1)) + assert v.is_last is True + + def test_chain_grows_rounds(self): + """A chain of messages should produce increasing round numbers.""" + g = make_graph() + pid = make_id("alice") + difficulty = DifficultyOracle(constant_difficulty=0) # Low difficulty + + # Create a chain + prev_msg = None + vertices = [] + for i in range(5): + digests = (prev_msg.compute_digest(),) if prev_msg else () + msg = Message(nonce=make_nonce(i), id=pid, digests=digests, payload=f"msg{i}".encode()) + v = g.extend(msg) + vertices.append(v) + prev_msg = msg + + compute_rounds(g, difficulty, connectivity_k=0) + + # All should have round numbers assigned + for v in vertices: + assert v.round is not None + + # First vertex is round 0 + assert vertices[0].round == 0 + + def test_max_round_empty_graph(self): + g = make_graph() + assert max_round(g) == 0 + + def test_max_round_with_vertices(self): + g = make_graph() + msg = Message(nonce=make_nonce(), id=make_id("x")) + g.extend(msg) + compute_rounds(g, DifficultyOracle(constant_difficulty=1)) + assert max_round(g) == 0 + + def test_last_vertices_in_round(self): + g = make_graph() + msg = Message(nonce=make_nonce(), id=make_id("alice")) + g.extend(msg) + compute_rounds(g, DifficultyOracle(constant_difficulty=1)) + lasts = last_vertices_in_round(g, 0) + assert len(lasts) == 1 + + def test_multiple_ids_same_round(self): + """Multiple independent vertices are all in round 0.""" + g = make_graph() + for name in ["alice", "bob", "carol"]: + msg = Message(nonce=make_nonce(), id=make_id(name), payload=name.encode()) + g.extend(msg) + + compute_rounds(g, DifficultyOracle(constant_difficulty=1)) + + r0 = vertices_in_round(g, 0) + assert len(r0) == 3 + + def test_round_invariance(self): + """Proposition 5.3: equivalent vertices in different graphs have same round.""" + g1 = make_graph() + g2 = make_graph() + difficulty = DifficultyOracle(constant_difficulty=1) + + msg = Message(nonce=make_nonce(), id=make_id("alice"), payload=b"genesis") + v1 = g1.extend(msg) + v2 = g2.extend(msg) + + compute_rounds(g1, difficulty) + compute_rounds(g2, difficulty) + + assert v1.round == v2.round + assert v1.is_last == v2.is_last diff --git a/tests/test_simulation.py b/tests/test_simulation.py new file mode 100644 index 0000000..a19705f --- /dev/null +++ b/tests/test_simulation.py @@ -0,0 +1,55 @@ +"""Integration test: run the full simulation and verify basic properties.""" + +from crisis.demo import Simulation + + +class TestSimulation: + + def test_simulation_runs(self): + """The simulation should complete without errors.""" + sim = Simulation(num_honest=3, num_byzantine=0, seed=42) + results = sim.run(num_steps=5, verbose=False) + assert len(results) == 5 + + def test_graphs_grow(self): + """Each step should add messages to the graphs.""" + sim = Simulation(num_honest=2, seed=42) + sim.run(num_steps=3, verbose=False) + for node in sim.nodes: + assert node.graph.vertex_count() > 0 + + def test_honest_nodes_same_graph_size(self): + """All honest nodes should have the same number of vertices + (since all messages are delivered to all nodes).""" + sim = Simulation(num_honest=3, seed=42) + sim.run(num_steps=5, verbose=False) + sizes = [n.graph.vertex_count() for n in sim.nodes] + assert all(s == sizes[0] for s in sizes) + + def test_rounds_are_computed(self): + """After running, vertices should have round numbers.""" + sim = Simulation(num_honest=3, seed=42) + sim.run(num_steps=5, verbose=False) + for node in sim.nodes: + for v in node.graph.all_vertices(): + assert v.round is not None + + def test_with_byzantine_node(self): + """Simulation should handle byzantine nodes without crashing.""" + sim = Simulation(num_honest=3, num_byzantine=1, seed=42) + results = sim.run(num_steps=5, verbose=False) + assert len(results) == 5 + + def test_deterministic_with_seed(self): + """Same seed should produce the same results.""" + sim1 = Simulation(num_honest=3, seed=123) + r1 = sim1.run(num_steps=3, verbose=False) + + sim2 = Simulation(num_honest=3, seed=123) + r2 = sim2.run(num_steps=3, verbose=False) + + # Same number of messages at each step + for s1, s2 in zip(r1, r2): + assert len(s1["new_messages"]) == len(s2["new_messages"]) + for ns1, ns2 in zip(s1["node_states"], s2["node_states"]): + assert ns1["vertices"] == ns2["vertices"] diff --git a/tests/test_weight.py b/tests/test_weight.py new file mode 100644 index 0000000..e318312 --- /dev/null +++ b/tests/test_weight.py @@ -0,0 +1,78 @@ +"""Tests for the weight system and difficulty oracle.""" + +from crisis.crypto import digest +from crisis.message import Message, ID_LENGTH, NONCE_LENGTH +from crisis.weight import ProofOfWorkWeight, DifficultyOracle + + +def make_id(name: str) -> bytes: + return digest(name.encode())[:ID_LENGTH] + + +def make_nonce(n: int = 0) -> bytes: + return n.to_bytes(NONCE_LENGTH, "big") + + +class TestProofOfWorkWeight: + + def test_weight_is_non_negative(self): + ws = ProofOfWorkWeight(min_leading_zeros=0) + msg = Message(nonce=make_nonce(), id=make_id("x"), payload=b"test") + assert ws.weight(msg) >= 0 + + def test_weight_sum_is_additive(self): + ws = ProofOfWorkWeight() + assert ws.weight_sum(3, 5) == 8 + assert ws.weight_sum(0, 0) == 0 + + def test_threshold(self): + ws = ProofOfWorkWeight(min_leading_zeros=2) + assert ws.threshold == 2 + + def test_is_valid_weight_with_zero_threshold(self): + ws = ProofOfWorkWeight(min_leading_zeros=0) + msg = Message(nonce=make_nonce(), id=make_id("x")) + assert ws.is_valid_weight(msg) # Everything passes with 0 + + def test_mine_nonce_finds_valid_message(self): + ws = ProofOfWorkWeight(min_leading_zeros=1) + msg = ws.mine_nonce( + id_bytes=make_id("miner"), + digests=(), + payload=b"test payload", + target_weight=1 + ) + assert ws.weight(msg) >= 1 + assert ws.is_valid_weight(msg) + + def test_different_nonces_different_weights(self): + """Uniqueness property: different messages have different weights (w.h.p.).""" + ws = ProofOfWorkWeight() + weights = set() + for i in range(20): + msg = Message(nonce=make_nonce(i), id=make_id("x"), payload=b"same") + weights.add(ws.weight(msg)) + # Not all the same (with overwhelming probability) + assert len(weights) > 1 + + def test_tamper_proof(self): + """Changing a message should change its weight (w.h.p.).""" + ws = ProofOfWorkWeight() + msg1 = Message(nonce=make_nonce(42), id=make_id("x"), payload=b"original") + msg2 = Message(nonce=make_nonce(42), id=make_id("x"), payload=b"tampered") + # Weights differ because digests differ + # (this is probabilistic, but extremely likely) + assert msg1.compute_digest() != msg2.compute_digest() + + +class TestDifficultyOracle: + + def test_constant_difficulty(self): + d = DifficultyOracle(constant_difficulty=5) + assert d.difficulty(0) == 5 + assert d.difficulty(100) == 5 + assert d.difficulty(999) == 5 + + def test_default_difficulty(self): + d = DifficultyOracle() + assert d.difficulty(0) == 4