Initial implementation of the Crisis protocol (Richter, 2019)

Complete Python PoC of "Probabilistically Self Organizing Total Order
in Unstructured P2P Networks". Implements all 10 algorithms from the paper:
message generation, integrity checks, Lamport graphs, virtual synchronous
rounds, safe voting patterns, virtual leader election (BA*), longest chain
rule, total order via Kahn's algorithm, and push/pull gossip.

Includes simulation harness, full node binary, and 72 passing tests.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
saymrwulf 2026-04-23 13:20:30 +02:00
commit 1df4790fb4
22 changed files with 3987 additions and 0 deletions

26
.gitignore vendored Normal file
View file

@ -0,0 +1,26 @@
# Python
__pycache__/
*.py[cod]
*$py.class
*.egg-info/
dist/
build/
*.egg
# Virtual environment
.venv/
# IDE
.idea/
.vscode/
*.swp
*.swo
# OS
.DS_Store
Thumbs.db
# Testing
.pytest_cache/
.coverage
htmlcov/

Binary file not shown.

30
pyproject.toml Normal file
View file

@ -0,0 +1,30 @@
[project]
name = "crisis"
version = "0.1.0"
description = "Crisis: Probabilistically Self Organizing Total Order in Unstructured P2P Networks"
readme = "README.md"
requires-python = ">=3.11"
license = "CC-BY-4.0"
authors = [
{ name = "Mirco Richter (paper)", email = "mirco.richter@mailbox.org" },
]
dependencies = [
"networkx>=3.0",
]
[project.optional-dependencies]
dev = [
"pytest>=8.0",
"rich>=13.0",
]
[project.scripts]
crisis-node = "crisis.node:main"
crisis-demo = "crisis.demo:main"
[build-system]
requires = ["setuptools>=68.0"]
build-backend = "setuptools.build_meta"
[tool.pytest.ini_options]
testpaths = ["tests"]

21
src/crisis/__init__.py Normal file
View file

@ -0,0 +1,21 @@
"""
Crisis: Probabilistically Self Organizing Total Order in Unstructured P2P Networks
A Python implementation of the Crisis protocol described by Mirco Richter (2019).
The protocol achieves total order on messages in fully open, unstructured
Peer-to-Peer networks through virtual voting -- votes are never sent explicitly
but are deduced from the causal relationships between messages encoded in
Lamport graphs.
Key components:
- crypto: Random oracle model (SHA-256 hash function)
- message: Message and Vertex data structures
- weight: Weight systems (PoW-based Sybil resistance)
- graph: Lamport graphs with integrity checking
- rounds: Virtual synchronous rounds
- voting: Safe voting patterns and virtual leader election (BA*)
- order: Total order via leader stream and topological sorting
- gossip: Push/pull gossip for member discovery and message dissemination
- node: Full Crisis node tying all components together
"""

104
src/crisis/crypto.py Normal file
View file

@ -0,0 +1,104 @@
"""
Random Oracle Model (Section 2.1)
We work in the random oracle model, assuming the existence of a cryptographic
hash function that behaves like a random oracle:
H : {0,1}* -> {0,1}^p (Eq. 1)
We use SHA-256 as our concrete instantiation. H is assumed to be collision-,
preimage-, and second-preimage-resistant.
We call H(b) the *digest* of the binary string b.
"""
import hashlib
from typing import Union
# The digest length in bytes (SHA-256 produces 32 bytes = 256 bits).
DIGEST_LENGTH = 32
def digest(data: Union[bytes, bytearray]) -> bytes:
"""Compute the SHA-256 digest of arbitrary binary data.
This is the core random oracle H used throughout the protocol.
Every reference to "the digest of" a message or byte string in the
paper maps to this function.
Returns:
32-byte digest (256 bits).
"""
return hashlib.sha256(data).digest()
def digest_hex(data: Union[bytes, bytearray]) -> str:
"""Convenience: return the digest as a hex string for display."""
return digest(data).hex()
def verify_digest(data: bytes, expected: bytes) -> bool:
"""Check that H(data) equals the expected digest."""
return digest(data) == expected
# ---------------------------------------------------------------------------
# Least significant bit helper (used in the virtual coin flip, Algorithm 7)
# ---------------------------------------------------------------------------
def least_significant_bit(h: bytes) -> int:
"""Return the least significant bit of a hash value.
Used in Algorithm 7 (virtual leader election) for the "genuine coin flip"
stage, where the LSB of H(v_hat.m) determines the binary vote.
The paper defines:
b_coin := lsb(H(x.m)) for max weight x in S
"""
return h[-1] & 1
# ---------------------------------------------------------------------------
# Proof-of-Work helpers (used by the weight system, Section 3.1.1)
# ---------------------------------------------------------------------------
def count_leading_zero_bits(h: bytes) -> int:
"""Count the number of leading zero bits in a hash value.
This is the standard measure of proof-of-work difficulty: a hash with
k leading zero bits required roughly 2^k hash evaluations to find.
"""
count = 0
for byte in h:
if byte == 0:
count += 8
else:
# Count leading zeros in this byte
count += (byte ^ 0xFF).bit_length() - (255 - byte).bit_length()
# Simpler: count leading zeros via bit tricks
for bit_pos in range(7, -1, -1):
if byte & (1 << bit_pos):
return count
count += 1
break
return count
def count_leading_zero_bits(h: bytes) -> int:
"""Count the number of leading zero bits in a hash value.
A hash with k leading zero bits required roughly 2^k evaluations to find.
Used by the PoW weight function to assign weight to messages.
"""
count = 0
for byte in h:
if byte == 0:
count += 8
continue
# Count leading zeros in this non-zero byte
for bit_pos in range(7, -1, -1):
if byte & (1 << bit_pos):
return count
count += 1
break
return count

356
src/crisis/demo.py Normal file
View file

@ -0,0 +1,356 @@
"""
Demonstration / Simulation Harness
This module provides a deterministic, single-process simulation of the Crisis
protocol with N virtual nodes. It is designed as the foundation for a lecture
series: each phase of the protocol can be observed step by step.
The simulation bypasses the network layer entirely -- messages are delivered
directly between in-memory Lamport graphs. This makes the consensus mechanism
visible without network noise.
Usage:
python -m crisis.demo # Run the full demo
python -m crisis.demo --nodes 5 # 5 honest nodes
python -m crisis.demo --byzantine 1 # 1 byzantine node
python -m crisis.demo --rounds 10 # Run for 10 message rounds
"""
from __future__ import annotations
import os
import random
import time
from dataclasses import dataclass, field
from typing import Optional
from crisis.crypto import digest
from crisis.graph import LamportGraph
from crisis.message import Message, Vertex, ID_LENGTH, NONCE_LENGTH
from crisis.order import LeaderStream, compute_order
from crisis.rounds import compute_rounds, max_round, last_vertices_in_round
from crisis.voting import compute_safe_voting_pattern, compute_virtual_leader_election
from crisis.weight import ProofOfWorkWeight, DifficultyOracle
# ---------------------------------------------------------------------------
# Simulated Node
# ---------------------------------------------------------------------------
@dataclass
class SimulatedNode:
"""A simulated Crisis node running in-memory.
Each node has its own Lamport graph and process id. Messages are
exchanged by directly sharing Message objects (no serialization needed).
"""
name: str
process_id: bytes
graph: LamportGraph
leader_stream: LeaderStream = field(default_factory=LeaderStream)
is_byzantine: bool = False
messages_created: int = 0
def generate_message(self, payload: str) -> Message:
"""Generate a new message from this node."""
self.messages_created += 1
return self.graph.generate_message(
self.process_id,
payload.encode(),
)
# ---------------------------------------------------------------------------
# Simulation Engine
# ---------------------------------------------------------------------------
class Simulation:
"""Deterministic simulation of N Crisis nodes.
Runs the protocol in lock-step rounds:
1. Each node generates a message
2. Messages are gossiped (delivered to all nodes)
3. Consensus is computed on each node
4. State is displayed
This allows observing how the Lamport graph grows, rounds emerge,
and total order converges.
"""
def __init__(self, num_honest: int = 3, num_byzantine: int = 0,
pow_zeros: int = 0, difficulty: int = 2,
connectivity_k: int = 1, seed: int = 42):
self.difficulty_oracle = DifficultyOracle(constant_difficulty=difficulty)
self.connectivity_k = connectivity_k
self.weight_system = ProofOfWorkWeight(min_leading_zeros=pow_zeros)
self.seed = seed
random.seed(seed)
# Create nodes
self.nodes: list[SimulatedNode] = []
for i in range(num_honest):
name = f"honest-{i}"
pid = digest(name.encode())[:ID_LENGTH]
graph = LamportGraph(weight_system=self.weight_system)
self.nodes.append(SimulatedNode(
name=name, process_id=pid, graph=graph
))
for i in range(num_byzantine):
name = f"byzantine-{i}"
pid = digest(name.encode())[:ID_LENGTH]
graph = LamportGraph(weight_system=self.weight_system)
self.nodes.append(SimulatedNode(
name=name, process_id=pid, graph=graph, is_byzantine=True
))
self.step_count = 0
self.all_messages: list[Message] = []
def step(self) -> dict:
"""Execute one simulation step.
Returns a dict with step results for display.
"""
self.step_count += 1
step_results = {
"step": self.step_count,
"new_messages": [],
"node_states": [],
}
# Phase 1: Each node generates a message
new_messages: list[tuple[SimulatedNode, Message]] = []
for node in self.nodes:
if node.is_byzantine:
msg = self._byzantine_message(node)
else:
payload = f"step-{self.step_count}-{node.name}"
msg = node.generate_message(payload)
if msg is not None:
new_messages.append((node, msg))
step_results["new_messages"].append({
"from": node.name,
"digest": msg.compute_digest().hex()[:12],
"weight": self.weight_system.weight(msg),
"payload": msg.payload.decode(errors="replace"),
})
# Phase 2: Gossip -- deliver all messages to all nodes
for source_node, msg in new_messages:
self.all_messages.append(msg)
for target_node in self.nodes:
# Deliver to all nodes (including source, for consistency)
target_node.graph.extend(msg)
# Also re-deliver older messages that nodes might be missing
# (simulates pull gossip catching up)
for msg in self.all_messages:
for node in self.nodes:
node.graph.extend(msg) # extend() is idempotent (integrity check)
# Phase 3: Compute consensus on each node
for node in self.nodes:
compute_rounds(node.graph, self.difficulty_oracle, self.connectivity_k)
for vertex in node.graph.all_vertices():
if vertex.is_last:
compute_safe_voting_pattern(
vertex, node.graph, self.difficulty_oracle,
self.connectivity_k
)
leader_dict: dict[int, list[tuple[int, Message]]] = {}
for vertex in node.graph.all_vertices():
if vertex.svp:
compute_virtual_leader_election(
vertex, node.graph, self.difficulty_oracle,
self.connectivity_k, leader_dict
)
for round_num, entries in leader_dict.items():
for deciding_round, leader_msg in entries:
node.leader_stream.update(round_num, deciding_round, leader_msg)
ordered = compute_order(node.graph, node.leader_stream)
mr = max_round(node.graph)
step_results["node_states"].append({
"name": node.name,
"vertices": node.graph.vertex_count(),
"max_round": mr,
"leaders": len(node.leader_stream.leaders),
"ordered": len(ordered),
"is_byzantine": node.is_byzantine,
})
return step_results
def _byzantine_message(self, node: SimulatedNode) -> Optional[Message]:
"""Generate a byzantine message.
Byzantine nodes can exhibit several faulty behaviors:
- Mutations: same id, forking the causal chain
- Strategic distribution: different messages to different peers
- Time travel: referencing old rounds
For this demo, we generate a message with a random payload that
may not reference the latest same-id message (creating a mutation).
"""
payload = f"byz-{self.step_count}-{node.name}-{random.randint(0, 999)}"
# 50% chance of creating a mutation (not referencing last same-id vertex)
if random.random() < 0.5 and node.graph.vertex_count() > 0:
# Pick random digests instead of following the chain
available = list(node.graph.vertices.keys())
num_refs = min(random.randint(1, 3), len(available))
digests = tuple(random.sample(available, num_refs))
nonce = os.urandom(NONCE_LENGTH)
return Message(
nonce=nonce, id=node.process_id,
digests=digests, payload=payload.encode()
)
else:
return node.generate_message(payload)
def run(self, num_steps: int = 10, verbose: bool = True) -> list[dict]:
"""Run the simulation for a number of steps."""
results = []
for _ in range(num_steps):
result = self.step()
results.append(result)
if verbose:
_print_step(result)
if verbose:
_print_convergence_summary(self)
return results
# ---------------------------------------------------------------------------
# Display functions
# ---------------------------------------------------------------------------
def _print_step(result: dict) -> None:
"""Print the results of a simulation step."""
print(f"\n{'='*70}")
print(f" Step {result['step']}")
print(f"{'='*70}")
print(f"\n New messages:")
for msg in result["new_messages"]:
print(f" {msg['from']:>15s} -> {msg['digest']} "
f"w={msg['weight']} {msg['payload'][:40]}")
print(f"\n Node states:")
print(f" {'Name':>15s} {'Vertices':>8s} {'Round':>5s} "
f"{'Leaders':>7s} {'Ordered':>7s}")
print(f" {'-'*15} {'-'*8} {'-'*5} {'-'*7} {'-'*7}")
for ns in result["node_states"]:
byz = " [BYZ]" if ns["is_byzantine"] else ""
print(f" {ns['name']:>15s} {ns['vertices']:>8d} "
f"{ns['max_round']:>5d} {ns['leaders']:>7d} "
f"{ns['ordered']:>7d}{byz}")
def _print_convergence_summary(sim: Simulation) -> None:
"""Print a summary showing whether honest nodes have converged."""
print(f"\n{'='*70}")
print(f" Convergence Summary")
print(f"{'='*70}")
honest_nodes = [n for n in sim.nodes if not n.is_byzantine]
# Check if all honest nodes have the same total order
orders = []
for node in honest_nodes:
ordered = compute_order(node.graph, node.leader_stream)
order_digests = [v.message_digest.hex()[:12] for v in ordered]
orders.append(order_digests)
if len(orders) >= 2:
# Compare pairwise
all_agree = all(o == orders[0] for o in orders[1:])
if all_agree:
print(f"\n All {len(honest_nodes)} honest nodes AGREE "
f"on total order ({len(orders[0])} messages)")
else:
print(f"\n Honest nodes have DIVERGENT total orders "
f"(convergence in progress)")
for i, (node, order) in enumerate(zip(honest_nodes, orders)):
print(f" {node.name}: {len(order)} ordered messages")
# Show the total order from the first honest node
if orders and orders[0]:
print(f"\n Total order (from {honest_nodes[0].name}):")
first_node = honest_nodes[0]
ordered = compute_order(first_node.graph, first_node.leader_stream)
for v in ordered[:20]: # Show first 20
print(f" pos={v.total_position:>3d} "
f"hash={v.message_digest.hex()[:12]} "
f"r={v.round} "
f"payload={v.payload.decode(errors='replace')[:40]}")
if len(ordered) > 20:
print(f" ... and {len(ordered) - 20} more")
# Show leader stream
if honest_nodes:
ls = honest_nodes[0].leader_stream
if ls.leaders:
print(f"\n Leader stream ({honest_nodes[0].name}):")
for round_num, (dec_round, msg) in sorted(ls.leaders.items()):
print(f" round {round_num}: leader={msg.compute_digest().hex()[:12]} "
f"decided_in_round={dec_round}")
print()
# ---------------------------------------------------------------------------
# CLI entry point
# ---------------------------------------------------------------------------
def main():
import argparse
parser = argparse.ArgumentParser(
description="Crisis Protocol Simulation",
epilog="Demonstrates probabilistic total order convergence"
)
parser.add_argument("--nodes", type=int, default=3,
help="Number of honest nodes (default: 3)")
parser.add_argument("--byzantine", type=int, default=0,
help="Number of byzantine nodes (default: 0)")
parser.add_argument("--steps", type=int, default=10,
help="Number of simulation steps (default: 10)")
parser.add_argument("--pow-zeros", type=int, default=0,
help="Min PoW leading zeros (default: 0 = no PoW)")
parser.add_argument("--difficulty", type=int, default=2,
help="Difficulty oracle constant (default: 2)")
parser.add_argument("--seed", type=int, default=42,
help="Random seed for reproducibility (default: 42)")
args = parser.parse_args()
print(f"Crisis Protocol Simulation")
print(f" Honest nodes: {args.nodes}")
print(f" Byzantine nodes: {args.byzantine}")
print(f" Steps: {args.steps}")
print(f" PoW zeros: {args.pow_zeros}")
print(f" Difficulty: {args.difficulty}")
print(f" Seed: {args.seed}")
sim = Simulation(
num_honest=args.nodes,
num_byzantine=args.byzantine,
pow_zeros=args.pow_zeros,
difficulty=args.difficulty,
seed=args.seed,
)
sim.run(num_steps=args.steps)
if __name__ == "__main__":
main()

414
src/crisis/gossip.py Normal file
View file

@ -0,0 +1,414 @@
"""
Communication (Section 4)
Crisis is built on top of two simple push & pull gossip protocols:
1. Member discovery gossip (Algorithm 3)
2. Message gossip (Algorithm 4)
These are well suited for communication in unstructured P2P networks.
All the system needs is a way to distribute messages in a byzantine-prone
environment.
4.3 Member Discovery Gossip (Algorithm 3):
Each process maintains a partial view Π_j(t) of the network.
Periodically, a process pushes its neighbor list to a random peer
and pulls neighbor lists from other peers.
4.4 Message Gossip (Algorithm 4):
Processes push unordered messages to random peers and pull missing
messages. Already ordered messages are pushed only as responses
to pull requests (stop criterion for push gossip).
This module implements both gossip protocols using asyncio for the
"run in parallel forever" loops described in the paper.
"""
from __future__ import annotations
import asyncio
import json
import logging
import os
import random
import struct
from dataclasses import dataclass, field
from typing import Optional
from crisis.graph import LamportGraph
from crisis.message import Message, Vertex, NONCE_LENGTH, ID_LENGTH
from crisis.crypto import DIGEST_LENGTH
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Peer identity and network view
# ---------------------------------------------------------------------------
@dataclass
class PeerInfo:
"""Information about a known peer in the network."""
host: str
port: int
process_id: bytes = b"" # The peer's virtual process id, if known
@property
def address(self) -> tuple[str, int]:
return (self.host, self.port)
def __hash__(self):
return hash((self.host, self.port))
def __eq__(self, other):
if not isinstance(other, PeerInfo):
return NotImplemented
return self.host == other.host and self.port == other.port
@dataclass
class NetworkView:
"""Π_j(t): a process's partial view of the network at time t.
"No process must know the entire system and each j ∈ Π(t) might
have a partial view Π_j(t) only." (Section 4.3)
"""
peers: set[PeerInfo] = field(default_factory=set)
max_peers: int = 50 # Limit to prevent unbounded growth
def add_peer(self, peer: PeerInfo) -> None:
if len(self.peers) < self.max_peers:
self.peers.add(peer)
def remove_peer(self, peer: PeerInfo) -> None:
self.peers.discard(peer)
def random_peer(self) -> Optional[PeerInfo]:
if not self.peers:
return None
return random.choice(list(self.peers))
def random_subset(self, k: int) -> list[PeerInfo]:
peers_list = list(self.peers)
return random.sample(peers_list, min(k, len(peers_list)))
# ---------------------------------------------------------------------------
# Message serialization for network transport
# ---------------------------------------------------------------------------
def serialize_message(message: Message) -> bytes:
"""Serialize a Message for network transmission.
Format: [total_length:4][nonce:8][id:32][num_digests:2][digests...][payload]
"""
body = message.serialize()
length = len(body)
return struct.pack("!I", length) + body
def deserialize_message(data: bytes) -> Message:
"""Deserialize a Message from network bytes.
Parses the fixed-size fields and reconstructs the Message object.
"""
offset = 0
nonce = data[offset:offset + NONCE_LENGTH]
offset += NONCE_LENGTH
id_bytes = data[offset:offset + ID_LENGTH]
offset += ID_LENGTH
num_digests = int.from_bytes(data[offset:offset + 2], "big")
offset += 2
digests = []
for _ in range(num_digests):
d = data[offset:offset + DIGEST_LENGTH]
digests.append(d)
offset += DIGEST_LENGTH
payload = data[offset:]
return Message(
nonce=nonce,
id=id_bytes,
digests=tuple(digests),
payload=payload,
)
# ---------------------------------------------------------------------------
# Protocol message types
# ---------------------------------------------------------------------------
# Simple protocol: 1-byte type prefix
MSG_TYPE_PUSH_MESSAGE = b"\x01" # Push a crisis message
MSG_TYPE_PULL_REQUEST = b"\x02" # Request missing messages
MSG_TYPE_PULL_RESPONSE = b"\x03" # Response with requested messages
MSG_TYPE_PEER_PUSH = b"\x04" # Push peer list
MSG_TYPE_PEER_PULL = b"\x05" # Request peer list
MSG_TYPE_PEER_RESPONSE = b"\x06" # Response with peer list
# ---------------------------------------------------------------------------
# Gossip Server
# ---------------------------------------------------------------------------
class GossipServer:
"""Asyncio-based gossip server implementing Algorithms 3 and 4.
Runs two parallel loops:
1. Member discovery push & pull (Algorithm 3)
2. Message push & pull (Algorithm 4)
Plus a listener that handles incoming connections.
"""
def __init__(self, host: str, port: int, graph: LamportGraph,
network_view: NetworkView,
push_interval: float = 2.0,
discovery_interval: float = 5.0):
self.host = host
self.port = port
self.graph = graph
self.network_view = network_view
self.push_interval = push_interval
self.discovery_interval = discovery_interval
self._server: Optional[asyncio.Server] = None
self._running = False
async def start(self) -> None:
"""Start the gossip server and all gossip loops."""
self._running = True
self._server = await asyncio.start_server(
self._handle_connection, self.host, self.port
)
logger.info(f"Gossip server listening on {self.host}:{self.port}")
# Run the gossip loops concurrently (paper: "run in parallel forever")
await asyncio.gather(
self._server.serve_forever(),
self._discovery_push_loop(),
self._message_push_loop(),
)
async def stop(self) -> None:
"""Stop the gossip server."""
self._running = False
if self._server:
self._server.close()
await self._server.wait_closed()
# ------------------------------------------------------------------
# Algorithm 3: Member discovery push & pull
# ------------------------------------------------------------------
async def _discovery_push_loop(self) -> None:
"""Algorithm 3, lines 1-5: periodically push peer list to random peers."""
while self._running:
await asyncio.sleep(self.discovery_interval)
peer = self.network_view.random_peer()
if peer is None:
continue
try:
await self._send_peer_push(peer)
await self._send_peer_pull(peer)
except (ConnectionError, OSError) as e:
logger.debug(f"Discovery push to {peer.address} failed: {e}")
self.network_view.remove_peer(peer)
async def _send_peer_push(self, peer: PeerInfo) -> None:
"""Push our peer list to a remote peer."""
peer_data = self._encode_peer_list(list(self.network_view.peers))
await self._send_to_peer(peer, MSG_TYPE_PEER_PUSH + peer_data)
async def _send_peer_pull(self, peer: PeerInfo) -> None:
"""Request a peer list from a remote peer."""
response = await self._send_and_receive(peer, MSG_TYPE_PEER_PULL)
if response and response[0:1] == MSG_TYPE_PEER_RESPONSE:
new_peers = self._decode_peer_list(response[1:])
for p in new_peers:
if p.host != self.host or p.port != self.port:
self.network_view.add_peer(p)
# ------------------------------------------------------------------
# Algorithm 4: Message gossip push & pull
# ------------------------------------------------------------------
async def _message_push_loop(self) -> None:
"""Algorithm 4, lines 1-5: push unordered messages to random peers.
"Messages are retransmitted via push gossip, only if they don't
have a total order yet." (Section 4.4)
"""
while self._running:
await asyncio.sleep(self.push_interval)
peer = self.network_view.random_peer()
if peer is None:
continue
# Push messages that don't have total_position yet
unordered = [
v for v in self.graph.all_vertices()
if v.total_position is None
]
if not unordered:
continue
try:
for vertex in unordered:
msg_bytes = serialize_message(vertex.m)
await self._send_to_peer(
peer, MSG_TYPE_PUSH_MESSAGE + msg_bytes
)
except (ConnectionError, OSError) as e:
logger.debug(f"Message push to {peer.address} failed: {e}")
# ------------------------------------------------------------------
# Connection handler (incoming)
# ------------------------------------------------------------------
async def _handle_connection(self, reader: asyncio.StreamReader,
writer: asyncio.StreamWriter) -> None:
"""Handle an incoming gossip connection.
Algorithm 3, lines 6-13 (peer data) and Algorithm 4, lines 6-13
(message data).
"""
try:
data = await asyncio.wait_for(reader.read(65536), timeout=10.0)
if not data:
return
msg_type = data[0:1]
payload = data[1:]
if msg_type == MSG_TYPE_PUSH_MESSAGE:
# Received a message: try to extend our Lamport graph
self._handle_push_message(payload)
elif msg_type == MSG_TYPE_PULL_REQUEST:
# Someone wants messages: send what we have
response = self._handle_pull_request(payload)
writer.write(response)
await writer.drain()
elif msg_type == MSG_TYPE_PEER_PUSH:
# Received a peer list: update our view
new_peers = self._decode_peer_list(payload)
for p in new_peers:
if p.host != self.host or p.port != self.port:
self.network_view.add_peer(p)
elif msg_type == MSG_TYPE_PEER_PULL:
# Someone wants our peer list
response = MSG_TYPE_PEER_RESPONSE + self._encode_peer_list(
list(self.network_view.peers)
)
writer.write(response)
await writer.drain()
except (asyncio.TimeoutError, ConnectionError):
pass
finally:
writer.close()
def _handle_push_message(self, data: bytes) -> Optional[Vertex]:
"""Process a pushed message: validate and extend graph if valid.
Algorithm 4, lines 7-8: "if MESSAGE_INTEGRITY(m, G) then
expand G with vertex v, such that v.m = m"
"""
try:
# Parse length prefix
if len(data) < 4:
return None
length = struct.unpack("!I", data[:4])[0]
msg_data = data[4:4 + length]
message = deserialize_message(msg_data)
return self.graph.extend(message)
except Exception as e:
logger.debug(f"Failed to process pushed message: {e}")
return None
def _handle_pull_request(self, data: bytes) -> bytes:
"""Respond to a pull request with messages the requester is missing.
Algorithm 4, lines 10-11: "respond with appropriate set of messages"
"""
# Data contains a list of digests the requester already has
known_digests = set()
offset = 0
while offset + DIGEST_LENGTH <= len(data):
known_digests.add(data[offset:offset + DIGEST_LENGTH])
offset += DIGEST_LENGTH
# Send messages the requester doesn't have
response_parts = [MSG_TYPE_PULL_RESPONSE]
for d, vertex in self.graph.vertices.items():
if d not in known_digests:
response_parts.append(serialize_message(vertex.m))
return b"".join(response_parts)
# ------------------------------------------------------------------
# Network I/O helpers
# ------------------------------------------------------------------
async def _send_to_peer(self, peer: PeerInfo, data: bytes) -> None:
"""Send data to a peer (fire-and-forget)."""
reader, writer = await asyncio.open_connection(peer.host, peer.port)
writer.write(data)
await writer.drain()
writer.close()
async def _send_and_receive(self, peer: PeerInfo, data: bytes) -> Optional[bytes]:
"""Send data and wait for a response."""
try:
reader, writer = await asyncio.open_connection(peer.host, peer.port)
writer.write(data)
await writer.drain()
response = await asyncio.wait_for(reader.read(65536), timeout=5.0)
writer.close()
return response
except Exception:
return None
# ------------------------------------------------------------------
# Peer list encoding
# ------------------------------------------------------------------
@staticmethod
def _encode_peer_list(peers: list[PeerInfo]) -> bytes:
"""Encode a list of peers as bytes: [count:2][host_len:1][host][port:2]..."""
parts = [struct.pack("!H", len(peers))]
for peer in peers:
host_bytes = peer.host.encode("utf-8")
parts.append(struct.pack("!B", len(host_bytes)))
parts.append(host_bytes)
parts.append(struct.pack("!H", peer.port))
return b"".join(parts)
@staticmethod
def _decode_peer_list(data: bytes) -> list[PeerInfo]:
"""Decode a peer list from bytes."""
if len(data) < 2:
return []
count = struct.unpack("!H", data[:2])[0]
offset = 2
peers = []
for _ in range(count):
if offset >= len(data):
break
host_len = data[offset]
offset += 1
host = data[offset:offset + host_len].decode("utf-8")
offset += host_len
port = struct.unpack("!H", data[offset:offset + 2])[0]
offset += 2
peers.append(PeerInfo(host=host, port=port))
return peers

479
src/crisis/graph.py Normal file
View file

@ -0,0 +1,479 @@
"""
Lamport Graphs (Section 3.2)
Lamport graphs represent the causal partial order between messages as a
directed acyclic graph. They are the central data structure of the Crisis
protocol -- all consensus state is derived from the graph structure.
Definition 3.5 (Lamport Graph):
Let V VERTEX be a finite set of vertices, such that all vertices
v_hat with v_hat v for all v V are in V, but no two vertices in V
are equivalent. Then the graph G = (V, A) with (v, v_hat) A if and
only if v -> v_hat is called a *Lamport graph*.
Key properties:
- Directed and acyclic (Proposition 3.6)
- The past of a vertex is invariant across Lamport graphs (Theorem 3.7)
- No two equivalent vertices exist in the same graph
This module implements:
- Algorithm 1: Message generation
- Algorithm 2: Message integrity checking and graph extension
- Causality queries (past, future, timelike, spacelike)
"""
from __future__ import annotations
import os
from typing import Optional
from crisis.crypto import digest
from crisis.message import Message, Vertex, Vote, ID_LENGTH, NONCE_LENGTH
from crisis.weight import ProofOfWorkWeight, WeightSystem
class LamportGraph:
"""A Lamport graph: a DAG of vertices connected by causal acknowledgement.
The graph is stored as:
- vertices: dict mapping message digest -> Vertex
- edges: dict mapping digest -> set of digests it references
(i.e. v -> v_hat means v acknowledges v_hat)
Invariants maintained:
- No two vertices have the same underlying message (no equivalence)
- All referenced digests either exist in the graph or are the empty digest
- The graph is acyclic (guaranteed by hash function properties)
"""
def __init__(self, weight_system: WeightSystem | None = None):
self.weight_system: WeightSystem = weight_system or ProofOfWorkWeight(min_leading_zeros=0)
# digest -> Vertex
self.vertices: dict[bytes, Vertex] = {}
# digest -> set of digests this vertex references (outgoing causal edges)
# An edge v -> v_hat means "v acknowledges v_hat" i.e. H(v_hat.m) ∈ v.m.digests
self.edges: dict[bytes, set[bytes]] = {}
# Reverse edges for efficient "future" queries
# digest -> set of digests that reference this vertex
self.reverse_edges: dict[bytes, set[bytes]] = {}
# ------------------------------------------------------------------
# Graph queries
# ------------------------------------------------------------------
def __len__(self) -> int:
return len(self.vertices)
def __contains__(self, digest_or_vertex) -> bool:
if isinstance(digest_or_vertex, Vertex):
return digest_or_vertex.message_digest in self.vertices
return digest_or_vertex in self.vertices
def get_vertex(self, msg_digest: bytes) -> Optional[Vertex]:
return self.vertices.get(msg_digest)
def all_vertices(self) -> list[Vertex]:
return list(self.vertices.values())
def vertex_count(self) -> int:
return len(self.vertices)
# ------------------------------------------------------------------
# Causality (Definition 3.2)
# ------------------------------------------------------------------
# m -> m_hat (m happens before m_hat) iff:
# - m = m_hat, OR
# - there is a chain m -> m1 -> ... -> mk -> m_hat
# In our DAG: v has an edge to v_hat means v acknowledges v_hat.
# So v is in the *future* of v_hat, and v_hat is in the *past* of v.
def direct_causes(self, v: Vertex) -> list[Vertex]:
"""Return the direct causes of v (vertices that v acknowledges).
These are the vertices whose digests appear in v.m.digests.
In graph terms: the outgoing neighbors of v.
"""
result = []
for d in self.edges.get(v.message_digest, set()):
vertex = self.vertices.get(d)
if vertex is not None:
result.append(vertex)
return result
def direct_effects(self, v: Vertex) -> list[Vertex]:
"""Return the direct effects of v (vertices that acknowledge v).
In graph terms: the incoming neighbors of v (who references v).
"""
result = []
for d in self.reverse_edges.get(v.message_digest, set()):
vertex = self.vertices.get(d)
if vertex is not None:
result.append(vertex)
return result
def past(self, v: Vertex) -> set[Vertex]:
"""G_v: the subgraph of G containing all causes of v.
Definition 3.5: "the subgraph G_v of G that contains all causes
of v is called the *past* of v".
Returns the set of all vertices that are causally before v
(including v itself -- reflexivity).
"""
visited: set[bytes] = set()
stack = [v.message_digest]
while stack:
current = stack.pop()
if current in visited:
continue
visited.add(current)
for neighbor in self.edges.get(current, set()):
if neighbor in self.vertices and neighbor not in visited:
stack.append(neighbor)
return {self.vertices[d] for d in visited if d in self.vertices}
def future(self, v: Vertex) -> set[Vertex]:
"""All vertices that are causally after v (including v itself)."""
visited: set[bytes] = set()
stack = [v.message_digest]
while stack:
current = stack.pop()
if current in visited:
continue
visited.add(current)
for neighbor in self.reverse_edges.get(current, set()):
if neighbor in self.vertices and neighbor not in visited:
stack.append(neighbor)
return {self.vertices[d] for d in visited if d in self.vertices}
def is_cause_of(self, v: Vertex, v_hat: Vertex) -> bool:
"""Check if v ≤ v_hat (v is in the past of v_hat).
Definition 3.4: v is said to happen before v_hat (v v_hat)
if there is a causality chain from v to v_hat.
"""
if v == v_hat:
return True
return v in self.past(v_hat)
def are_timelike(self, v: Vertex, v_hat: Vertex) -> bool:
"""Check if v and v_hat are timelike (comparable / causally related)."""
return self.is_cause_of(v, v_hat) or self.is_cause_of(v_hat, v)
def are_spacelike(self, v: Vertex, v_hat: Vertex) -> bool:
"""Check if v and v_hat are spacelike (incomparable / no causal relation).
Spacelike vertices are the ones that need the total order protocol
to become comparable. The protocol extends the timelike partial
order to cover spacelike vertices as well.
"""
return not self.are_timelike(v, v_hat)
# ------------------------------------------------------------------
# Mutations (Definition 4.2)
# ------------------------------------------------------------------
def find_mutations(self, vertex_id: bytes) -> list[list[Vertex]]:
"""Find mutations: vertices with the same id that are spacelike.
Definition 4.2: Two vertices v and v_hat in G are called a *mutation*
of a virtual process if they have the same id and are spacelike,
i.e. neither v v_hat nor v_hat v holds.
Mutations are the virtual voting equivalent of equivocation -- a
byzantine actor sending different votes to different processes.
Returns a list of groups of mutually spacelike same-id vertices.
"""
# Group vertices by id
by_id: dict[bytes, list[Vertex]] = {}
for v in self.vertices.values():
by_id.setdefault(v.id, []).append(v)
mutations = []
for vid, group in by_id.items():
if vid != vertex_id:
continue
# Find spacelike pairs within the group
spacelike_group = []
for i, v1 in enumerate(group):
for v2 in group[i + 1:]:
if self.are_spacelike(v1, v2):
if v1 not in spacelike_group:
spacelike_group.append(v1)
if v2 not in spacelike_group:
spacelike_group.append(v2)
if spacelike_group:
mutations.append(spacelike_group)
return mutations
# ------------------------------------------------------------------
# Byte-level correctness (part of Algorithm 2)
# ------------------------------------------------------------------
def _bytelevel_correctness(self, message: Message) -> bool:
"""BYTELEVEL_CORRECTNESS: basic structural validation of a message.
Checks that the message has valid field lengths and is well-formed.
"""
if len(message.nonce) != NONCE_LENGTH:
return False
if len(message.id) != ID_LENGTH:
return False
for d in message.digests:
if len(d) != 32: # DIGEST_LENGTH
return False
return True
def _payload_correctness(self, message: Message) -> bool:
"""PAYLOAD_CORRECTNESS: validate the payload against system rules.
In this PoC, any payload is accepted. A real system would enforce
application-specific validation here.
"""
return True
# ------------------------------------------------------------------
# Algorithm 2: Message integrity (Section 4.2)
# ------------------------------------------------------------------
def message_integrity(self, message: Message) -> bool:
"""Check whether a message can be validly added to this Lamport graph.
Algorithm 2 from the paper:
1. Check BYTELEVEL_CORRECTNESS(m)
2. Check w(m) > c_min (weight threshold)
3. Check PAYLOAD_CORRECTNESS(m.payload)
4. Check no equivalent vertex exists (no vertex with same digest)
5. For each digest in m.digests:
- It must reference a vertex in G
- All referenced vertices must have different id's
6. If there is a vertex v in G with v.id = m.id:
- One of m.digests must reference v (or a vertex in v's past)
Ensures the virtual process forms a chain, not a tree.
Returns True if the message passes integrity checks.
"""
# Step 1: byte-level structure
if not self._bytelevel_correctness(message):
return False
# Step 2: weight threshold
if not self.weight_system.is_valid_weight(message):
return False
# Step 3: payload rules
if not self._payload_correctness(message):
return False
msg_digest = message.compute_digest()
# Step 4: no duplicate (no equivalent vertex)
if msg_digest in self.vertices:
return False
# Step 5: all referenced digests must exist in G
# and all referenced vertices must have different ids
referenced_ids: set[bytes] = set()
for ref_digest in message.digests:
if ref_digest not in self.vertices:
return False
ref_vertex = self.vertices[ref_digest]
if ref_vertex.id in referenced_ids:
return False # Two references to same id
referenced_ids.add(ref_vertex.id)
# Step 6: if same id exists, must reference it (chain constraint)
# Find the "last vertex" with this id (not referenced by any other
# vertex with the same id)
same_id_vertices = [v for v in self.vertices.values() if v.id == message.id]
if same_id_vertices:
# Check that at least one digest references a same-id vertex
referenced_digests = set(message.digests)
found_chain_link = False
for v in same_id_vertices:
if v.message_digest in referenced_digests:
found_chain_link = True
break
if not found_chain_link:
return False
return True
# ------------------------------------------------------------------
# Lamport graph extension (Section 4.2)
# ------------------------------------------------------------------
def extend(self, message: Message) -> Optional[Vertex]:
"""Attempt to extend the Lamport graph with a new message.
If the message passes integrity checks (Algorithm 2), create a new
vertex and add it to the graph with appropriate edges.
Proposition 4.1 guarantees that the extension of a Lamport graph
by a valid message is itself a Lamport graph.
Returns the new Vertex if successful, None if integrity check fails.
"""
if not self.message_integrity(message):
return None
vertex = Vertex(m=message)
msg_digest = message.compute_digest()
# Add vertex
self.vertices[msg_digest] = vertex
# Add edges: this vertex -> each referenced vertex
self.edges[msg_digest] = set()
for ref_digest in message.digests:
self.edges[msg_digest].add(ref_digest)
# Reverse edge
if ref_digest not in self.reverse_edges:
self.reverse_edges[ref_digest] = set()
self.reverse_edges[ref_digest].add(msg_digest)
# Initialize reverse_edges entry for this vertex
if msg_digest not in self.reverse_edges:
self.reverse_edges[msg_digest] = set()
return vertex
# ------------------------------------------------------------------
# Algorithm 1: Message generation (Section 4.1)
# ------------------------------------------------------------------
def generate_message(self, process_id: bytes, payload: bytes,
weight_system: WeightSystem | None = None) -> Message:
"""Generate a valid message for a given virtual process id.
Algorithm 1 from the paper:
1. Find the last vertex v with v.id = id in G
2. Choose S {v.m | v G v G_v} such that all have different ids
3. Return message with digests = {H(v.m)} {H(m) | m S {v.m}}
The nonce is chosen so that w(m) > c_min (via mining if PoW).
"""
ws = weight_system or self.weight_system
# Find the last vertex with this process id
last_vertex = self._find_last_vertex(process_id)
# Collect digests: last same-id vertex + a sample of other vertices
digests_list: list[bytes] = []
if last_vertex is not None:
# Must reference the last vertex with same id
digests_list.append(last_vertex.message_digest)
# Add cross-references to vertices NOT in last_vertex's past
past_digests = {v.message_digest for v in self.past(last_vertex)}
candidates = [
v for d, v in self.vertices.items()
if d not in past_digests
and v.id != process_id
and d != last_vertex.message_digest
]
# Include candidates with different ids
seen_ids: set[bytes] = {process_id}
for candidate in candidates:
if candidate.id not in seen_ids:
digests_list.append(candidate.message_digest)
seen_ids.add(candidate.id)
else:
# First message for this id: reference a sample of existing vertices
seen_ids = {process_id}
for v in self.vertices.values():
if v.id not in seen_ids:
digests_list.append(v.message_digest)
seen_ids.add(v.id)
digests_tuple = tuple(digests_list)
# Mine a valid nonce (or just find one that meets threshold)
if isinstance(ws, ProofOfWorkWeight):
message = ws.mine_nonce(process_id, digests_tuple, payload)
else:
# For non-PoW systems, use a random nonce
nonce = os.urandom(NONCE_LENGTH)
message = Message(nonce=nonce, id=process_id, digests=digests_tuple, payload=payload)
return message
def _find_last_vertex(self, process_id: bytes) -> Optional[Vertex]:
"""Find the last vertex with a given process id.
A vertex is "last" for an id if no other vertex with the same id
references it (i.e. it has no same-id successor).
"""
same_id = [v for v in self.vertices.values() if v.id == process_id]
if not same_id:
return None
# Find the one that is not referenced by any other same-id vertex
referenced_by_same_id: set[bytes] = set()
for v in same_id:
for d in v.digests:
ref = self.vertices.get(d)
if ref is not None and ref.id == process_id:
referenced_by_same_id.add(d)
for v in same_id:
if v.message_digest not in referenced_by_same_id:
return v
# Fallback: return the one added most recently (by convention)
return same_id[-1]
# ------------------------------------------------------------------
# Vertices by id (for virtual process queries)
# ------------------------------------------------------------------
def vertices_by_id(self, process_id: bytes) -> list[Vertex]:
"""Return all vertices belonging to a given virtual process id."""
return [v for v in self.vertices.values() if v.id == process_id]
def all_process_ids(self) -> set[bytes]:
"""Return all unique virtual process ids in this graph."""
return {v.id for v in self.vertices.values()}
def last_vertices_by_id(self) -> dict[bytes, Vertex]:
"""Return the last vertex for each virtual process id."""
result = {}
for pid in self.all_process_ids():
last = self._find_last_vertex(pid)
if last is not None:
result[pid] = last
return result
# ------------------------------------------------------------------
# Weight queries
# ------------------------------------------------------------------
def vertex_weight(self, v: Vertex) -> int:
"""w(v) = w(v.m): the weight of a vertex is the weight of its message."""
return self.weight_system.weight(v.m)
def set_weight(self, vertices: set[Vertex] | list[Vertex]) -> int:
"""w(M) := ⊕_{m ∈ M} w(m): the combined weight of a set of vertices."""
total = 0
for v in vertices:
total = self.weight_system.weight_sum(total, self.vertex_weight(v))
return total
# ------------------------------------------------------------------
# Display
# ------------------------------------------------------------------
def __repr__(self) -> str:
return f"LamportGraph(vertices={len(self.vertices)}, ids={len(self.all_process_ids())})"

250
src/crisis/message.py Normal file
View file

@ -0,0 +1,250 @@
"""
Data Structures (Section 3)
3.1 Messages
-------------
Messages distribute payload across the network. The purpose of the protocol
is to establish a total order on those messages that respects causality.
A message is a byte string of variable length with the following structure
(paper, page 3):
struct Message {
byte[c1] nonce,
byte[c2] id,
byte[c3] num_digests,
byte[p * num_digests] digests,
byte[] payload
}
Where c1, c2, c3 are fixed protocol constants and p is the digest length.
The *nonce* is used by the weight function (e.g. PoW grinding).
The *id* groups messages into virtual processes.
The *digests* field encodes causal acknowledgement of other messages.
Key insight: a message that acknowledges other messages defines an inherent
natural causality -- this is the Lamport "happens-before" relation (1978).
m -> m_hat iff H(m_hat) is contained in m.digests (Eq. 2)
3.1.3 Vertices
---------------
To establish total order, messages are extended by local voting data that is
NOT transmitted. Votes are deduced from the causal relation between messages.
This is the key characteristic of virtual voting (Moser & Melliar-Smith).
struct Vertex {
Message m,
Option<uint> round,
Option<boolean> is_last,
Option<TotalOrderSet<uint>> svp, # safe voting pattern
Option<(Message, Option<bool>)> vote,
Option<uint> total_position
}
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Optional
from crisis.crypto import digest, DIGEST_LENGTH
# ---------------------------------------------------------------------------
# Protocol constants (c1, c2, c3 from the paper)
# ---------------------------------------------------------------------------
# These define the byte-lengths of the fixed-size fields in a message.
# Chosen for a practical PoC: generous enough for real use, compact enough
# for clarity.
NONCE_LENGTH = 8 # c1: 8 bytes of nonce (plenty for PoW search space)
ID_LENGTH = 32 # c2: 32 bytes for virtual process id (a hash)
NUM_DIGESTS_LENGTH = 2 # c3: 2 bytes => up to 65535 referenced digests
# ---------------------------------------------------------------------------
# Message
# ---------------------------------------------------------------------------
@dataclass(frozen=True)
class Message:
"""An immutable Crisis message as defined in Section 3.1.
A message is the atomic unit of communication in the Crisis protocol.
It carries a payload and encodes causal history through its digests field.
Attributes:
nonce: Used by the weight function (e.g. PoW nonce grinding).
id: Groups this message into a virtual process.
digests: Tuple of digests of causally prior messages (H values).
payload: The actual application data being ordered.
"""
nonce: bytes
id: bytes
digests: tuple[bytes, ...] = ()
payload: bytes = b""
def __post_init__(self):
if len(self.nonce) != NONCE_LENGTH:
raise ValueError(f"nonce must be {NONCE_LENGTH} bytes, got {len(self.nonce)}")
if len(self.id) != ID_LENGTH:
raise ValueError(f"id must be {ID_LENGTH} bytes, got {len(self.id)}")
for i, d in enumerate(self.digests):
if len(d) != DIGEST_LENGTH:
raise ValueError(f"digest[{i}] must be {DIGEST_LENGTH} bytes")
def serialize(self) -> bytes:
"""Serialize this message to a canonical byte string.
The serialized form is what gets hashed to produce the message's digest.
Format: nonce | id | num_digests (2 bytes big-endian) | digests... | payload
"""
num = len(self.digests)
parts = [
self.nonce,
self.id,
num.to_bytes(NUM_DIGESTS_LENGTH, "big"),
]
for d in self.digests:
parts.append(d)
parts.append(self.payload)
return b"".join(parts)
def compute_digest(self) -> bytes:
"""Compute H(m) -- the digest of this message.
This is the value other messages include in their digests field
to acknowledge this message (establishing causality, Eq. 2).
"""
return digest(self.serialize())
@property
def num_digests(self) -> int:
return len(self.digests)
def __repr__(self) -> str:
h = self.compute_digest().hex()[:12]
return f"Message(id={self.id.hex()[:8]}..., digests={self.num_digests}, hash={h}...)"
# ---------------------------------------------------------------------------
# The empty message (paper: ∅ ∈ MESSAGE)
# ---------------------------------------------------------------------------
# "We postulate a special non-message ∅ ∈ MESSAGE" (Section 3.1)
# Acknowledgement of ∅ is defined as H(empty string).
EMPTY_MESSAGE_DIGEST = digest(b"")
# ---------------------------------------------------------------------------
# Vote
# ---------------------------------------------------------------------------
@dataclass(frozen=True)
class Vote:
"""A virtual vote as computed locally by each vertex.
From the paper (Algorithm 7): v.vote(r) = (l, b) describes v's vote
on some message l, together with a possibly undecided binary value
b {, 0, 1} in a round r.
Attributes:
message: The message l being voted on (None = , the non-leader).
binary: The binary part of the vote: None= (undecided), 0, or 1.
"""
message: Optional[Message] = None
binary: Optional[int] = None # None = ⊥, 0, or 1
def __repr__(self) -> str:
msg_str = "" if self.message is None else self.message.compute_digest().hex()[:8]
bin_str = "" if self.binary is None else str(self.binary)
return f"Vote({msg_str}, {bin_str})"
# ---------------------------------------------------------------------------
# Vertex
# ---------------------------------------------------------------------------
@dataclass
class Vertex:
"""A vertex in a Lamport graph (Section 3.1.3).
A vertex wraps a message and adds locally-computed consensus state.
The additional fields (round, is_last, svp, vote, total_position) are
never transmitted -- they are deduced from the causal structure.
From the paper (page 5, Eq. 6):
w(v) <- w(v.m)
v.nonce <- v.m.nonce
v.id <- v.m.id
v.num_digests <- v.m.num_digests
v.digests <- v.m.digests
v.payload <- v.m.payload
Attributes:
m: The underlying message.
round: The virtual round number (Algorithm 5).
is_last: Whether this is a "last vertex" of its round (Alg 5).
svp: Safe voting pattern -- ordered set of round numbers.
vote: Per-round votes: round -> Vote.
total_position: Final position in the total order (Algorithm 9/10).
"""
m: Message
# Locally computed consensus state (initialized to None / ⊥)
round: Optional[int] = None
is_last: Optional[bool] = None
svp: list[int] = field(default_factory=list)
vote: dict[int, Vote] = field(default_factory=dict)
total_position: Optional[int] = None
# ------------------------------------------------------------------
# Convenience accessors that delegate to the underlying message
# ------------------------------------------------------------------
@property
def nonce(self) -> bytes:
return self.m.nonce
@property
def id(self) -> bytes:
return self.m.id
@property
def digests(self) -> tuple[bytes, ...]:
return self.m.digests
@property
def payload(self) -> bytes:
return self.m.payload
@property
def message_digest(self) -> bytes:
"""H(v.m) -- the digest that uniquely identifies this vertex's message."""
return self.m.compute_digest()
# ------------------------------------------------------------------
# Equivalence (Definition 3.3)
# ------------------------------------------------------------------
# "Two vertices v and v_hat are equivalent if v.m = v_hat.m"
# i.e. they wrap the same underlying message.
def equivalent_to(self, other: Vertex) -> bool:
"""Check vertex equivalence: same underlying message."""
return self.message_digest == other.message_digest
def __eq__(self, other: object) -> bool:
if not isinstance(other, Vertex):
return NotImplemented
return self.message_digest == other.message_digest
def __hash__(self) -> int:
return hash(self.message_digest)
def __repr__(self) -> str:
h = self.message_digest.hex()[:12]
round_str = str(self.round) if self.round is not None else "?"
last_str = "*" if self.is_last else ""
return f"Vertex({h}..., r={round_str}{last_str})"

336
src/crisis/node.py Normal file
View file

@ -0,0 +1,336 @@
"""
Crisis Node (Section 5.9 -- The Crisis Protocol)
This module ties all components together into a full Crisis node.
From the paper (Section 5.9):
"The overall algorithm works as follows: Member discovery (3) and
message gossip (4) are executed in infinite loops, concurrently to
the rest of the system. Ideally the message sending loop is executed
on as many parallel threads as possible. This implies that an overall
unbounded amount of new messages arrive over time due to our liveness
assumption. In addition each process may generate messages and write
them into its own Lamport graph."
The full node runs these concurrent loops:
1. Gossip: member discovery + message dissemination
2. Message generation: create new messages with PoW
3. Consensus: compute rounds, voting patterns, leader elections, order
Each loop runs independently and they communicate through the shared
Lamport graph.
"""
from __future__ import annotations
import asyncio
import logging
import os
import time
from typing import Optional
from crisis.crypto import digest
from crisis.graph import LamportGraph
from crisis.gossip import GossipServer, NetworkView, PeerInfo
from crisis.message import Message, Vertex, ID_LENGTH, NONCE_LENGTH
from crisis.order import LeaderStream, compute_order
from crisis.rounds import compute_rounds
from crisis.voting import compute_virtual_leader_election, compute_safe_voting_pattern
from crisis.weight import ProofOfWorkWeight, DifficultyOracle
logger = logging.getLogger(__name__)
class CrisisNode:
"""A full Crisis protocol node.
Combines all protocol components into a single running process:
- Lamport graph (the shared DAG)
- Weight system (PoW)
- Difficulty oracle
- Gossip server (member discovery + message dissemination)
- Consensus engine (rounds, voting, ordering)
Attributes:
process_id: This node's virtual process identity.
graph: The local Lamport graph.
leader_stream: The evolving total order leader stream.
network_view: Known peers in the network.
"""
def __init__(self, host: str = "127.0.0.1", port: int = 9000,
min_pow_zeros: int = 1,
difficulty_constant: int = 4,
connectivity_k: int = 2,
message_interval: float = 3.0,
consensus_interval: float = 5.0,
seed_peers: list[tuple[str, int]] | None = None):
# Identity: use a hash of host:port as this node's virtual process id
self.process_id = digest(f"{host}:{port}".encode())[:ID_LENGTH]
self.host = host
self.port = port
# Protocol components
self.weight_system = ProofOfWorkWeight(min_leading_zeros=min_pow_zeros)
self.difficulty = DifficultyOracle(constant_difficulty=difficulty_constant)
self.connectivity_k = connectivity_k
self.graph = LamportGraph(weight_system=self.weight_system)
self.leader_stream = LeaderStream()
# Timing
self.message_interval = message_interval
self.consensus_interval = consensus_interval
# Network
self.network_view = NetworkView()
if seed_peers:
for h, p in seed_peers:
self.network_view.add_peer(PeerInfo(host=h, port=p))
self.gossip = GossipServer(
host=host, port=port,
graph=self.graph,
network_view=self.network_view,
)
# State
self._running = False
self._message_count = 0
# Callbacks for monitoring
self.on_new_vertex: Optional[callable] = None
self.on_round_update: Optional[callable] = None
self.on_order_update: Optional[callable] = None
# ------------------------------------------------------------------
# Main entry point
# ------------------------------------------------------------------
async def run(self) -> None:
"""Start all protocol loops concurrently.
This is the Crisis protocol (Section 5.9): three concurrent loops.
"""
self._running = True
logger.info(
f"Crisis node starting on {self.host}:{self.port} "
f"(id={self.process_id.hex()[:16]}...)"
)
try:
await asyncio.gather(
self._gossip_loop(),
self._message_generation_loop(),
self._consensus_loop(),
)
except asyncio.CancelledError:
logger.info("Crisis node shutting down")
finally:
self._running = False
async def stop(self) -> None:
self._running = False
await self.gossip.stop()
# ------------------------------------------------------------------
# Loop 1: Gossip (Algorithms 3 & 4)
# ------------------------------------------------------------------
async def _gossip_loop(self) -> None:
"""Run the gossip server (member discovery + message dissemination)."""
try:
await self.gossip.start()
except Exception as e:
logger.error(f"Gossip loop error: {e}")
# ------------------------------------------------------------------
# Loop 2: Message generation (Algorithm 1)
# ------------------------------------------------------------------
async def _message_generation_loop(self) -> None:
"""Periodically generate new messages and add them to the graph.
Each message:
1. References the last same-id message (chain constraint)
2. References a sample of other vertices (cross-links for connectivity)
3. Has a PoW nonce meeting the weight threshold
4. Carries an application payload
"""
while self._running:
await asyncio.sleep(self.message_interval)
try:
payload = self._generate_payload()
message = self.graph.generate_message(
self.process_id, payload, self.weight_system
)
vertex = self.graph.extend(message)
if vertex is not None:
self._message_count += 1
logger.debug(
f"Generated message #{self._message_count}: {vertex}"
)
if self.on_new_vertex:
self.on_new_vertex(vertex)
except Exception as e:
logger.error(f"Message generation error: {e}")
def _generate_payload(self) -> bytes:
"""Generate a payload for a new message.
In this PoC, payloads are simple timestamped entries.
A real application would put actual data here.
"""
self._message_count += 1
return f"msg-{self._message_count}-{time.time():.3f}".encode()
# ------------------------------------------------------------------
# Loop 3: Consensus (Algorithms 5, 6, 7, 9, 10)
# ------------------------------------------------------------------
async def _consensus_loop(self) -> None:
"""Periodically recompute consensus state.
From Section 5.9 and the proof section (Section 6):
"algorithms (5), (6) and (7) are executed in that order concurrently
on each vertex from V... the total order loop (9) runs concurrently
and waits for updates of the leader stream."
"""
while self._running:
await asyncio.sleep(self.consensus_interval)
if self.graph.vertex_count() == 0:
continue
try:
# Step 1: Compute rounds (Algorithm 5)
compute_rounds(self.graph, self.difficulty, self.connectivity_k)
if self.on_round_update:
self.on_round_update(self.graph)
# Step 2: Compute safe voting patterns (Algorithm 6)
for vertex in self.graph.all_vertices():
if vertex.is_last:
compute_safe_voting_pattern(
vertex, self.graph, self.difficulty,
self.connectivity_k
)
# Step 3: Virtual leader election (Algorithm 7)
leader_dict: dict[int, list[tuple[int, Message]]] = {}
for vertex in self.graph.all_vertices():
if vertex.svp:
compute_virtual_leader_election(
vertex, self.graph, self.difficulty,
self.connectivity_k, leader_dict
)
# Update leader stream from election results
for round_num, entries in leader_dict.items():
for deciding_round, leader_msg in entries:
self.leader_stream.update(
round_num, deciding_round, leader_msg
)
# Step 4: Compute total order (Algorithms 9 & 10)
ordered = compute_order(self.graph, self.leader_stream)
if ordered and self.on_order_update:
self.on_order_update(ordered)
logger.debug(
f"Consensus: {self.graph.vertex_count()} vertices, "
f"{len(self.leader_stream.leaders)} leaders, "
f"{len(ordered)} ordered"
)
except Exception as e:
logger.error(f"Consensus loop error: {e}")
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def submit_message(self, payload: bytes) -> Optional[Vertex]:
"""Submit an application message to be ordered by the protocol."""
message = self.graph.generate_message(
self.process_id, payload, self.weight_system
)
return self.graph.extend(message)
def get_total_order(self) -> list[tuple[int, bytes]]:
"""Get the current total order as (position, payload) pairs."""
ordered = compute_order(self.graph, self.leader_stream)
return [
(v.total_position, v.payload)
for v in ordered
if v.total_position is not None
]
def status(self) -> dict:
"""Return a summary of this node's current state."""
from crisis.rounds import max_round as get_max_round
return {
"process_id": self.process_id.hex()[:16],
"address": f"{self.host}:{self.port}",
"vertices": self.graph.vertex_count(),
"process_ids": len(self.graph.all_process_ids()),
"max_round": get_max_round(self.graph),
"leaders": len(self.leader_stream.leaders),
"peers": len(self.network_view.peers),
"messages_generated": self._message_count,
}
# ---------------------------------------------------------------------------
# CLI entry point
# ---------------------------------------------------------------------------
def main():
"""Run a Crisis node from the command line."""
import argparse
parser = argparse.ArgumentParser(
description="Crisis Protocol Node",
epilog="Probabilistically self-organizing total order in P2P networks"
)
parser.add_argument("--host", default="127.0.0.1", help="Listen address")
parser.add_argument("--port", type=int, default=9000, help="Listen port")
parser.add_argument("--pow-zeros", type=int, default=1,
help="Min PoW leading zeros (weight threshold)")
parser.add_argument("--difficulty", type=int, default=4,
help="Difficulty oracle constant")
parser.add_argument("--msg-interval", type=float, default=3.0,
help="Seconds between message generation")
parser.add_argument("--peers", nargs="*", default=[],
help="Seed peers as host:port")
args = parser.parse_args()
seed_peers = []
for peer_str in args.peers:
h, p = peer_str.rsplit(":", 1)
seed_peers.append((h, int(p)))
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(name)s] %(levelname)s: %(message)s"
)
node = CrisisNode(
host=args.host,
port=args.port,
min_pow_zeros=args.pow_zeros,
difficulty_constant=args.difficulty,
seed_peers=seed_peers,
message_interval=args.msg_interval,
)
asyncio.run(node.run())
if __name__ == "__main__":
main()

254
src/crisis/order.py Normal file
View file

@ -0,0 +1,254 @@
"""
Total Order (Section 5.8)
As time goes by and the Lamport graph grows, more and more round leaders
are computed and incorporated into the global leader stream LEADER_G(·).
Algorithm 9 (Order loop): watches for leader stream updates and recomputes
total order. Total order is achieved by topological sorting on the past
of appropriate vertices.
Algorithm 10 (Total order using Kahn's algorithm): generates total order
in linear runtime by processing vertices without outgoing causal edges first,
using voting weight to break ties among spacelike vertices.
The total order converges probabilistically: any two non-byzantine processes
will eventually compute the same total order (Proposition 6.21).
Definition 5.17 (Leader Stream):
LEADER_G : N -> Option<(uint, MESSAGE)>
is called the *global leader stream* of the Lamport graph.
Corollary 6.19 (Leader stream convergence):
If the probability for new rounds and safe voting pattern is not zero,
the leader streams of any two honest processes will converge.
"""
from __future__ import annotations
from typing import Optional
from crisis.graph import LamportGraph
from crisis.message import Message, Vertex
# ---------------------------------------------------------------------------
# Leader Stream (Definition 5.17)
# ---------------------------------------------------------------------------
class LeaderStream:
"""The global leader stream of a Lamport graph.
Maps round numbers to (deciding_round, leader_message) pairs.
Uses the Nakamoto longest chain rule: when a new leader is decided
in a later round, it may replace leaders decided in earlier rounds.
The leader stream converges to contain a single element per round
(Theorem 6.18), and honest processes' leader streams converge to
the same values (Corollary 6.19).
"""
def __init__(self):
# round_number -> (deciding_round, leader_message)
self.leaders: dict[int, tuple[int, Message]] = {}
def update(self, round_number: int, deciding_round: int,
leader_message: Message) -> bool:
"""Update the leader for a round via the Nakamoto longest chain rule.
Algorithm 8 (LONG_CHAIN): keep only the leader decided in the
highest round. Delete leaders from previous rounds that have
lower deciding rounds.
Returns True if the leader stream was modified.
"""
current = self.leaders.get(round_number)
if current is not None:
existing_deciding_round, _ = current
if existing_deciding_round >= deciding_round:
return False # Already have a leader from a higher round
self.leaders[round_number] = (deciding_round, leader_message)
# Prune: remove leaders with lower deciding rounds
# (longest chain rule -- keep only the longest)
max_deciding = max(dr for dr, _ in self.leaders.values())
to_remove = []
for r, (dr, _) in self.leaders.items():
if dr < max_deciding and r < round_number:
to_remove.append(r)
for r in to_remove:
del self.leaders[r]
return True
def get_leader(self, round_number: int) -> Optional[Message]:
"""Get the current leader message for a round, if any."""
entry = self.leaders.get(round_number)
return entry[1] if entry else None
def max_round(self) -> int:
"""Highest round with a decided leader."""
return max(self.leaders.keys()) if self.leaders else -1
def all_leaders(self) -> list[tuple[int, Message]]:
"""Return all leaders ordered by round number."""
return [(r, msg) for r, (_, msg) in sorted(self.leaders.items())]
def __repr__(self) -> str:
rounds = sorted(self.leaders.keys())
return f"LeaderStream(rounds={rounds})"
# ---------------------------------------------------------------------------
# Algorithm 9: Order Loop
# ---------------------------------------------------------------------------
def compute_order(graph: LamportGraph, leader_stream: LeaderStream) -> list[Vertex]:
"""Algorithm 9: compute total order from the leader stream.
Pseudocode:
1: loop order update loop
2: wait for LEADER_G(·) to change
3: s <- min round of all changed LEADER_G(t)
4: r <- max round of all LEADER_G(t)
5: v_{l_r} <- leader in highest round, smallest s in G
6: n <- max(v.total_position | v Ord_G(v_{l_{r-1}}))
7: for x t r do
8: randomly choose (p, l_t) LEADER_G(t)
9: if l_t then
10: ORDER(Ord_G(v_t), n) v_t.m = l_t
11: end if
12: end for
13: end loop
For this PoC, we compute the order in a single pass over the current
leader stream state.
"""
if not leader_stream.leaders:
return []
ordered: list[Vertex] = []
position = 0
# Process leaders in round order
for round_number, leader_message in leader_stream.all_leaders():
# Find the vertex corresponding to this leader message
leader_digest = leader_message.compute_digest()
leader_vertex = graph.get_vertex(leader_digest)
if leader_vertex is None:
continue
# Order the past of this leader vertex (excluding already-ordered)
past_vertices = graph.past(leader_vertex)
already_ordered = {v.message_digest for v in ordered}
new_vertices = [
v for v in past_vertices
if v.message_digest not in already_ordered
]
# Sort new vertices using Kahn's algorithm (Algorithm 10)
sorted_new = _kahns_total_order(new_vertices, graph)
for v in sorted_new:
v.total_position = position
ordered.append(v)
position += 1
return ordered
# ---------------------------------------------------------------------------
# Algorithm 10: Total Order using Kahn's Algorithm
# ---------------------------------------------------------------------------
def _kahns_total_order(vertices: list[Vertex], graph: LamportGraph) -> list[Vertex]:
"""Algorithm 10: generate total order using Kahn's algorithm.
Kahn's algorithm in its "arrow reversed" incarnation: we want to order
the past before the future in our Lamport graph.
Pseudocode from the paper:
1: procedure ORDER(dag:Ord(v), uint:last)
2: n <- last + 1
3: S <- set of all elements of Ord(v) with no outgoing edges
4: while S do
5: remove x with highest weight w(x) from S
6: x.total_position <- n
7: n <- n + 1
8: for each vertex y Ord(v) with edge e : y -> x do
9: remove edge e from Ord(v)
10: if y has no other outgoing edge then
11: S <- S {y}
12: end if
13: end for
14: end while
15: end procedure
Tie-breaking by voting weight ensures that all honest processes produce
the same total order from equivalent Lamport graphs.
"""
if not vertices:
return []
# Build a local subgraph for just these vertices
vertex_set = {v.message_digest for v in vertices}
# out_degree: for each vertex, count edges to other vertices in this set
out_edges: dict[bytes, set[bytes]] = {}
in_edges: dict[bytes, set[bytes]] = {}
for v in vertices:
d = v.message_digest
out_edges[d] = set()
in_edges[d] = set()
for v in vertices:
d = v.message_digest
for cause_d in graph.edges.get(d, set()):
if cause_d in vertex_set:
out_edges[d].add(cause_d)
in_edges[cause_d].add(d)
# Start with vertices that have no outgoing edges (sinks = earliest causes)
result: list[Vertex] = []
available = [
v for v in vertices
if len(out_edges[v.message_digest]) == 0
]
while available:
# Remove the vertex with highest weight (deterministic tie-breaking)
available.sort(key=lambda v: graph.vertex_weight(v), reverse=True)
chosen = available.pop(0)
result.append(chosen)
# Remove edges pointing to chosen
chosen_d = chosen.message_digest
for referrer_d in list(in_edges.get(chosen_d, set())):
out_edges[referrer_d].discard(chosen_d)
if len(out_edges[referrer_d]) == 0:
referrer_vertex = graph.get_vertex(referrer_d)
if referrer_vertex is not None and referrer_vertex not in result:
available.append(referrer_vertex)
return result
# ---------------------------------------------------------------------------
# Convenience: full pipeline
# ---------------------------------------------------------------------------
def total_order_positions(graph: LamportGraph,
leader_stream: LeaderStream) -> dict[bytes, int]:
"""Return a mapping of message digest -> total order position.
This is the final output of the Crisis protocol: a total order on
messages that respects causality and is probabilistically invariant
among all honest participants.
"""
ordered = compute_order(graph, leader_stream)
return {v.message_digest: v.total_position for v in ordered
if v.total_position is not None}

231
src/crisis/rounds.py Normal file
View file

@ -0,0 +1,231 @@
"""
Virtual Synchronous Rounds (Section 5.3)
Lamport graphs represent a timelike order between vertices that we interpret
as virtual communication channels. Going one step further, we can think from
inside the Lamport graph to define a virtual clock tick as a transition from
one vertex to another.
This simple idea allows for internal synchronism that enables us to execute
strongly synchronous agreement protocols like Feldman & Micali's BA*
virtually, but without any compromise in external asynchronism.
Algorithm 5 (Virtual synchronous rounds):
The algorithm computes *round numbers* and the *is_last* property
of any vertex.
- The round number is computed by taking the largest round of all
direct causes.
- If the vertex is a direct effect of a current round vertex with
the is_last property, a new round begins.
- If the vertex has enough last vertices of the previous round in its
past and it is k-reachable from all of them, the vertex becomes a
last vertex in its own round.
Definition 5.1 (k-reachability):
v_hat is said to be k-reachable from v, if the overall weight of all
vertices in all paths from v to v_hat is greater than k.
Proposition 5.3 (Round invariance):
The round number and is_last property do not depend on the actual
Lamport graph, but are the same for equivalent vertices.
"""
from __future__ import annotations
from crisis.graph import LamportGraph
from crisis.message import Vertex
from crisis.weight import DifficultyOracle
def compute_rounds(graph: LamportGraph, difficulty: DifficultyOracle,
connectivity_k: int = 2) -> None:
"""Execute Algorithm 5 on all vertices in the graph.
This computes v.round and v.is_last for every vertex v in the graph.
The algorithm processes vertices in causal order (causes before effects)
to ensure dependencies are resolved before they are needed.
Args:
graph: The Lamport graph to process.
difficulty: The difficulty oracle d : N -> W.
connectivity_k: The connectivity parameter k for k-reachability.
"""
# Process vertices in topological order (causes first)
ordered = _topological_sort(graph)
for vertex in ordered:
_compute_round_for_vertex(vertex, graph, difficulty, connectivity_k)
def _compute_round_for_vertex(vertex: Vertex, graph: LamportGraph,
difficulty: DifficultyOracle,
connectivity_k: int) -> None:
"""Algorithm 5: compute round number and is_last for a single vertex.
Pseudocode from the paper:
1: procedure ROUND(vertex:v, lamport_graph:G)
2: N_v <- {v_hat G | v -> v_hat} # direct causes
3: r <- max({v_hat.round | v_hat N_v} {0})
4: if there is a v_hat N_v with v_hat.is_last and v_hat.round = r then
5: v.round <- r + 1
6: else
7: v.round <- r
8: end if
9: S_r <- {v_hat G | v_hat.round = v.round - 1, v_hat.is_last, v_hat _k v}
10: if w(S_r) > 3 * d_r then
11: v.is_last <- true
12: else
13: v.is_last <- (r = 0)
14: end if
15: end procedure
"""
# Step 2: direct causes
direct_causes = graph.direct_causes(vertex)
# Step 3: max round of direct causes (default 0 if no causes)
if direct_causes:
max_round = max(
(dc.round if dc.round is not None else 0) for dc in direct_causes
)
else:
max_round = 0
# Steps 4-8: determine this vertex's round
# If any direct cause is a "last vertex" of the current max round,
# this vertex starts a new round.
has_last_cause_in_max_round = any(
dc.is_last and dc.round == max_round
for dc in direct_causes
if dc.round is not None and dc.is_last is not None
)
if has_last_cause_in_max_round:
vertex.round = max_round + 1
else:
vertex.round = max_round
# Steps 9-14: determine is_last
r = vertex.round
if r == 0:
# All round-0 vertices are "last" (bootstrapping)
vertex.is_last = True
return
# Find last vertices of the previous round that are k-reachable from v
d_r = difficulty.difficulty(r)
previous_round_lasts = [
v_hat for v_hat in graph.all_vertices()
if v_hat.round == r - 1
and v_hat.is_last
and _is_k_reachable(v_hat, vertex, graph, connectivity_k)
]
# Weight of k-reachable last vertices from previous round
weight_of_previous_lasts = graph.set_weight(previous_round_lasts)
if weight_of_previous_lasts > 3 * d_r:
vertex.is_last = True
else:
vertex.is_last = False
def _is_k_reachable(v_from: Vertex, v_to: Vertex,
graph: LamportGraph, k: int) -> bool:
"""Check k-reachability (Definition 5.1).
v_hat is k-reachable from v if the overall weight of all vertices in
all paths from v to v_hat is greater than k.
For simplicity in this PoC, we approximate this by checking if v_from
is in the past of v_to and the total weight along the path exceeds k.
The paper notes (page 11): "counting disjoint paths is computationally
expensive and not really necessary in our setting... all we need is some
insurance that information flows through enough real world processes."
We use total path weight as a simpler proxy.
"""
if v_from not in graph.past(v_to):
return False
# Compute the weight of all vertices in the path from v_from to v_to
# (all vertices that are in both the future of v_from and the past of v_to)
past_of_to = graph.past(v_to)
future_of_from = graph.future(v_from)
path_vertices = past_of_to & future_of_from
total_weight = graph.set_weight(path_vertices)
return total_weight > k
def _topological_sort(graph: LamportGraph) -> list[Vertex]:
"""Sort vertices in causal order: causes come before their effects.
Uses Kahn's algorithm. Vertices with no causes (sources) come first.
This ensures that when we process a vertex, all its causes already
have their round numbers computed.
"""
# Compute in-degree (number of causes each vertex has within the graph)
in_degree: dict[bytes, int] = {}
for d, v in graph.vertices.items():
in_degree[d] = 0
for d, v in graph.vertices.items():
for ref_d in graph.edges.get(d, set()):
if ref_d in graph.vertices:
# ref_d is a cause of d, so d has an additional in-edge
# But we want causal order: causes first
# edges go from effect -> cause, so we need reverse
pass
# Actually: edges[d] contains the causes of d (d -> cause).
# For topological sort where causes come first, we need:
# in_degree[d] = number of digests in edges[d] that are in the graph
for d in graph.vertices:
count = 0
for cause_d in graph.edges.get(d, set()):
if cause_d in graph.vertices:
count += 1
in_degree[d] = count
# Start with vertices that have no causes (in_degree = 0)
queue = [d for d, deg in in_degree.items() if deg == 0]
result = []
while queue:
current = queue.pop(0)
result.append(graph.vertices[current])
# For each vertex that current is a cause of (reverse edges)
for effect_d in graph.reverse_edges.get(current, set()):
if effect_d in in_degree:
in_degree[effect_d] -= 1
if in_degree[effect_d] == 0:
queue.append(effect_d)
return result
# ---------------------------------------------------------------------------
# Queries on computed rounds
# ---------------------------------------------------------------------------
def last_vertices_in_round(graph: LamportGraph, round_number: int) -> list[Vertex]:
"""Return all last vertices in a given round."""
return [
v for v in graph.all_vertices()
if v.round == round_number and v.is_last
]
def max_round(graph: LamportGraph) -> int:
"""Return the highest round number in the graph."""
rounds = [v.round for v in graph.all_vertices() if v.round is not None]
return max(rounds) if rounds else 0
def vertices_in_round(graph: LamportGraph, round_number: int) -> list[Vertex]:
"""Return all vertices in a given round."""
return [v for v in graph.all_vertices() if v.round == round_number]

527
src/crisis/voting.py Normal file
View file

@ -0,0 +1,527 @@
"""
Virtual Voting, Safe Voting Patterns, and Leader Election (Section 5)
This module implements the heart of the Crisis protocol: the virtual voting
mechanism that achieves total order without ever sending explicit vote messages.
Key concepts:
5.5 Virtual Process Sortition & Knowledge Graphs
- Knowledge graph (Def 5.8): quotient graph projecting vertices to virtual
processes, representing what each process "knows" about others.
- Quorum selector (Def 5.11): deterministically chooses a subset of virtual
processes for each round -- the quorum that participates in agreement.
5.6 Safe Voting Pattern
- Voting sets (Def 5.12): the set of vertices participating in round s
agreement, reachable with connectivity k from vertex v.
- Algorithm 6: computes the safe voting pattern -- a nested sequence of
rounds where voting took place with appropriately bounded byzantine weight.
5.7 Local Leader Election
- Algorithm 7: virtual leader elections -- an adaptation of Chen, Feldman
& Micali's BA* to virtual voting on Lamport graphs.
- Three stage types: initial proposal (δ=0), presorting/gradecast (δ{1,2}),
and BBA* binary agreement (δ3) with "coin fixed to 0/1" and "genuine
coin flip" sub-stages.
5.8 Longest Chain Rule
- Algorithm 8: maintains the leader stream by keeping only the longest
chain of round leaders (similar to Nakamoto's longest chain rule).
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Optional
from crisis.crypto import digest, least_significant_bit
from crisis.graph import LamportGraph
from crisis.message import Message, Vertex, Vote, EMPTY_MESSAGE_DIGEST
from crisis.rounds import last_vertices_in_round, max_round
from crisis.weight import DifficultyOracle
# ---------------------------------------------------------------------------
# Knowledge Graph (Definition 5.8)
# ---------------------------------------------------------------------------
@dataclass
class KnowledgeGraph:
"""The round s knowledge graph of vertex v (Definition 5.8).
Given rounds s < r, a Lamport graph G, and v a last message in round r,
the knowledge graph Π^s_v is the quotient graph G^s_v / _id.
Each node in the knowledge graph represents a virtual process (identified
by its id). An edge from process id to id' means that some vertex with
v.id = id in round s has a vertex with v_hat.id = id' in its past.
This represents what each virtual process "knows" about others.
"""
# id -> set of ids that this process has edges to
edges: dict[bytes, set[bytes]] = field(default_factory=dict)
# id -> total weight of vertices in this equivalence class
weights: dict[bytes, int] = field(default_factory=dict)
def build_knowledge_graph(vertex: Vertex, round_s: int,
graph: LamportGraph) -> KnowledgeGraph:
"""Build the round s knowledge graph for vertex v.
Collects all round-s vertices in v's past, groups them by id,
and builds the quotient graph.
"""
kg = KnowledgeGraph()
past = graph.past(vertex)
# Find all round-s vertices in v's past
round_s_vertices = [v for v in past if v.round == round_s]
# Group by id and compute edges
for v_s in round_s_vertices:
vid = v_s.id
if vid not in kg.edges:
kg.edges[vid] = set()
if vid not in kg.weights:
kg.weights[vid] = 0
kg.weights[vid] = graph.weight_system.weight_sum(
kg.weights[vid], graph.vertex_weight(v_s)
)
# Add edges based on what this vertex references
for cause in graph.direct_causes(v_s):
if cause.round is not None and cause.round == round_s:
kg.edges[vid].add(cause.id)
return kg
# ---------------------------------------------------------------------------
# Quorum Selector (Definition 5.11)
# ---------------------------------------------------------------------------
def select_quorum(knowledge_graph: KnowledgeGraph, n: int = 3) -> set[bytes]:
"""Select a quorum from a knowledge graph (Definition 5.11).
Example 3 (Highest voting weight quorum):
Choose the weakly connected component with the highest combined voting
weight, then take the heaviest n virtual processes from it.
The quorum selector serves as a filter to reduce byzantine noise that
might appear in the voting process. By restricting to a heavily
connected component, faulty behavior based on graph partition is reduced.
"""
if not knowledge_graph.edges:
return set()
# Find weakly connected components using simple BFS
all_ids = set(knowledge_graph.edges.keys())
visited: set[bytes] = set()
components: list[set[bytes]] = []
for start_id in all_ids:
if start_id in visited:
continue
component: set[bytes] = set()
queue = [start_id]
while queue:
current = queue.pop(0)
if current in visited:
continue
visited.add(current)
component.add(current)
# Follow edges in both directions (weakly connected)
for neighbor in knowledge_graph.edges.get(current, set()):
if neighbor not in visited and neighbor in all_ids:
queue.append(neighbor)
# Reverse edges
for other_id, neighbors in knowledge_graph.edges.items():
if current in neighbors and other_id not in visited:
queue.append(other_id)
components.append(component)
# Choose the component with highest total weight
def component_weight(comp: set[bytes]) -> int:
return sum(knowledge_graph.weights.get(pid, 0) for pid in comp)
best_component = max(components, key=component_weight)
# Take the n heaviest processes from this component
sorted_by_weight = sorted(
best_component,
key=lambda pid: knowledge_graph.weights.get(pid, 0),
reverse=True
)
return set(sorted_by_weight[:n])
# ---------------------------------------------------------------------------
# Voting Sets (Definition 5.12)
# ---------------------------------------------------------------------------
def voting_set(vertex: Vertex, round_s: int, connectivity_k: int,
graph: LamportGraph) -> set[Vertex]:
"""Compute S_v(s,k): v's round s voting set (Definition 5.12).
S_v(s,k) := { x | x.id Q(v,s) x _{(r-s)*k} v
x.round = s x.is_last = true }
The voting set consists of all last vertices in round s that:
1. Belong to a quorum-selected virtual process
2. Are k-reachable from v (with distance scaled by round gap)
3. Are in v's past
"""
if vertex.round is None:
return set()
r = vertex.round
if round_s >= r:
return set()
# Build knowledge graph and select quorum
kg = build_knowledge_graph(vertex, round_s, graph)
quorum = select_quorum(kg)
past_of_v = graph.past(vertex)
result = set()
for v_hat in past_of_v:
if (v_hat.round == round_s
and v_hat.is_last
and v_hat.id in quorum):
result.add(v_hat)
return result
# ---------------------------------------------------------------------------
# Algorithm 6: Safe Voting Pattern (Section 5.6)
# ---------------------------------------------------------------------------
def compute_safe_voting_pattern(vertex: Vertex, graph: LamportGraph,
difficulty: DifficultyOracle,
connectivity_k: int = 2) -> None:
"""Algorithm 6: compute the safe voting pattern for a vertex.
The safe voting pattern v.svp is a totally ordered set of round numbers
where "safe" voting took place. Safe means:
- The voting set has enough overall weight
- The svp of all members agree
- Byzantine weight is bounded
Pseudocode from the paper:
1: procedure SVP(vertex:v, lamport_graph:G)
2: v.svp <-
3: if v.is_last and [safe voting pattern conditions are met] then
4: s <- maximum of all such k
5: v.svp <- v.svp {s} for all t s
6: end if
7: end procedure
The procedure checks if the current vertex's round qualifies as a new
entry in the safe voting pattern by verifying weight and agreement
conditions from its voting set.
"""
vertex.svp = []
if not vertex.is_last or vertex.round is None or vertex.round == 0:
return
r = vertex.round
# Check each previous round for safe voting pattern membership
for s in range(r):
d_s = difficulty.difficulty(s)
# Get voting set for round s
vs = voting_set(vertex, s, connectivity_k, graph)
if not vs:
continue
total_weight = graph.set_weight(vs)
# Check if voting weight exceeds threshold (6 * d_s from Eq. 8)
if total_weight <= 6 * d_s:
continue
# Check that all members of the voting set have compatible svp
svps_agree = True
for x in vs:
for y in vs:
if x.svp != y.svp:
# Allow prefix agreement
min_len = min(len(x.svp), len(y.svp))
if x.svp[:min_len] != y.svp[:min_len]:
svps_agree = False
break
if not svps_agree:
break
if svps_agree:
vertex.svp.append(s)
# svp is a nested sequence: add current round
if vertex.svp:
vertex.svp.append(r)
# ---------------------------------------------------------------------------
# Initial Vote Function (Definition 5.16, Example 4)
# ---------------------------------------------------------------------------
def initial_vote(vertices: set[Vertex], graph: LamportGraph) -> Optional[Message]:
"""INITIAL_VOTE: deterministically choose a leader proposal (Def 5.16).
Example 4 (Highest weight): Choose the underlying message of the highest
voting weight vertex. Since we assume it is infeasible to have different
vertices of equal weight, this is practically deterministic.
The initial vote function is a system parameter. Different choices lead
to different long-term behavior. Ideally all members of a safe voting
pattern would compute the same initial vote.
"""
if not vertices:
return None
best_vertex = max(vertices, key=lambda v: graph.vertex_weight(v))
return best_vertex.m
# ---------------------------------------------------------------------------
# Algorithm 7: Virtual Leader Elections (Section 5.7)
# ---------------------------------------------------------------------------
def compute_virtual_leader_election(vertex: Vertex, graph: LamportGraph,
difficulty: DifficultyOracle,
connectivity_k: int,
leader_stream: dict[int, list[tuple[int, Message]]]) -> None:
"""Algorithm 7: compute votes for all rounds in v's safe voting pattern.
This is the core virtual BA* protocol. For each element t in v.svp,
the vertex computes a vote v.vote(t) = (l, b) based on the stage δ
(the position of that round in the svp).
Stage types (determined by δ = d_{v.svp}(s, t)):
δ = 0: Initial leader proposal
δ = 1: Leader presorting (gradecast step)
δ = 2: BBA* initialization (gradecast step)
δ 3: Binary agreement rounds
δ mod 3 = 0: Coin fixed to 0
δ mod 3 = 1: Coin fixed to 1
δ mod 3 = 2: Genuine coin flip
The paper notes: "every step is entirely virtual and no votes are
actually sent to other real world processes."
"""
if not vertex.svp:
return
s = max(vertex.svp) if vertex.svp else None
if s is None:
return
for t_idx, t in enumerate(vertex.svp):
delta = t_idx # stage = position in svp
_compute_vote_for_stage(vertex, t, delta, s, graph, difficulty,
connectivity_k, leader_stream)
def _compute_vote_for_stage(vertex: Vertex, t: int, delta: int, s: int,
graph: LamportGraph, difficulty: DifficultyOracle,
connectivity_k: int,
leader_stream: dict[int, list[tuple[int, Message]]]) -> None:
"""Compute vertex's vote for a specific stage of the virtual leader election.
Implements the branching logic of Algorithm 7 (pages 19-20 of the paper).
"""
d_s = difficulty.difficulty(s)
vs = voting_set(vertex, t, connectivity_k, graph)
n = graph.set_weight(vs)
NON_LEADER = None # ∅ in the paper
if delta == 0:
# Stage 0: Initial leader proposal
l = initial_vote(vs, graph)
vertex.vote[t] = Vote(message=l, binary=None) # (INITIAL_VOTE(S), ⊥)
elif delta == 1:
# Stage 1: Leader presorting
# Find message with highest round-t voting weight in S
l = _highest_weight_message(vs, graph)
if l is not None:
# Check if l has super majority weight
l_weight = _vote_weight_for(vs, t, l, None, graph) # votes for (l, ⊥)
if l_weight > n - d_s:
vertex.vote[t] = Vote(message=l, binary=None) # (l, ⊥)
else:
vertex.vote[t] = Vote(message=NON_LEADER, binary=None) # (∅, ⊥)
else:
vertex.vote[t] = Vote(message=NON_LEADER, binary=None)
elif delta == 2:
# Stage 2: BBA* initialization (gradecast)
l = _highest_weight_message(vs, graph)
if l is not None:
l_weight_undecided = _vote_weight_for(vs, t, l, None, graph)
if l_weight_undecided > n - d_s:
vertex.vote[t] = Vote(message=l, binary=0)
else:
l_weight_1 = _vote_weight_for(vs, t, l, 1, graph)
if l_weight_1 > d_s:
vertex.vote[t] = Vote(message=l, binary=1)
else:
vertex.vote[t] = Vote(message=NON_LEADER, binary=1)
else:
vertex.vote[t] = Vote(message=NON_LEADER, binary=1)
else:
# Stage δ ≥ 3: Binary agreement (BBA*)
coin_stage = delta % 3
l = _highest_weight_message(vs, graph)
if coin_stage == 0:
# Coin fixed to 0
_bba_coin_fixed(vertex, t, vs, l, n, d_s, graph,
leader_stream, s, fixed_value=0)
elif coin_stage == 1:
# Coin fixed to 1
_bba_coin_fixed(vertex, t, vs, l, n, d_s, graph,
leader_stream, s, fixed_value=1)
else:
# Genuine coin flip (coin_stage == 2)
_bba_genuine_coin(vertex, t, vs, l, n, d_s, graph)
def _bba_coin_fixed(vertex: Vertex, t: int, vs: set[Vertex],
l: Optional[Message], n: int, d_s: int,
graph: LamportGraph,
leader_stream: dict[int, list[tuple[int, Message]]],
s: int, fixed_value: int) -> None:
"""BBA* stage with coin fixed to 0 or 1."""
other_value = 1 - fixed_value
if l is not None:
weight_for_fixed = _vote_weight_for_binary(vs, t, fixed_value, graph)
if weight_for_fixed > n - d_s:
vertex.vote[t] = Vote(message=l, binary=fixed_value)
# If weight = n, we have agreement: update leader stream
if weight_for_fixed == n:
_update_leader_stream(leader_stream, l, s)
return
weight_for_other = _vote_weight_for_binary(vs, t, other_value, graph)
if weight_for_other > n - d_s:
vertex.vote[t] = Vote(message=l, binary=other_value)
return
vertex.vote[t] = Vote(message=l, binary=fixed_value)
def _bba_genuine_coin(vertex: Vertex, t: int, vs: set[Vertex],
l: Optional[Message], n: int, d_s: int,
graph: LamportGraph) -> None:
"""BBA* stage with genuine coin flip."""
if l is not None:
weight_0 = _vote_weight_for_binary(vs, t, 0, graph)
if weight_0 > n - d_s:
vertex.vote[t] = Vote(message=l, binary=0)
return
weight_1 = _vote_weight_for_binary(vs, t, 1, graph)
if weight_1 > n - d_s:
vertex.vote[t] = Vote(message=l, binary=1)
return
# Genuine coin flip: use LSB of hash of heaviest vertex's message
if vs:
heaviest = max(vs, key=lambda v: graph.vertex_weight(v))
h = heaviest.m.compute_digest()
b_coin = least_significant_bit(h)
vertex.vote[t] = Vote(message=l, binary=b_coin)
else:
vertex.vote[t] = Vote(message=l, binary=0)
# ---------------------------------------------------------------------------
# Helper functions for vote weight computation
# ---------------------------------------------------------------------------
def _highest_weight_message(vs: set[Vertex], graph: LamportGraph) -> Optional[Message]:
"""Find the message with the highest voting weight in a set."""
if not vs:
return None
best = max(vs, key=lambda v: graph.vertex_weight(v))
return best.m
def _vote_weight_for(vs: set[Vertex], round_t: int,
target_msg: Optional[Message], target_binary: Optional[int],
graph: LamportGraph) -> int:
"""Compute total voting weight for a specific vote (l, b) in a voting set."""
total = 0
for v in vs:
vote = v.vote.get(round_t)
if vote is None:
continue
msg_match = (vote.message is None and target_msg is None) or \
(vote.message is not None and target_msg is not None and
vote.message.compute_digest() == target_msg.compute_digest())
bin_match = vote.binary == target_binary
if msg_match and bin_match:
total = graph.weight_system.weight_sum(total, graph.vertex_weight(v))
return total
def _vote_weight_for_binary(vs: set[Vertex], round_t: int,
target_binary: int,
graph: LamportGraph) -> int:
"""Compute total voting weight for a specific binary value in a voting set."""
total = 0
for v in vs:
vote = v.vote.get(round_t)
if vote is not None and vote.binary == target_binary:
total = graph.weight_system.weight_sum(total, graph.vertex_weight(v))
return total
# ---------------------------------------------------------------------------
# Algorithm 8: Longest Chain Rule (Section 5.8)
# ---------------------------------------------------------------------------
def _update_leader_stream(leader_stream: dict[int, list[tuple[int, Message]]],
message: Message, round_number: int) -> None:
"""Algorithm 8: update the leader stream with a new leader candidate.
The longest chain rule keeps only the chain with the highest deciding
round for each round leader. When a new round leader is decided at a
higher deciding round, previous entries with lower deciding rounds are
replaced.
Pseudocode:
1: procedure LONG_CHAIN(set{(uint,MESSAGE)}:S, MESSAGE:m, uint:s)
2: if there is no (l, t) S with t > s then
3: S <- {(l, t) S | t < s} (s, m)
4: end if
5: return S
6: end procedure
"""
if round_number not in leader_stream:
leader_stream[round_number] = []
entries = leader_stream[round_number]
# Check if there's already an entry with a higher deciding round
has_higher = any(t > round_number for (t, _) in entries)
if has_higher:
return
# Remove entries with lower deciding rounds, add new one
leader_stream[round_number] = [
(t, m) for (t, m) in entries if t < round_number
] + [(round_number, message)]

190
src/crisis/weight.py Normal file
View file

@ -0,0 +1,190 @@
"""
Weight Systems (Section 3.1.1)
Definition 3.1 (Weight system): Let MESSAGE be the metric space of all
messages and (W, ) a totally ordered set. Then the tuple (W, w, , c_min)
is a *weight system* if w is a function
w : MESSAGE -> W (Eq. 3)
that assigns an element of W to any message, c_min W is a constant called
the *weight threshold*, and is a function
: W × W -> W (Eq. 4)
called the *weight sum*, such that:
- Tamper proof: w(m) >= c_min and m_hat m implies w(m_hat) < c_min
with high probability.
- Uniqueness: m m_hat implies w(m) w(m_hat) with high probability.
- Summability: (W, ) is a totally ordered, abelian group.
The weight w(m) is interpreted as the amount of voting power m holds to
influence total order generation.
This module provides:
1. An abstract WeightSystem protocol
2. A concrete Proof-of-Work implementation (Hashcash-style)
The PoW weight function counts leading zero bits of H(m), similar to Bitcoin's
difficulty mechanism (Nakamoto, 2009; Beck, 2002).
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Protocol
from crisis.crypto import digest, count_leading_zero_bits
from crisis.message import Message
# ---------------------------------------------------------------------------
# Abstract weight system
# ---------------------------------------------------------------------------
class WeightSystem(Protocol):
"""Protocol defining the weight system interface (Definition 3.1).
Any concrete weight system must provide:
- weight(): Compute w(m) for a message
- threshold: The minimum weight c_min
- weight_sum(): Compute for two weights
"""
@property
def threshold(self) -> int:
"""c_min: the minimum weight threshold.
Messages with weight below this are rejected. This prevents Sybil
attacks by ensuring every message requires a minimum investment.
"""
...
def weight(self, message: Message) -> int:
"""w(m): compute the weight of a message.
The weight represents the voting power of this message in the
consensus protocol.
"""
...
def weight_sum(self, a: int, b: int) -> int:
"""⊕: combine two weights.
Must form a totally ordered abelian group.
For our purposes, ordinary integer addition suffices.
"""
...
def is_valid_weight(self, message: Message) -> bool:
"""Check whether w(m) >= c_min."""
...
# ---------------------------------------------------------------------------
# Proof-of-Work weight system
# ---------------------------------------------------------------------------
@dataclass
class ProofOfWorkWeight:
"""A Hashcash-style Proof-of-Work weight system.
The weight of a message is the number of leading zero bits in H(m).
This is similar to Bitcoin's mining: finding a message whose hash starts
with k zero bits requires approximately 2^k hash evaluations on average.
The nonce field of the message is used to search for a valid hash,
analogous to Bitcoin's block header nonce.
Attributes:
min_leading_zeros: c_min -- minimum leading zero bits required.
A value of 1 means every message needs at least
1 leading zero bit (50% of hashes qualify).
"""
min_leading_zeros: int = 1
@property
def threshold(self) -> int:
return self.min_leading_zeros
def weight(self, message: Message) -> int:
"""Count leading zero bits in H(m).
More leading zeros = more work performed = higher voting weight.
"""
h = message.compute_digest()
return count_leading_zero_bits(h)
def weight_sum(self, a: int, b: int) -> int:
"""Simple integer addition for combining weights.
This satisfies the abelian group requirement: (Z, +) is a totally
ordered abelian group with identity 0.
"""
return a + b
def is_valid_weight(self, message: Message) -> bool:
"""Check w(m) >= c_min."""
return self.weight(message) >= self.threshold
def mine_nonce(self, id_bytes: bytes, digests: tuple[bytes, ...],
payload: bytes, target_weight: int | None = None) -> Message:
"""Search for a nonce that produces a message meeting the weight target.
This is the "nonce grinding" step: try successive nonce values until
H(m) has enough leading zero bits.
Args:
id_bytes: The virtual process id for this message.
digests: Causal acknowledgements (digests of prior messages).
payload: The application payload.
target_weight: Minimum weight to achieve. Defaults to c_min.
Returns:
A Message with a valid nonce.
"""
if target_weight is None:
target_weight = self.threshold
nonce_int = 0
while True:
nonce = nonce_int.to_bytes(8, "big")
msg = Message(nonce=nonce, id=id_bytes, digests=digests, payload=payload)
if self.weight(msg) >= target_weight:
return msg
nonce_int += 1
# ---------------------------------------------------------------------------
# Difficulty Oracle (Section 5.4, Definition 5.2)
# ---------------------------------------------------------------------------
@dataclass
class DifficultyOracle:
"""Maps round numbers to difficulty values (Definition 5.2).
The difficulty oracle d : N -> W maps natural numbers (rounds) onto
weights. The value d_r := d(r) is called the *round r difficulty*.
The difficulty is designed so that the overall voting weight per round
is bounded:
lim sum(w_s^G / d_s) <= 6 (Eq. 8)
for all time parameters t, where w_s^G is the overall voting weight of
last vertices in round s.
Example 1 (paper): A fixed constant that does not change over time.
This is the simplest starting point for a PoC.
"""
constant_difficulty: int = 4
def difficulty(self, round_number: int) -> int:
"""d(r): return the difficulty for round r.
For this PoC we use a fixed constant (paper Example 1).
A production system might adapt this based on observed voting
weight, similar to Bitcoin's difficulty adjustment.
"""
return self.constant_difficulty

0
tests/__init__.py Normal file
View file

57
tests/test_crypto.py Normal file
View file

@ -0,0 +1,57 @@
"""Tests for the crypto module (random oracle model)."""
from crisis.crypto import (
digest, digest_hex, verify_digest,
least_significant_bit, count_leading_zero_bits,
DIGEST_LENGTH,
)
def test_digest_returns_32_bytes():
h = digest(b"hello")
assert len(h) == DIGEST_LENGTH == 32
def test_digest_is_deterministic():
assert digest(b"test") == digest(b"test")
def test_digest_different_inputs_different_outputs():
assert digest(b"a") != digest(b"b")
def test_digest_hex_matches():
h = digest(b"hello")
assert digest_hex(b"hello") == h.hex()
def test_verify_digest():
h = digest(b"data")
assert verify_digest(b"data", h)
assert not verify_digest(b"other", h)
def test_least_significant_bit():
# 0x00 -> LSB = 0, 0x01 -> LSB = 1
assert least_significant_bit(b"\x00") == 0
assert least_significant_bit(b"\x01") == 1
assert least_significant_bit(b"\x02") == 0
assert least_significant_bit(b"\x03") == 1
assert least_significant_bit(b"\xff") == 1
assert least_significant_bit(b"\xfe") == 0
def test_count_leading_zero_bits():
assert count_leading_zero_bits(b"\xff") == 0
assert count_leading_zero_bits(b"\x7f") == 1
assert count_leading_zero_bits(b"\x3f") == 2
assert count_leading_zero_bits(b"\x00\xff") == 8
assert count_leading_zero_bits(b"\x00\x00\x01") == 23
assert count_leading_zero_bits(b"\x00") == 8
def test_empty_digest_is_well_defined():
"""Paper: 'Acknowledgement of the empty string is defined as H(∅)'."""
h = digest(b"")
assert len(h) == 32
assert h == digest(b"") # deterministic

218
tests/test_graph.py Normal file
View file

@ -0,0 +1,218 @@
"""Tests for the Lamport graph with integrity checks."""
import os
import pytest
from crisis.crypto import digest
from crisis.graph import LamportGraph
from crisis.message import Message, Vertex, ID_LENGTH, NONCE_LENGTH
from crisis.weight import ProofOfWorkWeight
def make_id(name: str) -> bytes:
return digest(name.encode())[:ID_LENGTH]
def make_nonce(n: int = 0) -> bytes:
return n.to_bytes(NONCE_LENGTH, "big")
def make_graph(pow_zeros: int = 0) -> LamportGraph:
return LamportGraph(weight_system=ProofOfWorkWeight(min_leading_zeros=pow_zeros))
class TestLamportGraphExtension:
def test_extend_single_message(self):
g = make_graph()
msg = Message(nonce=make_nonce(), id=make_id("alice"), payload=b"hello")
v = g.extend(msg)
assert v is not None
assert g.vertex_count() == 1
def test_extend_chain(self):
"""Messages from the same id must form a chain."""
g = make_graph()
m1 = Message(nonce=make_nonce(0), id=make_id("alice"), payload=b"first")
v1 = g.extend(m1)
assert v1 is not None
m2 = Message(
nonce=make_nonce(1), id=make_id("alice"),
digests=(m1.compute_digest(),),
payload=b"second"
)
v2 = g.extend(m2)
assert v2 is not None
assert g.vertex_count() == 2
def test_reject_duplicate(self):
"""No two equivalent vertices in the same graph."""
g = make_graph()
msg = Message(nonce=make_nonce(), id=make_id("alice"), payload=b"x")
g.extend(msg)
v2 = g.extend(msg)
assert v2 is None # Rejected: duplicate
assert g.vertex_count() == 1
def test_reject_missing_reference(self):
"""Digests must reference existing vertices."""
g = make_graph()
fake_digest = digest(b"nonexistent")
msg = Message(
nonce=make_nonce(), id=make_id("alice"),
digests=(fake_digest,), payload=b"orphan"
)
v = g.extend(msg)
assert v is None # Rejected
def test_reject_broken_chain(self):
"""Second message from same id must reference a same-id vertex."""
g = make_graph()
id_a = make_id("alice")
id_b = make_id("bob")
m1 = Message(nonce=make_nonce(0), id=id_a, payload=b"first")
g.extend(m1)
m_bob = Message(nonce=make_nonce(1), id=id_b, payload=b"bob's msg")
g.extend(m_bob)
# Alice's second message references bob but not herself -> rejected
m2 = Message(
nonce=make_nonce(2), id=id_a,
digests=(m_bob.compute_digest(),),
payload=b"broken chain"
)
v = g.extend(m2)
assert v is None
class TestCausality:
def _build_chain(self):
"""Build a simple 3-message chain: m1 <- m2 <- m3."""
g = make_graph()
id_a = make_id("alice")
m1 = Message(nonce=make_nonce(0), id=id_a, payload=b"m1")
v1 = g.extend(m1)
m2 = Message(nonce=make_nonce(1), id=id_a,
digests=(m1.compute_digest(),), payload=b"m2")
v2 = g.extend(m2)
m3 = Message(nonce=make_nonce(2), id=id_a,
digests=(m2.compute_digest(),), payload=b"m3")
v3 = g.extend(m3)
return g, v1, v2, v3
def test_direct_causes(self):
g, v1, v2, v3 = self._build_chain()
causes_of_v3 = g.direct_causes(v3)
assert v2 in causes_of_v3
assert v1 not in causes_of_v3
def test_direct_effects(self):
g, v1, v2, v3 = self._build_chain()
effects_of_v1 = g.direct_effects(v1)
assert v2 in effects_of_v1
assert v3 not in effects_of_v1 # v3 is indirect
def test_past(self):
"""G_v: the past of v contains all its causes."""
g, v1, v2, v3 = self._build_chain()
past_of_v3 = g.past(v3)
assert v1 in past_of_v3
assert v2 in past_of_v3
assert v3 in past_of_v3 # reflexive
def test_future(self):
g, v1, v2, v3 = self._build_chain()
future_of_v1 = g.future(v1)
assert v2 in future_of_v1
assert v3 in future_of_v1
assert v1 in future_of_v1 # reflexive
def test_is_cause_of(self):
g, v1, v2, v3 = self._build_chain()
assert g.is_cause_of(v1, v3)
assert g.is_cause_of(v1, v2)
assert not g.is_cause_of(v3, v1)
def test_timelike(self):
g, v1, v2, v3 = self._build_chain()
assert g.are_timelike(v1, v3)
assert g.are_timelike(v3, v1)
def test_spacelike(self):
"""Two independent vertices are spacelike."""
g = make_graph()
m_a = Message(nonce=make_nonce(0), id=make_id("alice"), payload=b"a")
m_b = Message(nonce=make_nonce(0), id=make_id("bob"), payload=b"b")
va = g.extend(m_a)
vb = g.extend(m_b)
assert g.are_spacelike(va, vb)
assert not g.are_timelike(va, vb)
class TestInvarianceOfThePast:
"""Theorem 3.7: The past of equivalent vertices in two Lamport graphs
have the same cardinality."""
def test_past_invariance_simple(self):
"""Same message in two different graphs has same-size past."""
g1 = make_graph()
g2 = make_graph()
id_a = make_id("alice")
m1 = Message(nonce=make_nonce(0), id=id_a, payload=b"genesis")
m2 = Message(nonce=make_nonce(1), id=id_a,
digests=(m1.compute_digest(),), payload=b"second")
# Add to both graphs
g1.extend(m1)
v1_in_g1 = g1.extend(m2)
g2.extend(m1)
v1_in_g2 = g2.extend(m2)
# Past should be the same size
assert len(g1.past(v1_in_g1)) == len(g2.past(v1_in_g2))
class TestMessageGeneration:
def test_generate_first_message(self):
g = make_graph()
msg = g.generate_message(make_id("alice"), b"hello")
v = g.extend(msg)
assert v is not None
assert v.payload == b"hello"
def test_generate_chain(self):
g = make_graph()
pid = make_id("alice")
m1 = g.generate_message(pid, b"first")
g.extend(m1)
m2 = g.generate_message(pid, b"second")
v2 = g.extend(m2)
assert v2 is not None
# Second message should reference the first
assert m1.compute_digest() in m2.digests
def test_generate_cross_references(self):
"""Messages should reference vertices from other process ids."""
g = make_graph()
pid_a = make_id("alice")
pid_b = make_id("bob")
m_a = g.generate_message(pid_a, b"alice's msg")
g.extend(m_a)
m_b = g.generate_message(pid_b, b"bob's msg")
g.extend(m_b)
# Alice's second message should reference bob's message
m_a2 = g.generate_message(pid_a, b"alice second")
assert m_b.compute_digest() in m_a2.digests or m_a.compute_digest() in m_a2.digests

125
tests/test_message.py Normal file
View file

@ -0,0 +1,125 @@
"""Tests for the message and vertex data structures."""
import pytest
from crisis.crypto import digest, DIGEST_LENGTH
from crisis.message import (
Message, Vertex, Vote,
NONCE_LENGTH, ID_LENGTH, NUM_DIGESTS_LENGTH,
EMPTY_MESSAGE_DIGEST,
)
def make_id(name: str) -> bytes:
return digest(name.encode())[:ID_LENGTH]
def make_nonce(n: int = 0) -> bytes:
return n.to_bytes(NONCE_LENGTH, "big")
class TestMessage:
def test_create_minimal_message(self):
msg = Message(nonce=make_nonce(), id=make_id("test"), digests=(), payload=b"")
assert msg.num_digests == 0
def test_nonce_length_validation(self):
with pytest.raises(ValueError, match="nonce"):
Message(nonce=b"\x00", id=make_id("x"))
def test_id_length_validation(self):
with pytest.raises(ValueError, match="id"):
Message(nonce=make_nonce(), id=b"\x00")
def test_digest_length_validation(self):
with pytest.raises(ValueError, match="digest"):
Message(nonce=make_nonce(), id=make_id("x"),
digests=(b"\x00",))
def test_serialize_roundtrip_deterministic(self):
msg = Message(nonce=make_nonce(42), id=make_id("proc1"),
digests=(), payload=b"hello world")
serialized = msg.serialize()
assert isinstance(serialized, bytes)
# Same message serializes the same way
assert msg.serialize() == serialized
def test_compute_digest_deterministic(self):
msg = Message(nonce=make_nonce(), id=make_id("test"), payload=b"data")
d1 = msg.compute_digest()
d2 = msg.compute_digest()
assert d1 == d2
assert len(d1) == DIGEST_LENGTH
def test_different_messages_different_digests(self):
m1 = Message(nonce=make_nonce(1), id=make_id("a"), payload=b"x")
m2 = Message(nonce=make_nonce(2), id=make_id("a"), payload=b"x")
assert m1.compute_digest() != m2.compute_digest()
def test_message_with_digests(self):
parent = Message(nonce=make_nonce(), id=make_id("a"), payload=b"parent")
child = Message(
nonce=make_nonce(1), id=make_id("a"),
digests=(parent.compute_digest(),),
payload=b"child"
)
assert child.num_digests == 1
assert child.digests[0] == parent.compute_digest()
def test_message_is_immutable(self):
msg = Message(nonce=make_nonce(), id=make_id("x"), payload=b"y")
with pytest.raises(AttributeError):
msg.nonce = b"\x00" * NONCE_LENGTH
class TestVertex:
def test_vertex_wraps_message(self):
msg = Message(nonce=make_nonce(), id=make_id("proc"), payload=b"data")
v = Vertex(m=msg)
assert v.nonce == msg.nonce
assert v.id == msg.id
assert v.payload == msg.payload
assert v.digests == msg.digests
def test_vertex_default_state(self):
msg = Message(nonce=make_nonce(), id=make_id("x"))
v = Vertex(m=msg)
assert v.round is None
assert v.is_last is None
assert v.svp == []
assert v.vote == {}
assert v.total_position is None
def test_vertex_equivalence(self):
"""Definition 3.3: two vertices are equivalent if v.m = v_hat.m"""
msg = Message(nonce=make_nonce(), id=make_id("x"), payload=b"same")
v1 = Vertex(m=msg)
v2 = Vertex(m=msg)
assert v1.equivalent_to(v2)
assert v1 == v2
assert hash(v1) == hash(v2)
def test_vertex_non_equivalence(self):
m1 = Message(nonce=make_nonce(1), id=make_id("x"))
m2 = Message(nonce=make_nonce(2), id=make_id("x"))
v1 = Vertex(m=m1)
v2 = Vertex(m=m2)
assert not v1.equivalent_to(v2)
assert v1 != v2
class TestVote:
def test_vote_undecided(self):
v = Vote(message=None, binary=None)
assert "" in repr(v)
assert "" in repr(v)
def test_vote_with_message(self):
msg = Message(nonce=make_nonce(), id=make_id("x"))
v = Vote(message=msg, binary=1)
assert v.binary == 1
def test_empty_message_digest(self):
assert EMPTY_MESSAGE_DIGEST == digest(b"")

126
tests/test_order.py Normal file
View file

@ -0,0 +1,126 @@
"""Tests for total order computation."""
from crisis.crypto import digest
from crisis.graph import LamportGraph
from crisis.message import Message, ID_LENGTH, NONCE_LENGTH
from crisis.order import LeaderStream, compute_order, _kahns_total_order
from crisis.weight import ProofOfWorkWeight
def make_id(name: str) -> bytes:
return digest(name.encode())[:ID_LENGTH]
def make_nonce(n: int = 0) -> bytes:
return n.to_bytes(NONCE_LENGTH, "big")
def make_graph() -> LamportGraph:
return LamportGraph(weight_system=ProofOfWorkWeight(min_leading_zeros=0))
class TestLeaderStream:
def test_empty_stream(self):
ls = LeaderStream()
assert ls.max_round() == -1
assert ls.get_leader(0) is None
def test_add_leader(self):
ls = LeaderStream()
msg = Message(nonce=make_nonce(), id=make_id("leader"), payload=b"L")
updated = ls.update(0, 1, msg)
assert updated is True
assert ls.get_leader(0) is msg
def test_higher_deciding_round_replaces(self):
ls = LeaderStream()
m1 = Message(nonce=make_nonce(1), id=make_id("l1"), payload=b"old")
m2 = Message(nonce=make_nonce(2), id=make_id("l2"), payload=b"new")
ls.update(0, 1, m1)
ls.update(0, 2, m2)
assert ls.get_leader(0) is m2
def test_lower_deciding_round_rejected(self):
ls = LeaderStream()
m1 = Message(nonce=make_nonce(1), id=make_id("l1"), payload=b"first")
m2 = Message(nonce=make_nonce(2), id=make_id("l2"), payload=b"late")
ls.update(0, 5, m1)
updated = ls.update(0, 3, m2)
assert updated is False
assert ls.get_leader(0) is m1
def test_all_leaders_sorted(self):
ls = LeaderStream()
m0 = Message(nonce=make_nonce(0), id=make_id("l0"), payload=b"r0")
m1 = Message(nonce=make_nonce(1), id=make_id("l1"), payload=b"r1")
m2 = Message(nonce=make_nonce(2), id=make_id("l2"), payload=b"r2")
ls.update(2, 3, m2)
ls.update(0, 1, m0)
ls.update(1, 2, m1)
leaders = ls.all_leaders()
rounds = [r for r, _ in leaders]
assert rounds == sorted(rounds)
class TestKahnsAlgorithm:
def test_empty_input(self):
g = make_graph()
result = _kahns_total_order([], g)
assert result == []
def test_single_vertex(self):
g = make_graph()
msg = Message(nonce=make_nonce(), id=make_id("x"), payload=b"only")
v = g.extend(msg)
result = _kahns_total_order([v], g)
assert result == [v]
def test_chain_order(self):
"""A chain should be ordered causes-first."""
g = make_graph()
pid = make_id("alice")
m1 = Message(nonce=make_nonce(0), id=pid, payload=b"first")
v1 = g.extend(m1)
m2 = Message(nonce=make_nonce(1), id=pid,
digests=(m1.compute_digest(),), payload=b"second")
v2 = g.extend(m2)
m3 = Message(nonce=make_nonce(2), id=pid,
digests=(m2.compute_digest(),), payload=b"third")
v3 = g.extend(m3)
result = _kahns_total_order([v1, v2, v3], g)
# Causes come first
assert result.index(v1) < result.index(v2)
assert result.index(v2) < result.index(v3)
def test_respects_causality(self):
"""Total order must be consistent with causal order."""
g = make_graph()
m_a = Message(nonce=make_nonce(0), id=make_id("alice"), payload=b"a")
va = g.extend(m_a)
m_b = Message(nonce=make_nonce(0), id=make_id("bob"), payload=b"b")
vb = g.extend(m_b)
# Carol references both alice and bob
m_c = Message(
nonce=make_nonce(1), id=make_id("carol"),
digests=(m_a.compute_digest(), m_b.compute_digest()),
payload=b"c"
)
vc = g.extend(m_c)
result = _kahns_total_order([va, vb, vc], g)
# Carol must come after both alice and bob
assert result.index(va) < result.index(vc)
assert result.index(vb) < result.index(vc)

110
tests/test_rounds.py Normal file
View file

@ -0,0 +1,110 @@
"""Tests for virtual synchronous rounds."""
from crisis.crypto import digest
from crisis.graph import LamportGraph
from crisis.message import Message, ID_LENGTH, NONCE_LENGTH
from crisis.rounds import compute_rounds, max_round, last_vertices_in_round, vertices_in_round
from crisis.weight import ProofOfWorkWeight, DifficultyOracle
def make_id(name: str) -> bytes:
return digest(name.encode())[:ID_LENGTH]
def make_nonce(n: int = 0) -> bytes:
return n.to_bytes(NONCE_LENGTH, "big")
def make_graph() -> LamportGraph:
return LamportGraph(weight_system=ProofOfWorkWeight(min_leading_zeros=0))
class TestRoundComputation:
def test_single_vertex_round_zero(self):
"""A single vertex with no causes is in round 0."""
g = make_graph()
msg = Message(nonce=make_nonce(), id=make_id("alice"), payload=b"genesis")
v = g.extend(msg)
compute_rounds(g, DifficultyOracle(constant_difficulty=1))
assert v.round == 0
def test_single_vertex_is_last(self):
"""Round 0 vertices are always 'last' (bootstrapping)."""
g = make_graph()
msg = Message(nonce=make_nonce(), id=make_id("alice"))
v = g.extend(msg)
compute_rounds(g, DifficultyOracle(constant_difficulty=1))
assert v.is_last is True
def test_chain_grows_rounds(self):
"""A chain of messages should produce increasing round numbers."""
g = make_graph()
pid = make_id("alice")
difficulty = DifficultyOracle(constant_difficulty=0) # Low difficulty
# Create a chain
prev_msg = None
vertices = []
for i in range(5):
digests = (prev_msg.compute_digest(),) if prev_msg else ()
msg = Message(nonce=make_nonce(i), id=pid, digests=digests, payload=f"msg{i}".encode())
v = g.extend(msg)
vertices.append(v)
prev_msg = msg
compute_rounds(g, difficulty, connectivity_k=0)
# All should have round numbers assigned
for v in vertices:
assert v.round is not None
# First vertex is round 0
assert vertices[0].round == 0
def test_max_round_empty_graph(self):
g = make_graph()
assert max_round(g) == 0
def test_max_round_with_vertices(self):
g = make_graph()
msg = Message(nonce=make_nonce(), id=make_id("x"))
g.extend(msg)
compute_rounds(g, DifficultyOracle(constant_difficulty=1))
assert max_round(g) == 0
def test_last_vertices_in_round(self):
g = make_graph()
msg = Message(nonce=make_nonce(), id=make_id("alice"))
g.extend(msg)
compute_rounds(g, DifficultyOracle(constant_difficulty=1))
lasts = last_vertices_in_round(g, 0)
assert len(lasts) == 1
def test_multiple_ids_same_round(self):
"""Multiple independent vertices are all in round 0."""
g = make_graph()
for name in ["alice", "bob", "carol"]:
msg = Message(nonce=make_nonce(), id=make_id(name), payload=name.encode())
g.extend(msg)
compute_rounds(g, DifficultyOracle(constant_difficulty=1))
r0 = vertices_in_round(g, 0)
assert len(r0) == 3
def test_round_invariance(self):
"""Proposition 5.3: equivalent vertices in different graphs have same round."""
g1 = make_graph()
g2 = make_graph()
difficulty = DifficultyOracle(constant_difficulty=1)
msg = Message(nonce=make_nonce(), id=make_id("alice"), payload=b"genesis")
v1 = g1.extend(msg)
v2 = g2.extend(msg)
compute_rounds(g1, difficulty)
compute_rounds(g2, difficulty)
assert v1.round == v2.round
assert v1.is_last == v2.is_last

55
tests/test_simulation.py Normal file
View file

@ -0,0 +1,55 @@
"""Integration test: run the full simulation and verify basic properties."""
from crisis.demo import Simulation
class TestSimulation:
def test_simulation_runs(self):
"""The simulation should complete without errors."""
sim = Simulation(num_honest=3, num_byzantine=0, seed=42)
results = sim.run(num_steps=5, verbose=False)
assert len(results) == 5
def test_graphs_grow(self):
"""Each step should add messages to the graphs."""
sim = Simulation(num_honest=2, seed=42)
sim.run(num_steps=3, verbose=False)
for node in sim.nodes:
assert node.graph.vertex_count() > 0
def test_honest_nodes_same_graph_size(self):
"""All honest nodes should have the same number of vertices
(since all messages are delivered to all nodes)."""
sim = Simulation(num_honest=3, seed=42)
sim.run(num_steps=5, verbose=False)
sizes = [n.graph.vertex_count() for n in sim.nodes]
assert all(s == sizes[0] for s in sizes)
def test_rounds_are_computed(self):
"""After running, vertices should have round numbers."""
sim = Simulation(num_honest=3, seed=42)
sim.run(num_steps=5, verbose=False)
for node in sim.nodes:
for v in node.graph.all_vertices():
assert v.round is not None
def test_with_byzantine_node(self):
"""Simulation should handle byzantine nodes without crashing."""
sim = Simulation(num_honest=3, num_byzantine=1, seed=42)
results = sim.run(num_steps=5, verbose=False)
assert len(results) == 5
def test_deterministic_with_seed(self):
"""Same seed should produce the same results."""
sim1 = Simulation(num_honest=3, seed=123)
r1 = sim1.run(num_steps=3, verbose=False)
sim2 = Simulation(num_honest=3, seed=123)
r2 = sim2.run(num_steps=3, verbose=False)
# Same number of messages at each step
for s1, s2 in zip(r1, r2):
assert len(s1["new_messages"]) == len(s2["new_messages"])
for ns1, ns2 in zip(s1["node_states"], s2["node_states"]):
assert ns1["vertices"] == ns2["vertices"]

78
tests/test_weight.py Normal file
View file

@ -0,0 +1,78 @@
"""Tests for the weight system and difficulty oracle."""
from crisis.crypto import digest
from crisis.message import Message, ID_LENGTH, NONCE_LENGTH
from crisis.weight import ProofOfWorkWeight, DifficultyOracle
def make_id(name: str) -> bytes:
return digest(name.encode())[:ID_LENGTH]
def make_nonce(n: int = 0) -> bytes:
return n.to_bytes(NONCE_LENGTH, "big")
class TestProofOfWorkWeight:
def test_weight_is_non_negative(self):
ws = ProofOfWorkWeight(min_leading_zeros=0)
msg = Message(nonce=make_nonce(), id=make_id("x"), payload=b"test")
assert ws.weight(msg) >= 0
def test_weight_sum_is_additive(self):
ws = ProofOfWorkWeight()
assert ws.weight_sum(3, 5) == 8
assert ws.weight_sum(0, 0) == 0
def test_threshold(self):
ws = ProofOfWorkWeight(min_leading_zeros=2)
assert ws.threshold == 2
def test_is_valid_weight_with_zero_threshold(self):
ws = ProofOfWorkWeight(min_leading_zeros=0)
msg = Message(nonce=make_nonce(), id=make_id("x"))
assert ws.is_valid_weight(msg) # Everything passes with 0
def test_mine_nonce_finds_valid_message(self):
ws = ProofOfWorkWeight(min_leading_zeros=1)
msg = ws.mine_nonce(
id_bytes=make_id("miner"),
digests=(),
payload=b"test payload",
target_weight=1
)
assert ws.weight(msg) >= 1
assert ws.is_valid_weight(msg)
def test_different_nonces_different_weights(self):
"""Uniqueness property: different messages have different weights (w.h.p.)."""
ws = ProofOfWorkWeight()
weights = set()
for i in range(20):
msg = Message(nonce=make_nonce(i), id=make_id("x"), payload=b"same")
weights.add(ws.weight(msg))
# Not all the same (with overwhelming probability)
assert len(weights) > 1
def test_tamper_proof(self):
"""Changing a message should change its weight (w.h.p.)."""
ws = ProofOfWorkWeight()
msg1 = Message(nonce=make_nonce(42), id=make_id("x"), payload=b"original")
msg2 = Message(nonce=make_nonce(42), id=make_id("x"), payload=b"tampered")
# Weights differ because digests differ
# (this is probabilistic, but extremely likely)
assert msg1.compute_digest() != msg2.compute_digest()
class TestDifficultyOracle:
def test_constant_difficulty(self):
d = DifficultyOracle(constant_difficulty=5)
assert d.difficulty(0) == 5
assert d.difficulty(100) == 5
assert d.difficulty(999) == 5
def test_default_difficulty(self):
d = DifficultyOracle()
assert d.difficulty(0) == 4