mirror of
https://github.com/saymrwulf/crisis.git
synced 2026-05-14 20:37:54 +00:00
Initial implementation of the Crisis protocol (Richter, 2019)
Complete Python PoC of "Probabilistically Self Organizing Total Order in Unstructured P2P Networks". Implements all 10 algorithms from the paper: message generation, integrity checks, Lamport graphs, virtual synchronous rounds, safe voting patterns, virtual leader election (BA*), longest chain rule, total order via Kahn's algorithm, and push/pull gossip. Includes simulation harness, full node binary, and 72 passing tests. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
commit
1df4790fb4
22 changed files with 3987 additions and 0 deletions
26
.gitignore
vendored
Normal file
26
.gitignore
vendored
Normal file
|
|
@ -0,0 +1,26 @@
|
||||||
|
# Python
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
*.egg-info/
|
||||||
|
dist/
|
||||||
|
build/
|
||||||
|
*.egg
|
||||||
|
|
||||||
|
# Virtual environment
|
||||||
|
.venv/
|
||||||
|
|
||||||
|
# IDE
|
||||||
|
.idea/
|
||||||
|
.vscode/
|
||||||
|
*.swp
|
||||||
|
*.swo
|
||||||
|
|
||||||
|
# OS
|
||||||
|
.DS_Store
|
||||||
|
Thumbs.db
|
||||||
|
|
||||||
|
# Testing
|
||||||
|
.pytest_cache/
|
||||||
|
.coverage
|
||||||
|
htmlcov/
|
||||||
BIN
Crisis.mirco-richter-2019.pdf
Normal file
BIN
Crisis.mirco-richter-2019.pdf
Normal file
Binary file not shown.
30
pyproject.toml
Normal file
30
pyproject.toml
Normal file
|
|
@ -0,0 +1,30 @@
|
||||||
|
[project]
|
||||||
|
name = "crisis"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "Crisis: Probabilistically Self Organizing Total Order in Unstructured P2P Networks"
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.11"
|
||||||
|
license = "CC-BY-4.0"
|
||||||
|
authors = [
|
||||||
|
{ name = "Mirco Richter (paper)", email = "mirco.richter@mailbox.org" },
|
||||||
|
]
|
||||||
|
dependencies = [
|
||||||
|
"networkx>=3.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.optional-dependencies]
|
||||||
|
dev = [
|
||||||
|
"pytest>=8.0",
|
||||||
|
"rich>=13.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.scripts]
|
||||||
|
crisis-node = "crisis.node:main"
|
||||||
|
crisis-demo = "crisis.demo:main"
|
||||||
|
|
||||||
|
[build-system]
|
||||||
|
requires = ["setuptools>=68.0"]
|
||||||
|
build-backend = "setuptools.build_meta"
|
||||||
|
|
||||||
|
[tool.pytest.ini_options]
|
||||||
|
testpaths = ["tests"]
|
||||||
21
src/crisis/__init__.py
Normal file
21
src/crisis/__init__.py
Normal file
|
|
@ -0,0 +1,21 @@
|
||||||
|
"""
|
||||||
|
Crisis: Probabilistically Self Organizing Total Order in Unstructured P2P Networks
|
||||||
|
|
||||||
|
A Python implementation of the Crisis protocol described by Mirco Richter (2019).
|
||||||
|
|
||||||
|
The protocol achieves total order on messages in fully open, unstructured
|
||||||
|
Peer-to-Peer networks through virtual voting -- votes are never sent explicitly
|
||||||
|
but are deduced from the causal relationships between messages encoded in
|
||||||
|
Lamport graphs.
|
||||||
|
|
||||||
|
Key components:
|
||||||
|
- crypto: Random oracle model (SHA-256 hash function)
|
||||||
|
- message: Message and Vertex data structures
|
||||||
|
- weight: Weight systems (PoW-based Sybil resistance)
|
||||||
|
- graph: Lamport graphs with integrity checking
|
||||||
|
- rounds: Virtual synchronous rounds
|
||||||
|
- voting: Safe voting patterns and virtual leader election (BA*)
|
||||||
|
- order: Total order via leader stream and topological sorting
|
||||||
|
- gossip: Push/pull gossip for member discovery and message dissemination
|
||||||
|
- node: Full Crisis node tying all components together
|
||||||
|
"""
|
||||||
104
src/crisis/crypto.py
Normal file
104
src/crisis/crypto.py
Normal file
|
|
@ -0,0 +1,104 @@
|
||||||
|
"""
|
||||||
|
Random Oracle Model (Section 2.1)
|
||||||
|
|
||||||
|
We work in the random oracle model, assuming the existence of a cryptographic
|
||||||
|
hash function that behaves like a random oracle:
|
||||||
|
|
||||||
|
H : {0,1}* -> {0,1}^p (Eq. 1)
|
||||||
|
|
||||||
|
We use SHA-256 as our concrete instantiation. H is assumed to be collision-,
|
||||||
|
preimage-, and second-preimage-resistant.
|
||||||
|
|
||||||
|
We call H(b) the *digest* of the binary string b.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
# The digest length in bytes (SHA-256 produces 32 bytes = 256 bits).
|
||||||
|
DIGEST_LENGTH = 32
|
||||||
|
|
||||||
|
|
||||||
|
def digest(data: Union[bytes, bytearray]) -> bytes:
|
||||||
|
"""Compute the SHA-256 digest of arbitrary binary data.
|
||||||
|
|
||||||
|
This is the core random oracle H used throughout the protocol.
|
||||||
|
Every reference to "the digest of" a message or byte string in the
|
||||||
|
paper maps to this function.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
32-byte digest (256 bits).
|
||||||
|
"""
|
||||||
|
return hashlib.sha256(data).digest()
|
||||||
|
|
||||||
|
|
||||||
|
def digest_hex(data: Union[bytes, bytearray]) -> str:
|
||||||
|
"""Convenience: return the digest as a hex string for display."""
|
||||||
|
return digest(data).hex()
|
||||||
|
|
||||||
|
|
||||||
|
def verify_digest(data: bytes, expected: bytes) -> bool:
|
||||||
|
"""Check that H(data) equals the expected digest."""
|
||||||
|
return digest(data) == expected
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Least significant bit helper (used in the virtual coin flip, Algorithm 7)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def least_significant_bit(h: bytes) -> int:
|
||||||
|
"""Return the least significant bit of a hash value.
|
||||||
|
|
||||||
|
Used in Algorithm 7 (virtual leader election) for the "genuine coin flip"
|
||||||
|
stage, where the LSB of H(v_hat.m) determines the binary vote.
|
||||||
|
|
||||||
|
The paper defines:
|
||||||
|
b_coin := lsb(H(x.m)) for max weight x in S
|
||||||
|
"""
|
||||||
|
return h[-1] & 1
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Proof-of-Work helpers (used by the weight system, Section 3.1.1)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def count_leading_zero_bits(h: bytes) -> int:
|
||||||
|
"""Count the number of leading zero bits in a hash value.
|
||||||
|
|
||||||
|
This is the standard measure of proof-of-work difficulty: a hash with
|
||||||
|
k leading zero bits required roughly 2^k hash evaluations to find.
|
||||||
|
"""
|
||||||
|
count = 0
|
||||||
|
for byte in h:
|
||||||
|
if byte == 0:
|
||||||
|
count += 8
|
||||||
|
else:
|
||||||
|
# Count leading zeros in this byte
|
||||||
|
count += (byte ^ 0xFF).bit_length() - (255 - byte).bit_length()
|
||||||
|
# Simpler: count leading zeros via bit tricks
|
||||||
|
for bit_pos in range(7, -1, -1):
|
||||||
|
if byte & (1 << bit_pos):
|
||||||
|
return count
|
||||||
|
count += 1
|
||||||
|
break
|
||||||
|
return count
|
||||||
|
|
||||||
|
|
||||||
|
def count_leading_zero_bits(h: bytes) -> int:
|
||||||
|
"""Count the number of leading zero bits in a hash value.
|
||||||
|
|
||||||
|
A hash with k leading zero bits required roughly 2^k evaluations to find.
|
||||||
|
Used by the PoW weight function to assign weight to messages.
|
||||||
|
"""
|
||||||
|
count = 0
|
||||||
|
for byte in h:
|
||||||
|
if byte == 0:
|
||||||
|
count += 8
|
||||||
|
continue
|
||||||
|
# Count leading zeros in this non-zero byte
|
||||||
|
for bit_pos in range(7, -1, -1):
|
||||||
|
if byte & (1 << bit_pos):
|
||||||
|
return count
|
||||||
|
count += 1
|
||||||
|
break
|
||||||
|
return count
|
||||||
356
src/crisis/demo.py
Normal file
356
src/crisis/demo.py
Normal file
|
|
@ -0,0 +1,356 @@
|
||||||
|
"""
|
||||||
|
Demonstration / Simulation Harness
|
||||||
|
|
||||||
|
This module provides a deterministic, single-process simulation of the Crisis
|
||||||
|
protocol with N virtual nodes. It is designed as the foundation for a lecture
|
||||||
|
series: each phase of the protocol can be observed step by step.
|
||||||
|
|
||||||
|
The simulation bypasses the network layer entirely -- messages are delivered
|
||||||
|
directly between in-memory Lamport graphs. This makes the consensus mechanism
|
||||||
|
visible without network noise.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python -m crisis.demo # Run the full demo
|
||||||
|
python -m crisis.demo --nodes 5 # 5 honest nodes
|
||||||
|
python -m crisis.demo --byzantine 1 # 1 byzantine node
|
||||||
|
python -m crisis.demo --rounds 10 # Run for 10 message rounds
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from crisis.crypto import digest
|
||||||
|
from crisis.graph import LamportGraph
|
||||||
|
from crisis.message import Message, Vertex, ID_LENGTH, NONCE_LENGTH
|
||||||
|
from crisis.order import LeaderStream, compute_order
|
||||||
|
from crisis.rounds import compute_rounds, max_round, last_vertices_in_round
|
||||||
|
from crisis.voting import compute_safe_voting_pattern, compute_virtual_leader_election
|
||||||
|
from crisis.weight import ProofOfWorkWeight, DifficultyOracle
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Simulated Node
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SimulatedNode:
|
||||||
|
"""A simulated Crisis node running in-memory.
|
||||||
|
|
||||||
|
Each node has its own Lamport graph and process id. Messages are
|
||||||
|
exchanged by directly sharing Message objects (no serialization needed).
|
||||||
|
"""
|
||||||
|
name: str
|
||||||
|
process_id: bytes
|
||||||
|
graph: LamportGraph
|
||||||
|
leader_stream: LeaderStream = field(default_factory=LeaderStream)
|
||||||
|
is_byzantine: bool = False
|
||||||
|
messages_created: int = 0
|
||||||
|
|
||||||
|
def generate_message(self, payload: str) -> Message:
|
||||||
|
"""Generate a new message from this node."""
|
||||||
|
self.messages_created += 1
|
||||||
|
return self.graph.generate_message(
|
||||||
|
self.process_id,
|
||||||
|
payload.encode(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Simulation Engine
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class Simulation:
|
||||||
|
"""Deterministic simulation of N Crisis nodes.
|
||||||
|
|
||||||
|
Runs the protocol in lock-step rounds:
|
||||||
|
1. Each node generates a message
|
||||||
|
2. Messages are gossiped (delivered to all nodes)
|
||||||
|
3. Consensus is computed on each node
|
||||||
|
4. State is displayed
|
||||||
|
|
||||||
|
This allows observing how the Lamport graph grows, rounds emerge,
|
||||||
|
and total order converges.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, num_honest: int = 3, num_byzantine: int = 0,
|
||||||
|
pow_zeros: int = 0, difficulty: int = 2,
|
||||||
|
connectivity_k: int = 1, seed: int = 42):
|
||||||
|
self.difficulty_oracle = DifficultyOracle(constant_difficulty=difficulty)
|
||||||
|
self.connectivity_k = connectivity_k
|
||||||
|
self.weight_system = ProofOfWorkWeight(min_leading_zeros=pow_zeros)
|
||||||
|
self.seed = seed
|
||||||
|
random.seed(seed)
|
||||||
|
|
||||||
|
# Create nodes
|
||||||
|
self.nodes: list[SimulatedNode] = []
|
||||||
|
for i in range(num_honest):
|
||||||
|
name = f"honest-{i}"
|
||||||
|
pid = digest(name.encode())[:ID_LENGTH]
|
||||||
|
graph = LamportGraph(weight_system=self.weight_system)
|
||||||
|
self.nodes.append(SimulatedNode(
|
||||||
|
name=name, process_id=pid, graph=graph
|
||||||
|
))
|
||||||
|
|
||||||
|
for i in range(num_byzantine):
|
||||||
|
name = f"byzantine-{i}"
|
||||||
|
pid = digest(name.encode())[:ID_LENGTH]
|
||||||
|
graph = LamportGraph(weight_system=self.weight_system)
|
||||||
|
self.nodes.append(SimulatedNode(
|
||||||
|
name=name, process_id=pid, graph=graph, is_byzantine=True
|
||||||
|
))
|
||||||
|
|
||||||
|
self.step_count = 0
|
||||||
|
self.all_messages: list[Message] = []
|
||||||
|
|
||||||
|
def step(self) -> dict:
|
||||||
|
"""Execute one simulation step.
|
||||||
|
|
||||||
|
Returns a dict with step results for display.
|
||||||
|
"""
|
||||||
|
self.step_count += 1
|
||||||
|
step_results = {
|
||||||
|
"step": self.step_count,
|
||||||
|
"new_messages": [],
|
||||||
|
"node_states": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Phase 1: Each node generates a message
|
||||||
|
new_messages: list[tuple[SimulatedNode, Message]] = []
|
||||||
|
for node in self.nodes:
|
||||||
|
if node.is_byzantine:
|
||||||
|
msg = self._byzantine_message(node)
|
||||||
|
else:
|
||||||
|
payload = f"step-{self.step_count}-{node.name}"
|
||||||
|
msg = node.generate_message(payload)
|
||||||
|
|
||||||
|
if msg is not None:
|
||||||
|
new_messages.append((node, msg))
|
||||||
|
step_results["new_messages"].append({
|
||||||
|
"from": node.name,
|
||||||
|
"digest": msg.compute_digest().hex()[:12],
|
||||||
|
"weight": self.weight_system.weight(msg),
|
||||||
|
"payload": msg.payload.decode(errors="replace"),
|
||||||
|
})
|
||||||
|
|
||||||
|
# Phase 2: Gossip -- deliver all messages to all nodes
|
||||||
|
for source_node, msg in new_messages:
|
||||||
|
self.all_messages.append(msg)
|
||||||
|
for target_node in self.nodes:
|
||||||
|
# Deliver to all nodes (including source, for consistency)
|
||||||
|
target_node.graph.extend(msg)
|
||||||
|
|
||||||
|
# Also re-deliver older messages that nodes might be missing
|
||||||
|
# (simulates pull gossip catching up)
|
||||||
|
for msg in self.all_messages:
|
||||||
|
for node in self.nodes:
|
||||||
|
node.graph.extend(msg) # extend() is idempotent (integrity check)
|
||||||
|
|
||||||
|
# Phase 3: Compute consensus on each node
|
||||||
|
for node in self.nodes:
|
||||||
|
compute_rounds(node.graph, self.difficulty_oracle, self.connectivity_k)
|
||||||
|
|
||||||
|
for vertex in node.graph.all_vertices():
|
||||||
|
if vertex.is_last:
|
||||||
|
compute_safe_voting_pattern(
|
||||||
|
vertex, node.graph, self.difficulty_oracle,
|
||||||
|
self.connectivity_k
|
||||||
|
)
|
||||||
|
|
||||||
|
leader_dict: dict[int, list[tuple[int, Message]]] = {}
|
||||||
|
for vertex in node.graph.all_vertices():
|
||||||
|
if vertex.svp:
|
||||||
|
compute_virtual_leader_election(
|
||||||
|
vertex, node.graph, self.difficulty_oracle,
|
||||||
|
self.connectivity_k, leader_dict
|
||||||
|
)
|
||||||
|
|
||||||
|
for round_num, entries in leader_dict.items():
|
||||||
|
for deciding_round, leader_msg in entries:
|
||||||
|
node.leader_stream.update(round_num, deciding_round, leader_msg)
|
||||||
|
|
||||||
|
ordered = compute_order(node.graph, node.leader_stream)
|
||||||
|
|
||||||
|
mr = max_round(node.graph)
|
||||||
|
step_results["node_states"].append({
|
||||||
|
"name": node.name,
|
||||||
|
"vertices": node.graph.vertex_count(),
|
||||||
|
"max_round": mr,
|
||||||
|
"leaders": len(node.leader_stream.leaders),
|
||||||
|
"ordered": len(ordered),
|
||||||
|
"is_byzantine": node.is_byzantine,
|
||||||
|
})
|
||||||
|
|
||||||
|
return step_results
|
||||||
|
|
||||||
|
def _byzantine_message(self, node: SimulatedNode) -> Optional[Message]:
|
||||||
|
"""Generate a byzantine message.
|
||||||
|
|
||||||
|
Byzantine nodes can exhibit several faulty behaviors:
|
||||||
|
- Mutations: same id, forking the causal chain
|
||||||
|
- Strategic distribution: different messages to different peers
|
||||||
|
- Time travel: referencing old rounds
|
||||||
|
|
||||||
|
For this demo, we generate a message with a random payload that
|
||||||
|
may not reference the latest same-id message (creating a mutation).
|
||||||
|
"""
|
||||||
|
payload = f"byz-{self.step_count}-{node.name}-{random.randint(0, 999)}"
|
||||||
|
|
||||||
|
# 50% chance of creating a mutation (not referencing last same-id vertex)
|
||||||
|
if random.random() < 0.5 and node.graph.vertex_count() > 0:
|
||||||
|
# Pick random digests instead of following the chain
|
||||||
|
available = list(node.graph.vertices.keys())
|
||||||
|
num_refs = min(random.randint(1, 3), len(available))
|
||||||
|
digests = tuple(random.sample(available, num_refs))
|
||||||
|
nonce = os.urandom(NONCE_LENGTH)
|
||||||
|
return Message(
|
||||||
|
nonce=nonce, id=node.process_id,
|
||||||
|
digests=digests, payload=payload.encode()
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return node.generate_message(payload)
|
||||||
|
|
||||||
|
def run(self, num_steps: int = 10, verbose: bool = True) -> list[dict]:
|
||||||
|
"""Run the simulation for a number of steps."""
|
||||||
|
results = []
|
||||||
|
for _ in range(num_steps):
|
||||||
|
result = self.step()
|
||||||
|
results.append(result)
|
||||||
|
if verbose:
|
||||||
|
_print_step(result)
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
_print_convergence_summary(self)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Display functions
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _print_step(result: dict) -> None:
|
||||||
|
"""Print the results of a simulation step."""
|
||||||
|
print(f"\n{'='*70}")
|
||||||
|
print(f" Step {result['step']}")
|
||||||
|
print(f"{'='*70}")
|
||||||
|
|
||||||
|
print(f"\n New messages:")
|
||||||
|
for msg in result["new_messages"]:
|
||||||
|
print(f" {msg['from']:>15s} -> {msg['digest']} "
|
||||||
|
f"w={msg['weight']} {msg['payload'][:40]}")
|
||||||
|
|
||||||
|
print(f"\n Node states:")
|
||||||
|
print(f" {'Name':>15s} {'Vertices':>8s} {'Round':>5s} "
|
||||||
|
f"{'Leaders':>7s} {'Ordered':>7s}")
|
||||||
|
print(f" {'-'*15} {'-'*8} {'-'*5} {'-'*7} {'-'*7}")
|
||||||
|
for ns in result["node_states"]:
|
||||||
|
byz = " [BYZ]" if ns["is_byzantine"] else ""
|
||||||
|
print(f" {ns['name']:>15s} {ns['vertices']:>8d} "
|
||||||
|
f"{ns['max_round']:>5d} {ns['leaders']:>7d} "
|
||||||
|
f"{ns['ordered']:>7d}{byz}")
|
||||||
|
|
||||||
|
|
||||||
|
def _print_convergence_summary(sim: Simulation) -> None:
|
||||||
|
"""Print a summary showing whether honest nodes have converged."""
|
||||||
|
print(f"\n{'='*70}")
|
||||||
|
print(f" Convergence Summary")
|
||||||
|
print(f"{'='*70}")
|
||||||
|
|
||||||
|
honest_nodes = [n for n in sim.nodes if not n.is_byzantine]
|
||||||
|
|
||||||
|
# Check if all honest nodes have the same total order
|
||||||
|
orders = []
|
||||||
|
for node in honest_nodes:
|
||||||
|
ordered = compute_order(node.graph, node.leader_stream)
|
||||||
|
order_digests = [v.message_digest.hex()[:12] for v in ordered]
|
||||||
|
orders.append(order_digests)
|
||||||
|
|
||||||
|
if len(orders) >= 2:
|
||||||
|
# Compare pairwise
|
||||||
|
all_agree = all(o == orders[0] for o in orders[1:])
|
||||||
|
if all_agree:
|
||||||
|
print(f"\n All {len(honest_nodes)} honest nodes AGREE "
|
||||||
|
f"on total order ({len(orders[0])} messages)")
|
||||||
|
else:
|
||||||
|
print(f"\n Honest nodes have DIVERGENT total orders "
|
||||||
|
f"(convergence in progress)")
|
||||||
|
for i, (node, order) in enumerate(zip(honest_nodes, orders)):
|
||||||
|
print(f" {node.name}: {len(order)} ordered messages")
|
||||||
|
|
||||||
|
# Show the total order from the first honest node
|
||||||
|
if orders and orders[0]:
|
||||||
|
print(f"\n Total order (from {honest_nodes[0].name}):")
|
||||||
|
first_node = honest_nodes[0]
|
||||||
|
ordered = compute_order(first_node.graph, first_node.leader_stream)
|
||||||
|
for v in ordered[:20]: # Show first 20
|
||||||
|
print(f" pos={v.total_position:>3d} "
|
||||||
|
f"hash={v.message_digest.hex()[:12]} "
|
||||||
|
f"r={v.round} "
|
||||||
|
f"payload={v.payload.decode(errors='replace')[:40]}")
|
||||||
|
if len(ordered) > 20:
|
||||||
|
print(f" ... and {len(ordered) - 20} more")
|
||||||
|
|
||||||
|
# Show leader stream
|
||||||
|
if honest_nodes:
|
||||||
|
ls = honest_nodes[0].leader_stream
|
||||||
|
if ls.leaders:
|
||||||
|
print(f"\n Leader stream ({honest_nodes[0].name}):")
|
||||||
|
for round_num, (dec_round, msg) in sorted(ls.leaders.items()):
|
||||||
|
print(f" round {round_num}: leader={msg.compute_digest().hex()[:12]} "
|
||||||
|
f"decided_in_round={dec_round}")
|
||||||
|
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# CLI entry point
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def main():
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Crisis Protocol Simulation",
|
||||||
|
epilog="Demonstrates probabilistic total order convergence"
|
||||||
|
)
|
||||||
|
parser.add_argument("--nodes", type=int, default=3,
|
||||||
|
help="Number of honest nodes (default: 3)")
|
||||||
|
parser.add_argument("--byzantine", type=int, default=0,
|
||||||
|
help="Number of byzantine nodes (default: 0)")
|
||||||
|
parser.add_argument("--steps", type=int, default=10,
|
||||||
|
help="Number of simulation steps (default: 10)")
|
||||||
|
parser.add_argument("--pow-zeros", type=int, default=0,
|
||||||
|
help="Min PoW leading zeros (default: 0 = no PoW)")
|
||||||
|
parser.add_argument("--difficulty", type=int, default=2,
|
||||||
|
help="Difficulty oracle constant (default: 2)")
|
||||||
|
parser.add_argument("--seed", type=int, default=42,
|
||||||
|
help="Random seed for reproducibility (default: 42)")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
print(f"Crisis Protocol Simulation")
|
||||||
|
print(f" Honest nodes: {args.nodes}")
|
||||||
|
print(f" Byzantine nodes: {args.byzantine}")
|
||||||
|
print(f" Steps: {args.steps}")
|
||||||
|
print(f" PoW zeros: {args.pow_zeros}")
|
||||||
|
print(f" Difficulty: {args.difficulty}")
|
||||||
|
print(f" Seed: {args.seed}")
|
||||||
|
|
||||||
|
sim = Simulation(
|
||||||
|
num_honest=args.nodes,
|
||||||
|
num_byzantine=args.byzantine,
|
||||||
|
pow_zeros=args.pow_zeros,
|
||||||
|
difficulty=args.difficulty,
|
||||||
|
seed=args.seed,
|
||||||
|
)
|
||||||
|
|
||||||
|
sim.run(num_steps=args.steps)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
414
src/crisis/gossip.py
Normal file
414
src/crisis/gossip.py
Normal file
|
|
@ -0,0 +1,414 @@
|
||||||
|
"""
|
||||||
|
Communication (Section 4)
|
||||||
|
|
||||||
|
Crisis is built on top of two simple push & pull gossip protocols:
|
||||||
|
1. Member discovery gossip (Algorithm 3)
|
||||||
|
2. Message gossip (Algorithm 4)
|
||||||
|
|
||||||
|
These are well suited for communication in unstructured P2P networks.
|
||||||
|
All the system needs is a way to distribute messages in a byzantine-prone
|
||||||
|
environment.
|
||||||
|
|
||||||
|
4.3 Member Discovery Gossip (Algorithm 3):
|
||||||
|
Each process maintains a partial view Π_j(t) of the network.
|
||||||
|
Periodically, a process pushes its neighbor list to a random peer
|
||||||
|
and pulls neighbor lists from other peers.
|
||||||
|
|
||||||
|
4.4 Message Gossip (Algorithm 4):
|
||||||
|
Processes push unordered messages to random peers and pull missing
|
||||||
|
messages. Already ordered messages are pushed only as responses
|
||||||
|
to pull requests (stop criterion for push gossip).
|
||||||
|
|
||||||
|
This module implements both gossip protocols using asyncio for the
|
||||||
|
"run in parallel forever" loops described in the paper.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import struct
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from crisis.graph import LamportGraph
|
||||||
|
from crisis.message import Message, Vertex, NONCE_LENGTH, ID_LENGTH
|
||||||
|
from crisis.crypto import DIGEST_LENGTH
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Peer identity and network view
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PeerInfo:
|
||||||
|
"""Information about a known peer in the network."""
|
||||||
|
host: str
|
||||||
|
port: int
|
||||||
|
process_id: bytes = b"" # The peer's virtual process id, if known
|
||||||
|
|
||||||
|
@property
|
||||||
|
def address(self) -> tuple[str, int]:
|
||||||
|
return (self.host, self.port)
|
||||||
|
|
||||||
|
def __hash__(self):
|
||||||
|
return hash((self.host, self.port))
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
if not isinstance(other, PeerInfo):
|
||||||
|
return NotImplemented
|
||||||
|
return self.host == other.host and self.port == other.port
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class NetworkView:
|
||||||
|
"""Π_j(t): a process's partial view of the network at time t.
|
||||||
|
|
||||||
|
"No process must know the entire system and each j ∈ Π(t) might
|
||||||
|
have a partial view Π_j(t) only." (Section 4.3)
|
||||||
|
"""
|
||||||
|
peers: set[PeerInfo] = field(default_factory=set)
|
||||||
|
max_peers: int = 50 # Limit to prevent unbounded growth
|
||||||
|
|
||||||
|
def add_peer(self, peer: PeerInfo) -> None:
|
||||||
|
if len(self.peers) < self.max_peers:
|
||||||
|
self.peers.add(peer)
|
||||||
|
|
||||||
|
def remove_peer(self, peer: PeerInfo) -> None:
|
||||||
|
self.peers.discard(peer)
|
||||||
|
|
||||||
|
def random_peer(self) -> Optional[PeerInfo]:
|
||||||
|
if not self.peers:
|
||||||
|
return None
|
||||||
|
return random.choice(list(self.peers))
|
||||||
|
|
||||||
|
def random_subset(self, k: int) -> list[PeerInfo]:
|
||||||
|
peers_list = list(self.peers)
|
||||||
|
return random.sample(peers_list, min(k, len(peers_list)))
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Message serialization for network transport
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def serialize_message(message: Message) -> bytes:
|
||||||
|
"""Serialize a Message for network transmission.
|
||||||
|
|
||||||
|
Format: [total_length:4][nonce:8][id:32][num_digests:2][digests...][payload]
|
||||||
|
"""
|
||||||
|
body = message.serialize()
|
||||||
|
length = len(body)
|
||||||
|
return struct.pack("!I", length) + body
|
||||||
|
|
||||||
|
|
||||||
|
def deserialize_message(data: bytes) -> Message:
|
||||||
|
"""Deserialize a Message from network bytes.
|
||||||
|
|
||||||
|
Parses the fixed-size fields and reconstructs the Message object.
|
||||||
|
"""
|
||||||
|
offset = 0
|
||||||
|
nonce = data[offset:offset + NONCE_LENGTH]
|
||||||
|
offset += NONCE_LENGTH
|
||||||
|
|
||||||
|
id_bytes = data[offset:offset + ID_LENGTH]
|
||||||
|
offset += ID_LENGTH
|
||||||
|
|
||||||
|
num_digests = int.from_bytes(data[offset:offset + 2], "big")
|
||||||
|
offset += 2
|
||||||
|
|
||||||
|
digests = []
|
||||||
|
for _ in range(num_digests):
|
||||||
|
d = data[offset:offset + DIGEST_LENGTH]
|
||||||
|
digests.append(d)
|
||||||
|
offset += DIGEST_LENGTH
|
||||||
|
|
||||||
|
payload = data[offset:]
|
||||||
|
|
||||||
|
return Message(
|
||||||
|
nonce=nonce,
|
||||||
|
id=id_bytes,
|
||||||
|
digests=tuple(digests),
|
||||||
|
payload=payload,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Protocol message types
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
# Simple protocol: 1-byte type prefix
|
||||||
|
MSG_TYPE_PUSH_MESSAGE = b"\x01" # Push a crisis message
|
||||||
|
MSG_TYPE_PULL_REQUEST = b"\x02" # Request missing messages
|
||||||
|
MSG_TYPE_PULL_RESPONSE = b"\x03" # Response with requested messages
|
||||||
|
MSG_TYPE_PEER_PUSH = b"\x04" # Push peer list
|
||||||
|
MSG_TYPE_PEER_PULL = b"\x05" # Request peer list
|
||||||
|
MSG_TYPE_PEER_RESPONSE = b"\x06" # Response with peer list
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Gossip Server
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class GossipServer:
|
||||||
|
"""Asyncio-based gossip server implementing Algorithms 3 and 4.
|
||||||
|
|
||||||
|
Runs two parallel loops:
|
||||||
|
1. Member discovery push & pull (Algorithm 3)
|
||||||
|
2. Message push & pull (Algorithm 4)
|
||||||
|
|
||||||
|
Plus a listener that handles incoming connections.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, host: str, port: int, graph: LamportGraph,
|
||||||
|
network_view: NetworkView,
|
||||||
|
push_interval: float = 2.0,
|
||||||
|
discovery_interval: float = 5.0):
|
||||||
|
self.host = host
|
||||||
|
self.port = port
|
||||||
|
self.graph = graph
|
||||||
|
self.network_view = network_view
|
||||||
|
self.push_interval = push_interval
|
||||||
|
self.discovery_interval = discovery_interval
|
||||||
|
self._server: Optional[asyncio.Server] = None
|
||||||
|
self._running = False
|
||||||
|
|
||||||
|
async def start(self) -> None:
|
||||||
|
"""Start the gossip server and all gossip loops."""
|
||||||
|
self._running = True
|
||||||
|
self._server = await asyncio.start_server(
|
||||||
|
self._handle_connection, self.host, self.port
|
||||||
|
)
|
||||||
|
logger.info(f"Gossip server listening on {self.host}:{self.port}")
|
||||||
|
|
||||||
|
# Run the gossip loops concurrently (paper: "run in parallel forever")
|
||||||
|
await asyncio.gather(
|
||||||
|
self._server.serve_forever(),
|
||||||
|
self._discovery_push_loop(),
|
||||||
|
self._message_push_loop(),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def stop(self) -> None:
|
||||||
|
"""Stop the gossip server."""
|
||||||
|
self._running = False
|
||||||
|
if self._server:
|
||||||
|
self._server.close()
|
||||||
|
await self._server.wait_closed()
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Algorithm 3: Member discovery push & pull
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def _discovery_push_loop(self) -> None:
|
||||||
|
"""Algorithm 3, lines 1-5: periodically push peer list to random peers."""
|
||||||
|
while self._running:
|
||||||
|
await asyncio.sleep(self.discovery_interval)
|
||||||
|
|
||||||
|
peer = self.network_view.random_peer()
|
||||||
|
if peer is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self._send_peer_push(peer)
|
||||||
|
await self._send_peer_pull(peer)
|
||||||
|
except (ConnectionError, OSError) as e:
|
||||||
|
logger.debug(f"Discovery push to {peer.address} failed: {e}")
|
||||||
|
self.network_view.remove_peer(peer)
|
||||||
|
|
||||||
|
async def _send_peer_push(self, peer: PeerInfo) -> None:
|
||||||
|
"""Push our peer list to a remote peer."""
|
||||||
|
peer_data = self._encode_peer_list(list(self.network_view.peers))
|
||||||
|
await self._send_to_peer(peer, MSG_TYPE_PEER_PUSH + peer_data)
|
||||||
|
|
||||||
|
async def _send_peer_pull(self, peer: PeerInfo) -> None:
|
||||||
|
"""Request a peer list from a remote peer."""
|
||||||
|
response = await self._send_and_receive(peer, MSG_TYPE_PEER_PULL)
|
||||||
|
if response and response[0:1] == MSG_TYPE_PEER_RESPONSE:
|
||||||
|
new_peers = self._decode_peer_list(response[1:])
|
||||||
|
for p in new_peers:
|
||||||
|
if p.host != self.host or p.port != self.port:
|
||||||
|
self.network_view.add_peer(p)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Algorithm 4: Message gossip push & pull
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def _message_push_loop(self) -> None:
|
||||||
|
"""Algorithm 4, lines 1-5: push unordered messages to random peers.
|
||||||
|
|
||||||
|
"Messages are retransmitted via push gossip, only if they don't
|
||||||
|
have a total order yet." (Section 4.4)
|
||||||
|
"""
|
||||||
|
while self._running:
|
||||||
|
await asyncio.sleep(self.push_interval)
|
||||||
|
|
||||||
|
peer = self.network_view.random_peer()
|
||||||
|
if peer is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Push messages that don't have total_position yet
|
||||||
|
unordered = [
|
||||||
|
v for v in self.graph.all_vertices()
|
||||||
|
if v.total_position is None
|
||||||
|
]
|
||||||
|
|
||||||
|
if not unordered:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
for vertex in unordered:
|
||||||
|
msg_bytes = serialize_message(vertex.m)
|
||||||
|
await self._send_to_peer(
|
||||||
|
peer, MSG_TYPE_PUSH_MESSAGE + msg_bytes
|
||||||
|
)
|
||||||
|
except (ConnectionError, OSError) as e:
|
||||||
|
logger.debug(f"Message push to {peer.address} failed: {e}")
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Connection handler (incoming)
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def _handle_connection(self, reader: asyncio.StreamReader,
|
||||||
|
writer: asyncio.StreamWriter) -> None:
|
||||||
|
"""Handle an incoming gossip connection.
|
||||||
|
|
||||||
|
Algorithm 3, lines 6-13 (peer data) and Algorithm 4, lines 6-13
|
||||||
|
(message data).
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
data = await asyncio.wait_for(reader.read(65536), timeout=10.0)
|
||||||
|
if not data:
|
||||||
|
return
|
||||||
|
|
||||||
|
msg_type = data[0:1]
|
||||||
|
payload = data[1:]
|
||||||
|
|
||||||
|
if msg_type == MSG_TYPE_PUSH_MESSAGE:
|
||||||
|
# Received a message: try to extend our Lamport graph
|
||||||
|
self._handle_push_message(payload)
|
||||||
|
|
||||||
|
elif msg_type == MSG_TYPE_PULL_REQUEST:
|
||||||
|
# Someone wants messages: send what we have
|
||||||
|
response = self._handle_pull_request(payload)
|
||||||
|
writer.write(response)
|
||||||
|
await writer.drain()
|
||||||
|
|
||||||
|
elif msg_type == MSG_TYPE_PEER_PUSH:
|
||||||
|
# Received a peer list: update our view
|
||||||
|
new_peers = self._decode_peer_list(payload)
|
||||||
|
for p in new_peers:
|
||||||
|
if p.host != self.host or p.port != self.port:
|
||||||
|
self.network_view.add_peer(p)
|
||||||
|
|
||||||
|
elif msg_type == MSG_TYPE_PEER_PULL:
|
||||||
|
# Someone wants our peer list
|
||||||
|
response = MSG_TYPE_PEER_RESPONSE + self._encode_peer_list(
|
||||||
|
list(self.network_view.peers)
|
||||||
|
)
|
||||||
|
writer.write(response)
|
||||||
|
await writer.drain()
|
||||||
|
|
||||||
|
except (asyncio.TimeoutError, ConnectionError):
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
writer.close()
|
||||||
|
|
||||||
|
def _handle_push_message(self, data: bytes) -> Optional[Vertex]:
|
||||||
|
"""Process a pushed message: validate and extend graph if valid.
|
||||||
|
|
||||||
|
Algorithm 4, lines 7-8: "if MESSAGE_INTEGRITY(m, G) then
|
||||||
|
expand G with vertex v, such that v.m = m"
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Parse length prefix
|
||||||
|
if len(data) < 4:
|
||||||
|
return None
|
||||||
|
length = struct.unpack("!I", data[:4])[0]
|
||||||
|
msg_data = data[4:4 + length]
|
||||||
|
|
||||||
|
message = deserialize_message(msg_data)
|
||||||
|
return self.graph.extend(message)
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Failed to process pushed message: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _handle_pull_request(self, data: bytes) -> bytes:
|
||||||
|
"""Respond to a pull request with messages the requester is missing.
|
||||||
|
|
||||||
|
Algorithm 4, lines 10-11: "respond with appropriate set of messages"
|
||||||
|
"""
|
||||||
|
# Data contains a list of digests the requester already has
|
||||||
|
known_digests = set()
|
||||||
|
offset = 0
|
||||||
|
while offset + DIGEST_LENGTH <= len(data):
|
||||||
|
known_digests.add(data[offset:offset + DIGEST_LENGTH])
|
||||||
|
offset += DIGEST_LENGTH
|
||||||
|
|
||||||
|
# Send messages the requester doesn't have
|
||||||
|
response_parts = [MSG_TYPE_PULL_RESPONSE]
|
||||||
|
for d, vertex in self.graph.vertices.items():
|
||||||
|
if d not in known_digests:
|
||||||
|
response_parts.append(serialize_message(vertex.m))
|
||||||
|
return b"".join(response_parts)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Network I/O helpers
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def _send_to_peer(self, peer: PeerInfo, data: bytes) -> None:
|
||||||
|
"""Send data to a peer (fire-and-forget)."""
|
||||||
|
reader, writer = await asyncio.open_connection(peer.host, peer.port)
|
||||||
|
writer.write(data)
|
||||||
|
await writer.drain()
|
||||||
|
writer.close()
|
||||||
|
|
||||||
|
async def _send_and_receive(self, peer: PeerInfo, data: bytes) -> Optional[bytes]:
|
||||||
|
"""Send data and wait for a response."""
|
||||||
|
try:
|
||||||
|
reader, writer = await asyncio.open_connection(peer.host, peer.port)
|
||||||
|
writer.write(data)
|
||||||
|
await writer.drain()
|
||||||
|
response = await asyncio.wait_for(reader.read(65536), timeout=5.0)
|
||||||
|
writer.close()
|
||||||
|
return response
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Peer list encoding
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _encode_peer_list(peers: list[PeerInfo]) -> bytes:
|
||||||
|
"""Encode a list of peers as bytes: [count:2][host_len:1][host][port:2]..."""
|
||||||
|
parts = [struct.pack("!H", len(peers))]
|
||||||
|
for peer in peers:
|
||||||
|
host_bytes = peer.host.encode("utf-8")
|
||||||
|
parts.append(struct.pack("!B", len(host_bytes)))
|
||||||
|
parts.append(host_bytes)
|
||||||
|
parts.append(struct.pack("!H", peer.port))
|
||||||
|
return b"".join(parts)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _decode_peer_list(data: bytes) -> list[PeerInfo]:
|
||||||
|
"""Decode a peer list from bytes."""
|
||||||
|
if len(data) < 2:
|
||||||
|
return []
|
||||||
|
count = struct.unpack("!H", data[:2])[0]
|
||||||
|
offset = 2
|
||||||
|
peers = []
|
||||||
|
for _ in range(count):
|
||||||
|
if offset >= len(data):
|
||||||
|
break
|
||||||
|
host_len = data[offset]
|
||||||
|
offset += 1
|
||||||
|
host = data[offset:offset + host_len].decode("utf-8")
|
||||||
|
offset += host_len
|
||||||
|
port = struct.unpack("!H", data[offset:offset + 2])[0]
|
||||||
|
offset += 2
|
||||||
|
peers.append(PeerInfo(host=host, port=port))
|
||||||
|
return peers
|
||||||
479
src/crisis/graph.py
Normal file
479
src/crisis/graph.py
Normal file
|
|
@ -0,0 +1,479 @@
|
||||||
|
"""
|
||||||
|
Lamport Graphs (Section 3.2)
|
||||||
|
|
||||||
|
Lamport graphs represent the causal partial order between messages as a
|
||||||
|
directed acyclic graph. They are the central data structure of the Crisis
|
||||||
|
protocol -- all consensus state is derived from the graph structure.
|
||||||
|
|
||||||
|
Definition 3.5 (Lamport Graph):
|
||||||
|
Let V ⊂ VERTEX be a finite set of vertices, such that all vertices
|
||||||
|
v_hat with v_hat ≤ v for all v ∈ V are in V, but no two vertices in V
|
||||||
|
are equivalent. Then the graph G = (V, A) with (v, v_hat) ∈ A if and
|
||||||
|
only if v -> v_hat is called a *Lamport graph*.
|
||||||
|
|
||||||
|
Key properties:
|
||||||
|
- Directed and acyclic (Proposition 3.6)
|
||||||
|
- The past of a vertex is invariant across Lamport graphs (Theorem 3.7)
|
||||||
|
- No two equivalent vertices exist in the same graph
|
||||||
|
|
||||||
|
This module implements:
|
||||||
|
- Algorithm 1: Message generation
|
||||||
|
- Algorithm 2: Message integrity checking and graph extension
|
||||||
|
- Causality queries (past, future, timelike, spacelike)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from crisis.crypto import digest
|
||||||
|
from crisis.message import Message, Vertex, Vote, ID_LENGTH, NONCE_LENGTH
|
||||||
|
from crisis.weight import ProofOfWorkWeight, WeightSystem
|
||||||
|
|
||||||
|
|
||||||
|
class LamportGraph:
|
||||||
|
"""A Lamport graph: a DAG of vertices connected by causal acknowledgement.
|
||||||
|
|
||||||
|
The graph is stored as:
|
||||||
|
- vertices: dict mapping message digest -> Vertex
|
||||||
|
- edges: dict mapping digest -> set of digests it references
|
||||||
|
(i.e. v -> v_hat means v acknowledges v_hat)
|
||||||
|
|
||||||
|
Invariants maintained:
|
||||||
|
- No two vertices have the same underlying message (no equivalence)
|
||||||
|
- All referenced digests either exist in the graph or are the empty digest
|
||||||
|
- The graph is acyclic (guaranteed by hash function properties)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, weight_system: WeightSystem | None = None):
|
||||||
|
self.weight_system: WeightSystem = weight_system or ProofOfWorkWeight(min_leading_zeros=0)
|
||||||
|
|
||||||
|
# digest -> Vertex
|
||||||
|
self.vertices: dict[bytes, Vertex] = {}
|
||||||
|
|
||||||
|
# digest -> set of digests this vertex references (outgoing causal edges)
|
||||||
|
# An edge v -> v_hat means "v acknowledges v_hat" i.e. H(v_hat.m) ∈ v.m.digests
|
||||||
|
self.edges: dict[bytes, set[bytes]] = {}
|
||||||
|
|
||||||
|
# Reverse edges for efficient "future" queries
|
||||||
|
# digest -> set of digests that reference this vertex
|
||||||
|
self.reverse_edges: dict[bytes, set[bytes]] = {}
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Graph queries
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self.vertices)
|
||||||
|
|
||||||
|
def __contains__(self, digest_or_vertex) -> bool:
|
||||||
|
if isinstance(digest_or_vertex, Vertex):
|
||||||
|
return digest_or_vertex.message_digest in self.vertices
|
||||||
|
return digest_or_vertex in self.vertices
|
||||||
|
|
||||||
|
def get_vertex(self, msg_digest: bytes) -> Optional[Vertex]:
|
||||||
|
return self.vertices.get(msg_digest)
|
||||||
|
|
||||||
|
def all_vertices(self) -> list[Vertex]:
|
||||||
|
return list(self.vertices.values())
|
||||||
|
|
||||||
|
def vertex_count(self) -> int:
|
||||||
|
return len(self.vertices)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Causality (Definition 3.2)
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# m -> m_hat (m happens before m_hat) iff:
|
||||||
|
# - m = m_hat, OR
|
||||||
|
# - there is a chain m -> m1 -> ... -> mk -> m_hat
|
||||||
|
# In our DAG: v has an edge to v_hat means v acknowledges v_hat.
|
||||||
|
# So v is in the *future* of v_hat, and v_hat is in the *past* of v.
|
||||||
|
|
||||||
|
def direct_causes(self, v: Vertex) -> list[Vertex]:
|
||||||
|
"""Return the direct causes of v (vertices that v acknowledges).
|
||||||
|
|
||||||
|
These are the vertices whose digests appear in v.m.digests.
|
||||||
|
In graph terms: the outgoing neighbors of v.
|
||||||
|
"""
|
||||||
|
result = []
|
||||||
|
for d in self.edges.get(v.message_digest, set()):
|
||||||
|
vertex = self.vertices.get(d)
|
||||||
|
if vertex is not None:
|
||||||
|
result.append(vertex)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def direct_effects(self, v: Vertex) -> list[Vertex]:
|
||||||
|
"""Return the direct effects of v (vertices that acknowledge v).
|
||||||
|
|
||||||
|
In graph terms: the incoming neighbors of v (who references v).
|
||||||
|
"""
|
||||||
|
result = []
|
||||||
|
for d in self.reverse_edges.get(v.message_digest, set()):
|
||||||
|
vertex = self.vertices.get(d)
|
||||||
|
if vertex is not None:
|
||||||
|
result.append(vertex)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def past(self, v: Vertex) -> set[Vertex]:
|
||||||
|
"""G_v: the subgraph of G containing all causes of v.
|
||||||
|
|
||||||
|
Definition 3.5: "the subgraph G_v of G that contains all causes
|
||||||
|
of v is called the *past* of v".
|
||||||
|
|
||||||
|
Returns the set of all vertices that are causally before v
|
||||||
|
(including v itself -- reflexivity).
|
||||||
|
"""
|
||||||
|
visited: set[bytes] = set()
|
||||||
|
stack = [v.message_digest]
|
||||||
|
|
||||||
|
while stack:
|
||||||
|
current = stack.pop()
|
||||||
|
if current in visited:
|
||||||
|
continue
|
||||||
|
visited.add(current)
|
||||||
|
for neighbor in self.edges.get(current, set()):
|
||||||
|
if neighbor in self.vertices and neighbor not in visited:
|
||||||
|
stack.append(neighbor)
|
||||||
|
|
||||||
|
return {self.vertices[d] for d in visited if d in self.vertices}
|
||||||
|
|
||||||
|
def future(self, v: Vertex) -> set[Vertex]:
|
||||||
|
"""All vertices that are causally after v (including v itself)."""
|
||||||
|
visited: set[bytes] = set()
|
||||||
|
stack = [v.message_digest]
|
||||||
|
|
||||||
|
while stack:
|
||||||
|
current = stack.pop()
|
||||||
|
if current in visited:
|
||||||
|
continue
|
||||||
|
visited.add(current)
|
||||||
|
for neighbor in self.reverse_edges.get(current, set()):
|
||||||
|
if neighbor in self.vertices and neighbor not in visited:
|
||||||
|
stack.append(neighbor)
|
||||||
|
|
||||||
|
return {self.vertices[d] for d in visited if d in self.vertices}
|
||||||
|
|
||||||
|
def is_cause_of(self, v: Vertex, v_hat: Vertex) -> bool:
|
||||||
|
"""Check if v ≤ v_hat (v is in the past of v_hat).
|
||||||
|
|
||||||
|
Definition 3.4: v is said to happen before v_hat (v ≤ v_hat)
|
||||||
|
if there is a causality chain from v to v_hat.
|
||||||
|
"""
|
||||||
|
if v == v_hat:
|
||||||
|
return True
|
||||||
|
return v in self.past(v_hat)
|
||||||
|
|
||||||
|
def are_timelike(self, v: Vertex, v_hat: Vertex) -> bool:
|
||||||
|
"""Check if v and v_hat are timelike (comparable / causally related)."""
|
||||||
|
return self.is_cause_of(v, v_hat) or self.is_cause_of(v_hat, v)
|
||||||
|
|
||||||
|
def are_spacelike(self, v: Vertex, v_hat: Vertex) -> bool:
|
||||||
|
"""Check if v and v_hat are spacelike (incomparable / no causal relation).
|
||||||
|
|
||||||
|
Spacelike vertices are the ones that need the total order protocol
|
||||||
|
to become comparable. The protocol extends the timelike partial
|
||||||
|
order to cover spacelike vertices as well.
|
||||||
|
"""
|
||||||
|
return not self.are_timelike(v, v_hat)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Mutations (Definition 4.2)
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def find_mutations(self, vertex_id: bytes) -> list[list[Vertex]]:
|
||||||
|
"""Find mutations: vertices with the same id that are spacelike.
|
||||||
|
|
||||||
|
Definition 4.2: Two vertices v and v_hat in G are called a *mutation*
|
||||||
|
of a virtual process if they have the same id and are spacelike,
|
||||||
|
i.e. neither v ≤ v_hat nor v_hat ≤ v holds.
|
||||||
|
|
||||||
|
Mutations are the virtual voting equivalent of equivocation -- a
|
||||||
|
byzantine actor sending different votes to different processes.
|
||||||
|
|
||||||
|
Returns a list of groups of mutually spacelike same-id vertices.
|
||||||
|
"""
|
||||||
|
# Group vertices by id
|
||||||
|
by_id: dict[bytes, list[Vertex]] = {}
|
||||||
|
for v in self.vertices.values():
|
||||||
|
by_id.setdefault(v.id, []).append(v)
|
||||||
|
|
||||||
|
mutations = []
|
||||||
|
for vid, group in by_id.items():
|
||||||
|
if vid != vertex_id:
|
||||||
|
continue
|
||||||
|
# Find spacelike pairs within the group
|
||||||
|
spacelike_group = []
|
||||||
|
for i, v1 in enumerate(group):
|
||||||
|
for v2 in group[i + 1:]:
|
||||||
|
if self.are_spacelike(v1, v2):
|
||||||
|
if v1 not in spacelike_group:
|
||||||
|
spacelike_group.append(v1)
|
||||||
|
if v2 not in spacelike_group:
|
||||||
|
spacelike_group.append(v2)
|
||||||
|
if spacelike_group:
|
||||||
|
mutations.append(spacelike_group)
|
||||||
|
return mutations
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Byte-level correctness (part of Algorithm 2)
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _bytelevel_correctness(self, message: Message) -> bool:
|
||||||
|
"""BYTELEVEL_CORRECTNESS: basic structural validation of a message.
|
||||||
|
|
||||||
|
Checks that the message has valid field lengths and is well-formed.
|
||||||
|
"""
|
||||||
|
if len(message.nonce) != NONCE_LENGTH:
|
||||||
|
return False
|
||||||
|
if len(message.id) != ID_LENGTH:
|
||||||
|
return False
|
||||||
|
for d in message.digests:
|
||||||
|
if len(d) != 32: # DIGEST_LENGTH
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _payload_correctness(self, message: Message) -> bool:
|
||||||
|
"""PAYLOAD_CORRECTNESS: validate the payload against system rules.
|
||||||
|
|
||||||
|
In this PoC, any payload is accepted. A real system would enforce
|
||||||
|
application-specific validation here.
|
||||||
|
"""
|
||||||
|
return True
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Algorithm 2: Message integrity (Section 4.2)
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def message_integrity(self, message: Message) -> bool:
|
||||||
|
"""Check whether a message can be validly added to this Lamport graph.
|
||||||
|
|
||||||
|
Algorithm 2 from the paper:
|
||||||
|
|
||||||
|
1. Check BYTELEVEL_CORRECTNESS(m)
|
||||||
|
2. Check w(m) > c_min (weight threshold)
|
||||||
|
3. Check PAYLOAD_CORRECTNESS(m.payload)
|
||||||
|
4. Check no equivalent vertex exists (no vertex with same digest)
|
||||||
|
5. For each digest in m.digests:
|
||||||
|
- It must reference a vertex in G
|
||||||
|
- All referenced vertices must have different id's
|
||||||
|
6. If there is a vertex v in G with v.id = m.id:
|
||||||
|
- One of m.digests must reference v (or a vertex in v's past)
|
||||||
|
Ensures the virtual process forms a chain, not a tree.
|
||||||
|
|
||||||
|
Returns True if the message passes integrity checks.
|
||||||
|
"""
|
||||||
|
# Step 1: byte-level structure
|
||||||
|
if not self._bytelevel_correctness(message):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Step 2: weight threshold
|
||||||
|
if not self.weight_system.is_valid_weight(message):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Step 3: payload rules
|
||||||
|
if not self._payload_correctness(message):
|
||||||
|
return False
|
||||||
|
|
||||||
|
msg_digest = message.compute_digest()
|
||||||
|
|
||||||
|
# Step 4: no duplicate (no equivalent vertex)
|
||||||
|
if msg_digest in self.vertices:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Step 5: all referenced digests must exist in G
|
||||||
|
# and all referenced vertices must have different ids
|
||||||
|
referenced_ids: set[bytes] = set()
|
||||||
|
for ref_digest in message.digests:
|
||||||
|
if ref_digest not in self.vertices:
|
||||||
|
return False
|
||||||
|
ref_vertex = self.vertices[ref_digest]
|
||||||
|
if ref_vertex.id in referenced_ids:
|
||||||
|
return False # Two references to same id
|
||||||
|
referenced_ids.add(ref_vertex.id)
|
||||||
|
|
||||||
|
# Step 6: if same id exists, must reference it (chain constraint)
|
||||||
|
# Find the "last vertex" with this id (not referenced by any other
|
||||||
|
# vertex with the same id)
|
||||||
|
same_id_vertices = [v for v in self.vertices.values() if v.id == message.id]
|
||||||
|
if same_id_vertices:
|
||||||
|
# Check that at least one digest references a same-id vertex
|
||||||
|
referenced_digests = set(message.digests)
|
||||||
|
found_chain_link = False
|
||||||
|
for v in same_id_vertices:
|
||||||
|
if v.message_digest in referenced_digests:
|
||||||
|
found_chain_link = True
|
||||||
|
break
|
||||||
|
if not found_chain_link:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Lamport graph extension (Section 4.2)
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def extend(self, message: Message) -> Optional[Vertex]:
|
||||||
|
"""Attempt to extend the Lamport graph with a new message.
|
||||||
|
|
||||||
|
If the message passes integrity checks (Algorithm 2), create a new
|
||||||
|
vertex and add it to the graph with appropriate edges.
|
||||||
|
|
||||||
|
Proposition 4.1 guarantees that the extension of a Lamport graph
|
||||||
|
by a valid message is itself a Lamport graph.
|
||||||
|
|
||||||
|
Returns the new Vertex if successful, None if integrity check fails.
|
||||||
|
"""
|
||||||
|
if not self.message_integrity(message):
|
||||||
|
return None
|
||||||
|
|
||||||
|
vertex = Vertex(m=message)
|
||||||
|
msg_digest = message.compute_digest()
|
||||||
|
|
||||||
|
# Add vertex
|
||||||
|
self.vertices[msg_digest] = vertex
|
||||||
|
|
||||||
|
# Add edges: this vertex -> each referenced vertex
|
||||||
|
self.edges[msg_digest] = set()
|
||||||
|
for ref_digest in message.digests:
|
||||||
|
self.edges[msg_digest].add(ref_digest)
|
||||||
|
# Reverse edge
|
||||||
|
if ref_digest not in self.reverse_edges:
|
||||||
|
self.reverse_edges[ref_digest] = set()
|
||||||
|
self.reverse_edges[ref_digest].add(msg_digest)
|
||||||
|
|
||||||
|
# Initialize reverse_edges entry for this vertex
|
||||||
|
if msg_digest not in self.reverse_edges:
|
||||||
|
self.reverse_edges[msg_digest] = set()
|
||||||
|
|
||||||
|
return vertex
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Algorithm 1: Message generation (Section 4.1)
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def generate_message(self, process_id: bytes, payload: bytes,
|
||||||
|
weight_system: WeightSystem | None = None) -> Message:
|
||||||
|
"""Generate a valid message for a given virtual process id.
|
||||||
|
|
||||||
|
Algorithm 1 from the paper:
|
||||||
|
1. Find the last vertex v with v.id = id in G
|
||||||
|
2. Choose S ⊂ {v.m | v ∈ G ∧ v ∉ G_v} such that all have different ids
|
||||||
|
3. Return message with digests = {H(v.m)} ∪ {H(m) | m ∈ S ∪ {v.m}}
|
||||||
|
|
||||||
|
The nonce is chosen so that w(m) > c_min (via mining if PoW).
|
||||||
|
"""
|
||||||
|
ws = weight_system or self.weight_system
|
||||||
|
|
||||||
|
# Find the last vertex with this process id
|
||||||
|
last_vertex = self._find_last_vertex(process_id)
|
||||||
|
|
||||||
|
# Collect digests: last same-id vertex + a sample of other vertices
|
||||||
|
digests_list: list[bytes] = []
|
||||||
|
|
||||||
|
if last_vertex is not None:
|
||||||
|
# Must reference the last vertex with same id
|
||||||
|
digests_list.append(last_vertex.message_digest)
|
||||||
|
|
||||||
|
# Add cross-references to vertices NOT in last_vertex's past
|
||||||
|
past_digests = {v.message_digest for v in self.past(last_vertex)}
|
||||||
|
candidates = [
|
||||||
|
v for d, v in self.vertices.items()
|
||||||
|
if d not in past_digests
|
||||||
|
and v.id != process_id
|
||||||
|
and d != last_vertex.message_digest
|
||||||
|
]
|
||||||
|
|
||||||
|
# Include candidates with different ids
|
||||||
|
seen_ids: set[bytes] = {process_id}
|
||||||
|
for candidate in candidates:
|
||||||
|
if candidate.id not in seen_ids:
|
||||||
|
digests_list.append(candidate.message_digest)
|
||||||
|
seen_ids.add(candidate.id)
|
||||||
|
else:
|
||||||
|
# First message for this id: reference a sample of existing vertices
|
||||||
|
seen_ids = {process_id}
|
||||||
|
for v in self.vertices.values():
|
||||||
|
if v.id not in seen_ids:
|
||||||
|
digests_list.append(v.message_digest)
|
||||||
|
seen_ids.add(v.id)
|
||||||
|
|
||||||
|
digests_tuple = tuple(digests_list)
|
||||||
|
|
||||||
|
# Mine a valid nonce (or just find one that meets threshold)
|
||||||
|
if isinstance(ws, ProofOfWorkWeight):
|
||||||
|
message = ws.mine_nonce(process_id, digests_tuple, payload)
|
||||||
|
else:
|
||||||
|
# For non-PoW systems, use a random nonce
|
||||||
|
nonce = os.urandom(NONCE_LENGTH)
|
||||||
|
message = Message(nonce=nonce, id=process_id, digests=digests_tuple, payload=payload)
|
||||||
|
|
||||||
|
return message
|
||||||
|
|
||||||
|
def _find_last_vertex(self, process_id: bytes) -> Optional[Vertex]:
|
||||||
|
"""Find the last vertex with a given process id.
|
||||||
|
|
||||||
|
A vertex is "last" for an id if no other vertex with the same id
|
||||||
|
references it (i.e. it has no same-id successor).
|
||||||
|
"""
|
||||||
|
same_id = [v for v in self.vertices.values() if v.id == process_id]
|
||||||
|
if not same_id:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Find the one that is not referenced by any other same-id vertex
|
||||||
|
referenced_by_same_id: set[bytes] = set()
|
||||||
|
for v in same_id:
|
||||||
|
for d in v.digests:
|
||||||
|
ref = self.vertices.get(d)
|
||||||
|
if ref is not None and ref.id == process_id:
|
||||||
|
referenced_by_same_id.add(d)
|
||||||
|
|
||||||
|
for v in same_id:
|
||||||
|
if v.message_digest not in referenced_by_same_id:
|
||||||
|
return v
|
||||||
|
|
||||||
|
# Fallback: return the one added most recently (by convention)
|
||||||
|
return same_id[-1]
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Vertices by id (for virtual process queries)
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def vertices_by_id(self, process_id: bytes) -> list[Vertex]:
|
||||||
|
"""Return all vertices belonging to a given virtual process id."""
|
||||||
|
return [v for v in self.vertices.values() if v.id == process_id]
|
||||||
|
|
||||||
|
def all_process_ids(self) -> set[bytes]:
|
||||||
|
"""Return all unique virtual process ids in this graph."""
|
||||||
|
return {v.id for v in self.vertices.values()}
|
||||||
|
|
||||||
|
def last_vertices_by_id(self) -> dict[bytes, Vertex]:
|
||||||
|
"""Return the last vertex for each virtual process id."""
|
||||||
|
result = {}
|
||||||
|
for pid in self.all_process_ids():
|
||||||
|
last = self._find_last_vertex(pid)
|
||||||
|
if last is not None:
|
||||||
|
result[pid] = last
|
||||||
|
return result
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Weight queries
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def vertex_weight(self, v: Vertex) -> int:
|
||||||
|
"""w(v) = w(v.m): the weight of a vertex is the weight of its message."""
|
||||||
|
return self.weight_system.weight(v.m)
|
||||||
|
|
||||||
|
def set_weight(self, vertices: set[Vertex] | list[Vertex]) -> int:
|
||||||
|
"""w(M) := ⊕_{m ∈ M} w(m): the combined weight of a set of vertices."""
|
||||||
|
total = 0
|
||||||
|
for v in vertices:
|
||||||
|
total = self.weight_system.weight_sum(total, self.vertex_weight(v))
|
||||||
|
return total
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Display
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"LamportGraph(vertices={len(self.vertices)}, ids={len(self.all_process_ids())})"
|
||||||
250
src/crisis/message.py
Normal file
250
src/crisis/message.py
Normal file
|
|
@ -0,0 +1,250 @@
|
||||||
|
"""
|
||||||
|
Data Structures (Section 3)
|
||||||
|
|
||||||
|
3.1 Messages
|
||||||
|
-------------
|
||||||
|
Messages distribute payload across the network. The purpose of the protocol
|
||||||
|
is to establish a total order on those messages that respects causality.
|
||||||
|
|
||||||
|
A message is a byte string of variable length with the following structure
|
||||||
|
(paper, page 3):
|
||||||
|
|
||||||
|
struct Message {
|
||||||
|
byte[c1] nonce,
|
||||||
|
byte[c2] id,
|
||||||
|
byte[c3] num_digests,
|
||||||
|
byte[p * num_digests] digests,
|
||||||
|
byte[] payload
|
||||||
|
}
|
||||||
|
|
||||||
|
Where c1, c2, c3 are fixed protocol constants and p is the digest length.
|
||||||
|
|
||||||
|
The *nonce* is used by the weight function (e.g. PoW grinding).
|
||||||
|
The *id* groups messages into virtual processes.
|
||||||
|
The *digests* field encodes causal acknowledgement of other messages.
|
||||||
|
|
||||||
|
Key insight: a message that acknowledges other messages defines an inherent
|
||||||
|
natural causality -- this is the Lamport "happens-before" relation (1978).
|
||||||
|
|
||||||
|
m -> m_hat iff H(m_hat) is contained in m.digests (Eq. 2)
|
||||||
|
|
||||||
|
3.1.3 Vertices
|
||||||
|
---------------
|
||||||
|
To establish total order, messages are extended by local voting data that is
|
||||||
|
NOT transmitted. Votes are deduced from the causal relation between messages.
|
||||||
|
This is the key characteristic of virtual voting (Moser & Melliar-Smith).
|
||||||
|
|
||||||
|
struct Vertex {
|
||||||
|
Message m,
|
||||||
|
Option<uint> round,
|
||||||
|
Option<boolean> is_last,
|
||||||
|
Option<TotalOrderSet<uint>> svp, # safe voting pattern
|
||||||
|
Option<(Message, Option<bool>)> vote,
|
||||||
|
Option<uint> total_position
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from crisis.crypto import digest, DIGEST_LENGTH
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Protocol constants (c1, c2, c3 from the paper)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# These define the byte-lengths of the fixed-size fields in a message.
|
||||||
|
# Chosen for a practical PoC: generous enough for real use, compact enough
|
||||||
|
# for clarity.
|
||||||
|
|
||||||
|
NONCE_LENGTH = 8 # c1: 8 bytes of nonce (plenty for PoW search space)
|
||||||
|
ID_LENGTH = 32 # c2: 32 bytes for virtual process id (a hash)
|
||||||
|
NUM_DIGESTS_LENGTH = 2 # c3: 2 bytes => up to 65535 referenced digests
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Message
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class Message:
|
||||||
|
"""An immutable Crisis message as defined in Section 3.1.
|
||||||
|
|
||||||
|
A message is the atomic unit of communication in the Crisis protocol.
|
||||||
|
It carries a payload and encodes causal history through its digests field.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
nonce: Used by the weight function (e.g. PoW nonce grinding).
|
||||||
|
id: Groups this message into a virtual process.
|
||||||
|
digests: Tuple of digests of causally prior messages (H values).
|
||||||
|
payload: The actual application data being ordered.
|
||||||
|
"""
|
||||||
|
nonce: bytes
|
||||||
|
id: bytes
|
||||||
|
digests: tuple[bytes, ...] = ()
|
||||||
|
payload: bytes = b""
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if len(self.nonce) != NONCE_LENGTH:
|
||||||
|
raise ValueError(f"nonce must be {NONCE_LENGTH} bytes, got {len(self.nonce)}")
|
||||||
|
if len(self.id) != ID_LENGTH:
|
||||||
|
raise ValueError(f"id must be {ID_LENGTH} bytes, got {len(self.id)}")
|
||||||
|
for i, d in enumerate(self.digests):
|
||||||
|
if len(d) != DIGEST_LENGTH:
|
||||||
|
raise ValueError(f"digest[{i}] must be {DIGEST_LENGTH} bytes")
|
||||||
|
|
||||||
|
def serialize(self) -> bytes:
|
||||||
|
"""Serialize this message to a canonical byte string.
|
||||||
|
|
||||||
|
The serialized form is what gets hashed to produce the message's digest.
|
||||||
|
Format: nonce | id | num_digests (2 bytes big-endian) | digests... | payload
|
||||||
|
"""
|
||||||
|
num = len(self.digests)
|
||||||
|
parts = [
|
||||||
|
self.nonce,
|
||||||
|
self.id,
|
||||||
|
num.to_bytes(NUM_DIGESTS_LENGTH, "big"),
|
||||||
|
]
|
||||||
|
for d in self.digests:
|
||||||
|
parts.append(d)
|
||||||
|
parts.append(self.payload)
|
||||||
|
return b"".join(parts)
|
||||||
|
|
||||||
|
def compute_digest(self) -> bytes:
|
||||||
|
"""Compute H(m) -- the digest of this message.
|
||||||
|
|
||||||
|
This is the value other messages include in their digests field
|
||||||
|
to acknowledge this message (establishing causality, Eq. 2).
|
||||||
|
"""
|
||||||
|
return digest(self.serialize())
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_digests(self) -> int:
|
||||||
|
return len(self.digests)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
h = self.compute_digest().hex()[:12]
|
||||||
|
return f"Message(id={self.id.hex()[:8]}..., digests={self.num_digests}, hash={h}...)"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# The empty message (paper: ∅ ∈ MESSAGE)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# "We postulate a special non-message ∅ ∈ MESSAGE" (Section 3.1)
|
||||||
|
# Acknowledgement of ∅ is defined as H(empty string).
|
||||||
|
|
||||||
|
EMPTY_MESSAGE_DIGEST = digest(b"")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Vote
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class Vote:
|
||||||
|
"""A virtual vote as computed locally by each vertex.
|
||||||
|
|
||||||
|
From the paper (Algorithm 7): v.vote(r) = (l, b) describes v's vote
|
||||||
|
on some message l, together with a possibly undecided binary value
|
||||||
|
b ∈ {⊥, 0, 1} in a round r.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
message: The message l being voted on (None = ∅, the non-leader).
|
||||||
|
binary: The binary part of the vote: None=⊥ (undecided), 0, or 1.
|
||||||
|
"""
|
||||||
|
message: Optional[Message] = None
|
||||||
|
binary: Optional[int] = None # None = ⊥, 0, or 1
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
msg_str = "∅" if self.message is None else self.message.compute_digest().hex()[:8]
|
||||||
|
bin_str = "⊥" if self.binary is None else str(self.binary)
|
||||||
|
return f"Vote({msg_str}, {bin_str})"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Vertex
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Vertex:
|
||||||
|
"""A vertex in a Lamport graph (Section 3.1.3).
|
||||||
|
|
||||||
|
A vertex wraps a message and adds locally-computed consensus state.
|
||||||
|
The additional fields (round, is_last, svp, vote, total_position) are
|
||||||
|
never transmitted -- they are deduced from the causal structure.
|
||||||
|
|
||||||
|
From the paper (page 5, Eq. 6):
|
||||||
|
w(v) <- w(v.m)
|
||||||
|
v.nonce <- v.m.nonce
|
||||||
|
v.id <- v.m.id
|
||||||
|
v.num_digests <- v.m.num_digests
|
||||||
|
v.digests <- v.m.digests
|
||||||
|
v.payload <- v.m.payload
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
m: The underlying message.
|
||||||
|
round: The virtual round number (Algorithm 5).
|
||||||
|
is_last: Whether this is a "last vertex" of its round (Alg 5).
|
||||||
|
svp: Safe voting pattern -- ordered set of round numbers.
|
||||||
|
vote: Per-round votes: round -> Vote.
|
||||||
|
total_position: Final position in the total order (Algorithm 9/10).
|
||||||
|
"""
|
||||||
|
m: Message
|
||||||
|
|
||||||
|
# Locally computed consensus state (initialized to None / ⊥)
|
||||||
|
round: Optional[int] = None
|
||||||
|
is_last: Optional[bool] = None
|
||||||
|
svp: list[int] = field(default_factory=list)
|
||||||
|
vote: dict[int, Vote] = field(default_factory=dict)
|
||||||
|
total_position: Optional[int] = None
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Convenience accessors that delegate to the underlying message
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
@property
|
||||||
|
def nonce(self) -> bytes:
|
||||||
|
return self.m.nonce
|
||||||
|
|
||||||
|
@property
|
||||||
|
def id(self) -> bytes:
|
||||||
|
return self.m.id
|
||||||
|
|
||||||
|
@property
|
||||||
|
def digests(self) -> tuple[bytes, ...]:
|
||||||
|
return self.m.digests
|
||||||
|
|
||||||
|
@property
|
||||||
|
def payload(self) -> bytes:
|
||||||
|
return self.m.payload
|
||||||
|
|
||||||
|
@property
|
||||||
|
def message_digest(self) -> bytes:
|
||||||
|
"""H(v.m) -- the digest that uniquely identifies this vertex's message."""
|
||||||
|
return self.m.compute_digest()
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Equivalence (Definition 3.3)
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# "Two vertices v and v_hat are equivalent if v.m = v_hat.m"
|
||||||
|
# i.e. they wrap the same underlying message.
|
||||||
|
|
||||||
|
def equivalent_to(self, other: Vertex) -> bool:
|
||||||
|
"""Check vertex equivalence: same underlying message."""
|
||||||
|
return self.message_digest == other.message_digest
|
||||||
|
|
||||||
|
def __eq__(self, other: object) -> bool:
|
||||||
|
if not isinstance(other, Vertex):
|
||||||
|
return NotImplemented
|
||||||
|
return self.message_digest == other.message_digest
|
||||||
|
|
||||||
|
def __hash__(self) -> int:
|
||||||
|
return hash(self.message_digest)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
h = self.message_digest.hex()[:12]
|
||||||
|
round_str = str(self.round) if self.round is not None else "?"
|
||||||
|
last_str = "*" if self.is_last else ""
|
||||||
|
return f"Vertex({h}..., r={round_str}{last_str})"
|
||||||
336
src/crisis/node.py
Normal file
336
src/crisis/node.py
Normal file
|
|
@ -0,0 +1,336 @@
|
||||||
|
"""
|
||||||
|
Crisis Node (Section 5.9 -- The Crisis Protocol)
|
||||||
|
|
||||||
|
This module ties all components together into a full Crisis node.
|
||||||
|
|
||||||
|
From the paper (Section 5.9):
|
||||||
|
"The overall algorithm works as follows: Member discovery (3) and
|
||||||
|
message gossip (4) are executed in infinite loops, concurrently to
|
||||||
|
the rest of the system. Ideally the message sending loop is executed
|
||||||
|
on as many parallel threads as possible. This implies that an overall
|
||||||
|
unbounded amount of new messages arrive over time due to our liveness
|
||||||
|
assumption. In addition each process may generate messages and write
|
||||||
|
them into its own Lamport graph."
|
||||||
|
|
||||||
|
The full node runs these concurrent loops:
|
||||||
|
1. Gossip: member discovery + message dissemination
|
||||||
|
2. Message generation: create new messages with PoW
|
||||||
|
3. Consensus: compute rounds, voting patterns, leader elections, order
|
||||||
|
|
||||||
|
Each loop runs independently and they communicate through the shared
|
||||||
|
Lamport graph.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from crisis.crypto import digest
|
||||||
|
from crisis.graph import LamportGraph
|
||||||
|
from crisis.gossip import GossipServer, NetworkView, PeerInfo
|
||||||
|
from crisis.message import Message, Vertex, ID_LENGTH, NONCE_LENGTH
|
||||||
|
from crisis.order import LeaderStream, compute_order
|
||||||
|
from crisis.rounds import compute_rounds
|
||||||
|
from crisis.voting import compute_virtual_leader_election, compute_safe_voting_pattern
|
||||||
|
from crisis.weight import ProofOfWorkWeight, DifficultyOracle
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class CrisisNode:
|
||||||
|
"""A full Crisis protocol node.
|
||||||
|
|
||||||
|
Combines all protocol components into a single running process:
|
||||||
|
- Lamport graph (the shared DAG)
|
||||||
|
- Weight system (PoW)
|
||||||
|
- Difficulty oracle
|
||||||
|
- Gossip server (member discovery + message dissemination)
|
||||||
|
- Consensus engine (rounds, voting, ordering)
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
process_id: This node's virtual process identity.
|
||||||
|
graph: The local Lamport graph.
|
||||||
|
leader_stream: The evolving total order leader stream.
|
||||||
|
network_view: Known peers in the network.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, host: str = "127.0.0.1", port: int = 9000,
|
||||||
|
min_pow_zeros: int = 1,
|
||||||
|
difficulty_constant: int = 4,
|
||||||
|
connectivity_k: int = 2,
|
||||||
|
message_interval: float = 3.0,
|
||||||
|
consensus_interval: float = 5.0,
|
||||||
|
seed_peers: list[tuple[str, int]] | None = None):
|
||||||
|
# Identity: use a hash of host:port as this node's virtual process id
|
||||||
|
self.process_id = digest(f"{host}:{port}".encode())[:ID_LENGTH]
|
||||||
|
self.host = host
|
||||||
|
self.port = port
|
||||||
|
|
||||||
|
# Protocol components
|
||||||
|
self.weight_system = ProofOfWorkWeight(min_leading_zeros=min_pow_zeros)
|
||||||
|
self.difficulty = DifficultyOracle(constant_difficulty=difficulty_constant)
|
||||||
|
self.connectivity_k = connectivity_k
|
||||||
|
self.graph = LamportGraph(weight_system=self.weight_system)
|
||||||
|
self.leader_stream = LeaderStream()
|
||||||
|
|
||||||
|
# Timing
|
||||||
|
self.message_interval = message_interval
|
||||||
|
self.consensus_interval = consensus_interval
|
||||||
|
|
||||||
|
# Network
|
||||||
|
self.network_view = NetworkView()
|
||||||
|
if seed_peers:
|
||||||
|
for h, p in seed_peers:
|
||||||
|
self.network_view.add_peer(PeerInfo(host=h, port=p))
|
||||||
|
|
||||||
|
self.gossip = GossipServer(
|
||||||
|
host=host, port=port,
|
||||||
|
graph=self.graph,
|
||||||
|
network_view=self.network_view,
|
||||||
|
)
|
||||||
|
|
||||||
|
# State
|
||||||
|
self._running = False
|
||||||
|
self._message_count = 0
|
||||||
|
|
||||||
|
# Callbacks for monitoring
|
||||||
|
self.on_new_vertex: Optional[callable] = None
|
||||||
|
self.on_round_update: Optional[callable] = None
|
||||||
|
self.on_order_update: Optional[callable] = None
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Main entry point
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def run(self) -> None:
|
||||||
|
"""Start all protocol loops concurrently.
|
||||||
|
|
||||||
|
This is the Crisis protocol (Section 5.9): three concurrent loops.
|
||||||
|
"""
|
||||||
|
self._running = True
|
||||||
|
logger.info(
|
||||||
|
f"Crisis node starting on {self.host}:{self.port} "
|
||||||
|
f"(id={self.process_id.hex()[:16]}...)"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await asyncio.gather(
|
||||||
|
self._gossip_loop(),
|
||||||
|
self._message_generation_loop(),
|
||||||
|
self._consensus_loop(),
|
||||||
|
)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.info("Crisis node shutting down")
|
||||||
|
finally:
|
||||||
|
self._running = False
|
||||||
|
|
||||||
|
async def stop(self) -> None:
|
||||||
|
self._running = False
|
||||||
|
await self.gossip.stop()
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Loop 1: Gossip (Algorithms 3 & 4)
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def _gossip_loop(self) -> None:
|
||||||
|
"""Run the gossip server (member discovery + message dissemination)."""
|
||||||
|
try:
|
||||||
|
await self.gossip.start()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Gossip loop error: {e}")
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Loop 2: Message generation (Algorithm 1)
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def _message_generation_loop(self) -> None:
|
||||||
|
"""Periodically generate new messages and add them to the graph.
|
||||||
|
|
||||||
|
Each message:
|
||||||
|
1. References the last same-id message (chain constraint)
|
||||||
|
2. References a sample of other vertices (cross-links for connectivity)
|
||||||
|
3. Has a PoW nonce meeting the weight threshold
|
||||||
|
4. Carries an application payload
|
||||||
|
"""
|
||||||
|
while self._running:
|
||||||
|
await asyncio.sleep(self.message_interval)
|
||||||
|
|
||||||
|
try:
|
||||||
|
payload = self._generate_payload()
|
||||||
|
message = self.graph.generate_message(
|
||||||
|
self.process_id, payload, self.weight_system
|
||||||
|
)
|
||||||
|
vertex = self.graph.extend(message)
|
||||||
|
|
||||||
|
if vertex is not None:
|
||||||
|
self._message_count += 1
|
||||||
|
logger.debug(
|
||||||
|
f"Generated message #{self._message_count}: {vertex}"
|
||||||
|
)
|
||||||
|
if self.on_new_vertex:
|
||||||
|
self.on_new_vertex(vertex)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Message generation error: {e}")
|
||||||
|
|
||||||
|
def _generate_payload(self) -> bytes:
|
||||||
|
"""Generate a payload for a new message.
|
||||||
|
|
||||||
|
In this PoC, payloads are simple timestamped entries.
|
||||||
|
A real application would put actual data here.
|
||||||
|
"""
|
||||||
|
self._message_count += 1
|
||||||
|
return f"msg-{self._message_count}-{time.time():.3f}".encode()
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Loop 3: Consensus (Algorithms 5, 6, 7, 9, 10)
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def _consensus_loop(self) -> None:
|
||||||
|
"""Periodically recompute consensus state.
|
||||||
|
|
||||||
|
From Section 5.9 and the proof section (Section 6):
|
||||||
|
"algorithms (5), (6) and (7) are executed in that order concurrently
|
||||||
|
on each vertex from V... the total order loop (9) runs concurrently
|
||||||
|
and waits for updates of the leader stream."
|
||||||
|
"""
|
||||||
|
while self._running:
|
||||||
|
await asyncio.sleep(self.consensus_interval)
|
||||||
|
|
||||||
|
if self.graph.vertex_count() == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Step 1: Compute rounds (Algorithm 5)
|
||||||
|
compute_rounds(self.graph, self.difficulty, self.connectivity_k)
|
||||||
|
|
||||||
|
if self.on_round_update:
|
||||||
|
self.on_round_update(self.graph)
|
||||||
|
|
||||||
|
# Step 2: Compute safe voting patterns (Algorithm 6)
|
||||||
|
for vertex in self.graph.all_vertices():
|
||||||
|
if vertex.is_last:
|
||||||
|
compute_safe_voting_pattern(
|
||||||
|
vertex, self.graph, self.difficulty,
|
||||||
|
self.connectivity_k
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 3: Virtual leader election (Algorithm 7)
|
||||||
|
leader_dict: dict[int, list[tuple[int, Message]]] = {}
|
||||||
|
for vertex in self.graph.all_vertices():
|
||||||
|
if vertex.svp:
|
||||||
|
compute_virtual_leader_election(
|
||||||
|
vertex, self.graph, self.difficulty,
|
||||||
|
self.connectivity_k, leader_dict
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update leader stream from election results
|
||||||
|
for round_num, entries in leader_dict.items():
|
||||||
|
for deciding_round, leader_msg in entries:
|
||||||
|
self.leader_stream.update(
|
||||||
|
round_num, deciding_round, leader_msg
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 4: Compute total order (Algorithms 9 & 10)
|
||||||
|
ordered = compute_order(self.graph, self.leader_stream)
|
||||||
|
|
||||||
|
if ordered and self.on_order_update:
|
||||||
|
self.on_order_update(ordered)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"Consensus: {self.graph.vertex_count()} vertices, "
|
||||||
|
f"{len(self.leader_stream.leaders)} leaders, "
|
||||||
|
f"{len(ordered)} ordered"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Consensus loop error: {e}")
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Public API
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def submit_message(self, payload: bytes) -> Optional[Vertex]:
|
||||||
|
"""Submit an application message to be ordered by the protocol."""
|
||||||
|
message = self.graph.generate_message(
|
||||||
|
self.process_id, payload, self.weight_system
|
||||||
|
)
|
||||||
|
return self.graph.extend(message)
|
||||||
|
|
||||||
|
def get_total_order(self) -> list[tuple[int, bytes]]:
|
||||||
|
"""Get the current total order as (position, payload) pairs."""
|
||||||
|
ordered = compute_order(self.graph, self.leader_stream)
|
||||||
|
return [
|
||||||
|
(v.total_position, v.payload)
|
||||||
|
for v in ordered
|
||||||
|
if v.total_position is not None
|
||||||
|
]
|
||||||
|
|
||||||
|
def status(self) -> dict:
|
||||||
|
"""Return a summary of this node's current state."""
|
||||||
|
from crisis.rounds import max_round as get_max_round
|
||||||
|
return {
|
||||||
|
"process_id": self.process_id.hex()[:16],
|
||||||
|
"address": f"{self.host}:{self.port}",
|
||||||
|
"vertices": self.graph.vertex_count(),
|
||||||
|
"process_ids": len(self.graph.all_process_ids()),
|
||||||
|
"max_round": get_max_round(self.graph),
|
||||||
|
"leaders": len(self.leader_stream.leaders),
|
||||||
|
"peers": len(self.network_view.peers),
|
||||||
|
"messages_generated": self._message_count,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# CLI entry point
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Run a Crisis node from the command line."""
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Crisis Protocol Node",
|
||||||
|
epilog="Probabilistically self-organizing total order in P2P networks"
|
||||||
|
)
|
||||||
|
parser.add_argument("--host", default="127.0.0.1", help="Listen address")
|
||||||
|
parser.add_argument("--port", type=int, default=9000, help="Listen port")
|
||||||
|
parser.add_argument("--pow-zeros", type=int, default=1,
|
||||||
|
help="Min PoW leading zeros (weight threshold)")
|
||||||
|
parser.add_argument("--difficulty", type=int, default=4,
|
||||||
|
help="Difficulty oracle constant")
|
||||||
|
parser.add_argument("--msg-interval", type=float, default=3.0,
|
||||||
|
help="Seconds between message generation")
|
||||||
|
parser.add_argument("--peers", nargs="*", default=[],
|
||||||
|
help="Seed peers as host:port")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
seed_peers = []
|
||||||
|
for peer_str in args.peers:
|
||||||
|
h, p = peer_str.rsplit(":", 1)
|
||||||
|
seed_peers.append((h, int(p)))
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format="%(asctime)s [%(name)s] %(levelname)s: %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
|
node = CrisisNode(
|
||||||
|
host=args.host,
|
||||||
|
port=args.port,
|
||||||
|
min_pow_zeros=args.pow_zeros,
|
||||||
|
difficulty_constant=args.difficulty,
|
||||||
|
seed_peers=seed_peers,
|
||||||
|
message_interval=args.msg_interval,
|
||||||
|
)
|
||||||
|
|
||||||
|
asyncio.run(node.run())
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
254
src/crisis/order.py
Normal file
254
src/crisis/order.py
Normal file
|
|
@ -0,0 +1,254 @@
|
||||||
|
"""
|
||||||
|
Total Order (Section 5.8)
|
||||||
|
|
||||||
|
As time goes by and the Lamport graph grows, more and more round leaders
|
||||||
|
are computed and incorporated into the global leader stream LEADER_G(·).
|
||||||
|
|
||||||
|
Algorithm 9 (Order loop): watches for leader stream updates and recomputes
|
||||||
|
total order. Total order is achieved by topological sorting on the past
|
||||||
|
of appropriate vertices.
|
||||||
|
|
||||||
|
Algorithm 10 (Total order using Kahn's algorithm): generates total order
|
||||||
|
in linear runtime by processing vertices without outgoing causal edges first,
|
||||||
|
using voting weight to break ties among spacelike vertices.
|
||||||
|
|
||||||
|
The total order converges probabilistically: any two non-byzantine processes
|
||||||
|
will eventually compute the same total order (Proposition 6.21).
|
||||||
|
|
||||||
|
Definition 5.17 (Leader Stream):
|
||||||
|
LEADER_G : N -> Option<(uint, MESSAGE)>
|
||||||
|
is called the *global leader stream* of the Lamport graph.
|
||||||
|
|
||||||
|
Corollary 6.19 (Leader stream convergence):
|
||||||
|
If the probability for new rounds and safe voting pattern is not zero,
|
||||||
|
the leader streams of any two honest processes will converge.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from crisis.graph import LamportGraph
|
||||||
|
from crisis.message import Message, Vertex
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Leader Stream (Definition 5.17)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class LeaderStream:
|
||||||
|
"""The global leader stream of a Lamport graph.
|
||||||
|
|
||||||
|
Maps round numbers to (deciding_round, leader_message) pairs.
|
||||||
|
Uses the Nakamoto longest chain rule: when a new leader is decided
|
||||||
|
in a later round, it may replace leaders decided in earlier rounds.
|
||||||
|
|
||||||
|
The leader stream converges to contain a single element per round
|
||||||
|
(Theorem 6.18), and honest processes' leader streams converge to
|
||||||
|
the same values (Corollary 6.19).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
# round_number -> (deciding_round, leader_message)
|
||||||
|
self.leaders: dict[int, tuple[int, Message]] = {}
|
||||||
|
|
||||||
|
def update(self, round_number: int, deciding_round: int,
|
||||||
|
leader_message: Message) -> bool:
|
||||||
|
"""Update the leader for a round via the Nakamoto longest chain rule.
|
||||||
|
|
||||||
|
Algorithm 8 (LONG_CHAIN): keep only the leader decided in the
|
||||||
|
highest round. Delete leaders from previous rounds that have
|
||||||
|
lower deciding rounds.
|
||||||
|
|
||||||
|
Returns True if the leader stream was modified.
|
||||||
|
"""
|
||||||
|
current = self.leaders.get(round_number)
|
||||||
|
|
||||||
|
if current is not None:
|
||||||
|
existing_deciding_round, _ = current
|
||||||
|
if existing_deciding_round >= deciding_round:
|
||||||
|
return False # Already have a leader from a higher round
|
||||||
|
|
||||||
|
self.leaders[round_number] = (deciding_round, leader_message)
|
||||||
|
|
||||||
|
# Prune: remove leaders with lower deciding rounds
|
||||||
|
# (longest chain rule -- keep only the longest)
|
||||||
|
max_deciding = max(dr for dr, _ in self.leaders.values())
|
||||||
|
to_remove = []
|
||||||
|
for r, (dr, _) in self.leaders.items():
|
||||||
|
if dr < max_deciding and r < round_number:
|
||||||
|
to_remove.append(r)
|
||||||
|
for r in to_remove:
|
||||||
|
del self.leaders[r]
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_leader(self, round_number: int) -> Optional[Message]:
|
||||||
|
"""Get the current leader message for a round, if any."""
|
||||||
|
entry = self.leaders.get(round_number)
|
||||||
|
return entry[1] if entry else None
|
||||||
|
|
||||||
|
def max_round(self) -> int:
|
||||||
|
"""Highest round with a decided leader."""
|
||||||
|
return max(self.leaders.keys()) if self.leaders else -1
|
||||||
|
|
||||||
|
def all_leaders(self) -> list[tuple[int, Message]]:
|
||||||
|
"""Return all leaders ordered by round number."""
|
||||||
|
return [(r, msg) for r, (_, msg) in sorted(self.leaders.items())]
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
rounds = sorted(self.leaders.keys())
|
||||||
|
return f"LeaderStream(rounds={rounds})"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Algorithm 9: Order Loop
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def compute_order(graph: LamportGraph, leader_stream: LeaderStream) -> list[Vertex]:
|
||||||
|
"""Algorithm 9: compute total order from the leader stream.
|
||||||
|
|
||||||
|
Pseudocode:
|
||||||
|
1: loop order update loop
|
||||||
|
2: wait for LEADER_G(·) to change
|
||||||
|
3: s <- min round of all changed LEADER_G(t)
|
||||||
|
4: r <- max round of all LEADER_G(t) ≠ ∅
|
||||||
|
5: v_{l_r} <- leader in highest round, smallest s in G
|
||||||
|
6: n <- max(v.total_position | v ∈ Ord_G(v_{l_{r-1}}))
|
||||||
|
7: for x ≤ t ≤ r do
|
||||||
|
8: randomly choose (p, l_t) ∈ LEADER_G(t)
|
||||||
|
9: if l_t ≠ ∅ then
|
||||||
|
10: ORDER(Ord_G(v_t), n) ▷ v_t.m = l_t
|
||||||
|
11: end if
|
||||||
|
12: end for
|
||||||
|
13: end loop
|
||||||
|
|
||||||
|
For this PoC, we compute the order in a single pass over the current
|
||||||
|
leader stream state.
|
||||||
|
"""
|
||||||
|
if not leader_stream.leaders:
|
||||||
|
return []
|
||||||
|
|
||||||
|
ordered: list[Vertex] = []
|
||||||
|
position = 0
|
||||||
|
|
||||||
|
# Process leaders in round order
|
||||||
|
for round_number, leader_message in leader_stream.all_leaders():
|
||||||
|
# Find the vertex corresponding to this leader message
|
||||||
|
leader_digest = leader_message.compute_digest()
|
||||||
|
leader_vertex = graph.get_vertex(leader_digest)
|
||||||
|
|
||||||
|
if leader_vertex is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Order the past of this leader vertex (excluding already-ordered)
|
||||||
|
past_vertices = graph.past(leader_vertex)
|
||||||
|
already_ordered = {v.message_digest for v in ordered}
|
||||||
|
new_vertices = [
|
||||||
|
v for v in past_vertices
|
||||||
|
if v.message_digest not in already_ordered
|
||||||
|
]
|
||||||
|
|
||||||
|
# Sort new vertices using Kahn's algorithm (Algorithm 10)
|
||||||
|
sorted_new = _kahns_total_order(new_vertices, graph)
|
||||||
|
|
||||||
|
for v in sorted_new:
|
||||||
|
v.total_position = position
|
||||||
|
ordered.append(v)
|
||||||
|
position += 1
|
||||||
|
|
||||||
|
return ordered
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Algorithm 10: Total Order using Kahn's Algorithm
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _kahns_total_order(vertices: list[Vertex], graph: LamportGraph) -> list[Vertex]:
|
||||||
|
"""Algorithm 10: generate total order using Kahn's algorithm.
|
||||||
|
|
||||||
|
Kahn's algorithm in its "arrow reversed" incarnation: we want to order
|
||||||
|
the past before the future in our Lamport graph.
|
||||||
|
|
||||||
|
Pseudocode from the paper:
|
||||||
|
1: procedure ORDER(dag:Ord(v), uint:last)
|
||||||
|
2: n <- last + 1
|
||||||
|
3: S <- set of all elements of Ord(v) with no outgoing edges
|
||||||
|
4: while S ≠ ∅ do
|
||||||
|
5: remove x with highest weight w(x) from S
|
||||||
|
6: x.total_position <- n
|
||||||
|
7: n <- n + 1
|
||||||
|
8: for each vertex y ∈ Ord(v) with edge e : y -> x do
|
||||||
|
9: remove edge e from Ord(v)
|
||||||
|
10: if y has no other outgoing edge then
|
||||||
|
11: S <- S ∪ {y}
|
||||||
|
12: end if
|
||||||
|
13: end for
|
||||||
|
14: end while
|
||||||
|
15: end procedure
|
||||||
|
|
||||||
|
Tie-breaking by voting weight ensures that all honest processes produce
|
||||||
|
the same total order from equivalent Lamport graphs.
|
||||||
|
"""
|
||||||
|
if not vertices:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Build a local subgraph for just these vertices
|
||||||
|
vertex_set = {v.message_digest for v in vertices}
|
||||||
|
|
||||||
|
# out_degree: for each vertex, count edges to other vertices in this set
|
||||||
|
out_edges: dict[bytes, set[bytes]] = {}
|
||||||
|
in_edges: dict[bytes, set[bytes]] = {}
|
||||||
|
|
||||||
|
for v in vertices:
|
||||||
|
d = v.message_digest
|
||||||
|
out_edges[d] = set()
|
||||||
|
in_edges[d] = set()
|
||||||
|
|
||||||
|
for v in vertices:
|
||||||
|
d = v.message_digest
|
||||||
|
for cause_d in graph.edges.get(d, set()):
|
||||||
|
if cause_d in vertex_set:
|
||||||
|
out_edges[d].add(cause_d)
|
||||||
|
in_edges[cause_d].add(d)
|
||||||
|
|
||||||
|
# Start with vertices that have no outgoing edges (sinks = earliest causes)
|
||||||
|
result: list[Vertex] = []
|
||||||
|
available = [
|
||||||
|
v for v in vertices
|
||||||
|
if len(out_edges[v.message_digest]) == 0
|
||||||
|
]
|
||||||
|
|
||||||
|
while available:
|
||||||
|
# Remove the vertex with highest weight (deterministic tie-breaking)
|
||||||
|
available.sort(key=lambda v: graph.vertex_weight(v), reverse=True)
|
||||||
|
chosen = available.pop(0)
|
||||||
|
result.append(chosen)
|
||||||
|
|
||||||
|
# Remove edges pointing to chosen
|
||||||
|
chosen_d = chosen.message_digest
|
||||||
|
for referrer_d in list(in_edges.get(chosen_d, set())):
|
||||||
|
out_edges[referrer_d].discard(chosen_d)
|
||||||
|
if len(out_edges[referrer_d]) == 0:
|
||||||
|
referrer_vertex = graph.get_vertex(referrer_d)
|
||||||
|
if referrer_vertex is not None and referrer_vertex not in result:
|
||||||
|
available.append(referrer_vertex)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Convenience: full pipeline
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def total_order_positions(graph: LamportGraph,
|
||||||
|
leader_stream: LeaderStream) -> dict[bytes, int]:
|
||||||
|
"""Return a mapping of message digest -> total order position.
|
||||||
|
|
||||||
|
This is the final output of the Crisis protocol: a total order on
|
||||||
|
messages that respects causality and is probabilistically invariant
|
||||||
|
among all honest participants.
|
||||||
|
"""
|
||||||
|
ordered = compute_order(graph, leader_stream)
|
||||||
|
return {v.message_digest: v.total_position for v in ordered
|
||||||
|
if v.total_position is not None}
|
||||||
231
src/crisis/rounds.py
Normal file
231
src/crisis/rounds.py
Normal file
|
|
@ -0,0 +1,231 @@
|
||||||
|
"""
|
||||||
|
Virtual Synchronous Rounds (Section 5.3)
|
||||||
|
|
||||||
|
Lamport graphs represent a timelike order between vertices that we interpret
|
||||||
|
as virtual communication channels. Going one step further, we can think from
|
||||||
|
inside the Lamport graph to define a virtual clock tick as a transition from
|
||||||
|
one vertex to another.
|
||||||
|
|
||||||
|
This simple idea allows for internal synchronism that enables us to execute
|
||||||
|
strongly synchronous agreement protocols like Feldman & Micali's BA*
|
||||||
|
virtually, but without any compromise in external asynchronism.
|
||||||
|
|
||||||
|
Algorithm 5 (Virtual synchronous rounds):
|
||||||
|
The algorithm computes *round numbers* and the *is_last* property
|
||||||
|
of any vertex.
|
||||||
|
|
||||||
|
- The round number is computed by taking the largest round of all
|
||||||
|
direct causes.
|
||||||
|
- If the vertex is a direct effect of a current round vertex with
|
||||||
|
the is_last property, a new round begins.
|
||||||
|
- If the vertex has enough last vertices of the previous round in its
|
||||||
|
past and it is k-reachable from all of them, the vertex becomes a
|
||||||
|
last vertex in its own round.
|
||||||
|
|
||||||
|
Definition 5.1 (k-reachability):
|
||||||
|
v_hat is said to be k-reachable from v, if the overall weight of all
|
||||||
|
vertices in all paths from v to v_hat is greater than k.
|
||||||
|
|
||||||
|
Proposition 5.3 (Round invariance):
|
||||||
|
The round number and is_last property do not depend on the actual
|
||||||
|
Lamport graph, but are the same for equivalent vertices.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from crisis.graph import LamportGraph
|
||||||
|
from crisis.message import Vertex
|
||||||
|
from crisis.weight import DifficultyOracle
|
||||||
|
|
||||||
|
|
||||||
|
def compute_rounds(graph: LamportGraph, difficulty: DifficultyOracle,
|
||||||
|
connectivity_k: int = 2) -> None:
|
||||||
|
"""Execute Algorithm 5 on all vertices in the graph.
|
||||||
|
|
||||||
|
This computes v.round and v.is_last for every vertex v in the graph.
|
||||||
|
The algorithm processes vertices in causal order (causes before effects)
|
||||||
|
to ensure dependencies are resolved before they are needed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
graph: The Lamport graph to process.
|
||||||
|
difficulty: The difficulty oracle d : N -> W.
|
||||||
|
connectivity_k: The connectivity parameter k for k-reachability.
|
||||||
|
"""
|
||||||
|
# Process vertices in topological order (causes first)
|
||||||
|
ordered = _topological_sort(graph)
|
||||||
|
|
||||||
|
for vertex in ordered:
|
||||||
|
_compute_round_for_vertex(vertex, graph, difficulty, connectivity_k)
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_round_for_vertex(vertex: Vertex, graph: LamportGraph,
|
||||||
|
difficulty: DifficultyOracle,
|
||||||
|
connectivity_k: int) -> None:
|
||||||
|
"""Algorithm 5: compute round number and is_last for a single vertex.
|
||||||
|
|
||||||
|
Pseudocode from the paper:
|
||||||
|
1: procedure ROUND(vertex:v, lamport_graph:G)
|
||||||
|
2: N_v <- {v_hat ∈ G | v -> v_hat} # direct causes
|
||||||
|
3: r <- max({v_hat.round | v_hat ∈ N_v} ∪ {0})
|
||||||
|
4: if there is a v_hat ∈ N_v with v_hat.is_last and v_hat.round = r then
|
||||||
|
5: v.round <- r + 1
|
||||||
|
6: else
|
||||||
|
7: v.round <- r
|
||||||
|
8: end if
|
||||||
|
9: S_r <- {v_hat ∈ G | v_hat.round = v.round - 1, v_hat.is_last, v_hat ≤_k v}
|
||||||
|
10: if w(S_r) > 3 * d_r then
|
||||||
|
11: v.is_last <- true
|
||||||
|
12: else
|
||||||
|
13: v.is_last <- (r = 0)
|
||||||
|
14: end if
|
||||||
|
15: end procedure
|
||||||
|
"""
|
||||||
|
# Step 2: direct causes
|
||||||
|
direct_causes = graph.direct_causes(vertex)
|
||||||
|
|
||||||
|
# Step 3: max round of direct causes (default 0 if no causes)
|
||||||
|
if direct_causes:
|
||||||
|
max_round = max(
|
||||||
|
(dc.round if dc.round is not None else 0) for dc in direct_causes
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
max_round = 0
|
||||||
|
|
||||||
|
# Steps 4-8: determine this vertex's round
|
||||||
|
# If any direct cause is a "last vertex" of the current max round,
|
||||||
|
# this vertex starts a new round.
|
||||||
|
has_last_cause_in_max_round = any(
|
||||||
|
dc.is_last and dc.round == max_round
|
||||||
|
for dc in direct_causes
|
||||||
|
if dc.round is not None and dc.is_last is not None
|
||||||
|
)
|
||||||
|
|
||||||
|
if has_last_cause_in_max_round:
|
||||||
|
vertex.round = max_round + 1
|
||||||
|
else:
|
||||||
|
vertex.round = max_round
|
||||||
|
|
||||||
|
# Steps 9-14: determine is_last
|
||||||
|
r = vertex.round
|
||||||
|
if r == 0:
|
||||||
|
# All round-0 vertices are "last" (bootstrapping)
|
||||||
|
vertex.is_last = True
|
||||||
|
return
|
||||||
|
|
||||||
|
# Find last vertices of the previous round that are k-reachable from v
|
||||||
|
d_r = difficulty.difficulty(r)
|
||||||
|
|
||||||
|
previous_round_lasts = [
|
||||||
|
v_hat for v_hat in graph.all_vertices()
|
||||||
|
if v_hat.round == r - 1
|
||||||
|
and v_hat.is_last
|
||||||
|
and _is_k_reachable(v_hat, vertex, graph, connectivity_k)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Weight of k-reachable last vertices from previous round
|
||||||
|
weight_of_previous_lasts = graph.set_weight(previous_round_lasts)
|
||||||
|
|
||||||
|
if weight_of_previous_lasts > 3 * d_r:
|
||||||
|
vertex.is_last = True
|
||||||
|
else:
|
||||||
|
vertex.is_last = False
|
||||||
|
|
||||||
|
|
||||||
|
def _is_k_reachable(v_from: Vertex, v_to: Vertex,
|
||||||
|
graph: LamportGraph, k: int) -> bool:
|
||||||
|
"""Check k-reachability (Definition 5.1).
|
||||||
|
|
||||||
|
v_hat is k-reachable from v if the overall weight of all vertices in
|
||||||
|
all paths from v to v_hat is greater than k.
|
||||||
|
|
||||||
|
For simplicity in this PoC, we approximate this by checking if v_from
|
||||||
|
is in the past of v_to and the total weight along the path exceeds k.
|
||||||
|
|
||||||
|
The paper notes (page 11): "counting disjoint paths is computationally
|
||||||
|
expensive and not really necessary in our setting... all we need is some
|
||||||
|
insurance that information flows through enough real world processes."
|
||||||
|
We use total path weight as a simpler proxy.
|
||||||
|
"""
|
||||||
|
if v_from not in graph.past(v_to):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Compute the weight of all vertices in the path from v_from to v_to
|
||||||
|
# (all vertices that are in both the future of v_from and the past of v_to)
|
||||||
|
past_of_to = graph.past(v_to)
|
||||||
|
future_of_from = graph.future(v_from)
|
||||||
|
|
||||||
|
path_vertices = past_of_to & future_of_from
|
||||||
|
total_weight = graph.set_weight(path_vertices)
|
||||||
|
|
||||||
|
return total_weight > k
|
||||||
|
|
||||||
|
|
||||||
|
def _topological_sort(graph: LamportGraph) -> list[Vertex]:
|
||||||
|
"""Sort vertices in causal order: causes come before their effects.
|
||||||
|
|
||||||
|
Uses Kahn's algorithm. Vertices with no causes (sources) come first.
|
||||||
|
This ensures that when we process a vertex, all its causes already
|
||||||
|
have their round numbers computed.
|
||||||
|
"""
|
||||||
|
# Compute in-degree (number of causes each vertex has within the graph)
|
||||||
|
in_degree: dict[bytes, int] = {}
|
||||||
|
for d, v in graph.vertices.items():
|
||||||
|
in_degree[d] = 0
|
||||||
|
|
||||||
|
for d, v in graph.vertices.items():
|
||||||
|
for ref_d in graph.edges.get(d, set()):
|
||||||
|
if ref_d in graph.vertices:
|
||||||
|
# ref_d is a cause of d, so d has an additional in-edge
|
||||||
|
# But we want causal order: causes first
|
||||||
|
# edges go from effect -> cause, so we need reverse
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Actually: edges[d] contains the causes of d (d -> cause).
|
||||||
|
# For topological sort where causes come first, we need:
|
||||||
|
# in_degree[d] = number of digests in edges[d] that are in the graph
|
||||||
|
for d in graph.vertices:
|
||||||
|
count = 0
|
||||||
|
for cause_d in graph.edges.get(d, set()):
|
||||||
|
if cause_d in graph.vertices:
|
||||||
|
count += 1
|
||||||
|
in_degree[d] = count
|
||||||
|
|
||||||
|
# Start with vertices that have no causes (in_degree = 0)
|
||||||
|
queue = [d for d, deg in in_degree.items() if deg == 0]
|
||||||
|
result = []
|
||||||
|
|
||||||
|
while queue:
|
||||||
|
current = queue.pop(0)
|
||||||
|
result.append(graph.vertices[current])
|
||||||
|
|
||||||
|
# For each vertex that current is a cause of (reverse edges)
|
||||||
|
for effect_d in graph.reverse_edges.get(current, set()):
|
||||||
|
if effect_d in in_degree:
|
||||||
|
in_degree[effect_d] -= 1
|
||||||
|
if in_degree[effect_d] == 0:
|
||||||
|
queue.append(effect_d)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Queries on computed rounds
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def last_vertices_in_round(graph: LamportGraph, round_number: int) -> list[Vertex]:
|
||||||
|
"""Return all last vertices in a given round."""
|
||||||
|
return [
|
||||||
|
v for v in graph.all_vertices()
|
||||||
|
if v.round == round_number and v.is_last
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def max_round(graph: LamportGraph) -> int:
|
||||||
|
"""Return the highest round number in the graph."""
|
||||||
|
rounds = [v.round for v in graph.all_vertices() if v.round is not None]
|
||||||
|
return max(rounds) if rounds else 0
|
||||||
|
|
||||||
|
|
||||||
|
def vertices_in_round(graph: LamportGraph, round_number: int) -> list[Vertex]:
|
||||||
|
"""Return all vertices in a given round."""
|
||||||
|
return [v for v in graph.all_vertices() if v.round == round_number]
|
||||||
527
src/crisis/voting.py
Normal file
527
src/crisis/voting.py
Normal file
|
|
@ -0,0 +1,527 @@
|
||||||
|
"""
|
||||||
|
Virtual Voting, Safe Voting Patterns, and Leader Election (Section 5)
|
||||||
|
|
||||||
|
This module implements the heart of the Crisis protocol: the virtual voting
|
||||||
|
mechanism that achieves total order without ever sending explicit vote messages.
|
||||||
|
|
||||||
|
Key concepts:
|
||||||
|
|
||||||
|
5.5 Virtual Process Sortition & Knowledge Graphs
|
||||||
|
- Knowledge graph (Def 5.8): quotient graph projecting vertices to virtual
|
||||||
|
processes, representing what each process "knows" about others.
|
||||||
|
- Quorum selector (Def 5.11): deterministically chooses a subset of virtual
|
||||||
|
processes for each round -- the quorum that participates in agreement.
|
||||||
|
|
||||||
|
5.6 Safe Voting Pattern
|
||||||
|
- Voting sets (Def 5.12): the set of vertices participating in round s
|
||||||
|
agreement, reachable with connectivity k from vertex v.
|
||||||
|
- Algorithm 6: computes the safe voting pattern -- a nested sequence of
|
||||||
|
rounds where voting took place with appropriately bounded byzantine weight.
|
||||||
|
|
||||||
|
5.7 Local Leader Election
|
||||||
|
- Algorithm 7: virtual leader elections -- an adaptation of Chen, Feldman
|
||||||
|
& Micali's BA* to virtual voting on Lamport graphs.
|
||||||
|
- Three stage types: initial proposal (δ=0), presorting/gradecast (δ∈{1,2}),
|
||||||
|
and BBA* binary agreement (δ≥3) with "coin fixed to 0/1" and "genuine
|
||||||
|
coin flip" sub-stages.
|
||||||
|
|
||||||
|
5.8 Longest Chain Rule
|
||||||
|
- Algorithm 8: maintains the leader stream by keeping only the longest
|
||||||
|
chain of round leaders (similar to Nakamoto's longest chain rule).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from crisis.crypto import digest, least_significant_bit
|
||||||
|
from crisis.graph import LamportGraph
|
||||||
|
from crisis.message import Message, Vertex, Vote, EMPTY_MESSAGE_DIGEST
|
||||||
|
from crisis.rounds import last_vertices_in_round, max_round
|
||||||
|
from crisis.weight import DifficultyOracle
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Knowledge Graph (Definition 5.8)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class KnowledgeGraph:
|
||||||
|
"""The round s knowledge graph of vertex v (Definition 5.8).
|
||||||
|
|
||||||
|
Given rounds s < r, a Lamport graph G, and v a last message in round r,
|
||||||
|
the knowledge graph Π^s_v is the quotient graph G^s_v / ≃_id.
|
||||||
|
|
||||||
|
Each node in the knowledge graph represents a virtual process (identified
|
||||||
|
by its id). An edge from process id to id' means that some vertex with
|
||||||
|
v.id = id in round s has a vertex with v_hat.id = id' in its past.
|
||||||
|
|
||||||
|
This represents what each virtual process "knows" about others.
|
||||||
|
"""
|
||||||
|
# id -> set of ids that this process has edges to
|
||||||
|
edges: dict[bytes, set[bytes]] = field(default_factory=dict)
|
||||||
|
# id -> total weight of vertices in this equivalence class
|
||||||
|
weights: dict[bytes, int] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
def build_knowledge_graph(vertex: Vertex, round_s: int,
|
||||||
|
graph: LamportGraph) -> KnowledgeGraph:
|
||||||
|
"""Build the round s knowledge graph for vertex v.
|
||||||
|
|
||||||
|
Collects all round-s vertices in v's past, groups them by id,
|
||||||
|
and builds the quotient graph.
|
||||||
|
"""
|
||||||
|
kg = KnowledgeGraph()
|
||||||
|
past = graph.past(vertex)
|
||||||
|
|
||||||
|
# Find all round-s vertices in v's past
|
||||||
|
round_s_vertices = [v for v in past if v.round == round_s]
|
||||||
|
|
||||||
|
# Group by id and compute edges
|
||||||
|
for v_s in round_s_vertices:
|
||||||
|
vid = v_s.id
|
||||||
|
if vid not in kg.edges:
|
||||||
|
kg.edges[vid] = set()
|
||||||
|
if vid not in kg.weights:
|
||||||
|
kg.weights[vid] = 0
|
||||||
|
|
||||||
|
kg.weights[vid] = graph.weight_system.weight_sum(
|
||||||
|
kg.weights[vid], graph.vertex_weight(v_s)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add edges based on what this vertex references
|
||||||
|
for cause in graph.direct_causes(v_s):
|
||||||
|
if cause.round is not None and cause.round == round_s:
|
||||||
|
kg.edges[vid].add(cause.id)
|
||||||
|
|
||||||
|
return kg
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Quorum Selector (Definition 5.11)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def select_quorum(knowledge_graph: KnowledgeGraph, n: int = 3) -> set[bytes]:
|
||||||
|
"""Select a quorum from a knowledge graph (Definition 5.11).
|
||||||
|
|
||||||
|
Example 3 (Highest voting weight quorum):
|
||||||
|
Choose the weakly connected component with the highest combined voting
|
||||||
|
weight, then take the heaviest n virtual processes from it.
|
||||||
|
|
||||||
|
The quorum selector serves as a filter to reduce byzantine noise that
|
||||||
|
might appear in the voting process. By restricting to a heavily
|
||||||
|
connected component, faulty behavior based on graph partition is reduced.
|
||||||
|
"""
|
||||||
|
if not knowledge_graph.edges:
|
||||||
|
return set()
|
||||||
|
|
||||||
|
# Find weakly connected components using simple BFS
|
||||||
|
all_ids = set(knowledge_graph.edges.keys())
|
||||||
|
visited: set[bytes] = set()
|
||||||
|
components: list[set[bytes]] = []
|
||||||
|
|
||||||
|
for start_id in all_ids:
|
||||||
|
if start_id in visited:
|
||||||
|
continue
|
||||||
|
component: set[bytes] = set()
|
||||||
|
queue = [start_id]
|
||||||
|
while queue:
|
||||||
|
current = queue.pop(0)
|
||||||
|
if current in visited:
|
||||||
|
continue
|
||||||
|
visited.add(current)
|
||||||
|
component.add(current)
|
||||||
|
# Follow edges in both directions (weakly connected)
|
||||||
|
for neighbor in knowledge_graph.edges.get(current, set()):
|
||||||
|
if neighbor not in visited and neighbor in all_ids:
|
||||||
|
queue.append(neighbor)
|
||||||
|
# Reverse edges
|
||||||
|
for other_id, neighbors in knowledge_graph.edges.items():
|
||||||
|
if current in neighbors and other_id not in visited:
|
||||||
|
queue.append(other_id)
|
||||||
|
components.append(component)
|
||||||
|
|
||||||
|
# Choose the component with highest total weight
|
||||||
|
def component_weight(comp: set[bytes]) -> int:
|
||||||
|
return sum(knowledge_graph.weights.get(pid, 0) for pid in comp)
|
||||||
|
|
||||||
|
best_component = max(components, key=component_weight)
|
||||||
|
|
||||||
|
# Take the n heaviest processes from this component
|
||||||
|
sorted_by_weight = sorted(
|
||||||
|
best_component,
|
||||||
|
key=lambda pid: knowledge_graph.weights.get(pid, 0),
|
||||||
|
reverse=True
|
||||||
|
)
|
||||||
|
|
||||||
|
return set(sorted_by_weight[:n])
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Voting Sets (Definition 5.12)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def voting_set(vertex: Vertex, round_s: int, connectivity_k: int,
|
||||||
|
graph: LamportGraph) -> set[Vertex]:
|
||||||
|
"""Compute S_v(s,k): v's round s voting set (Definition 5.12).
|
||||||
|
|
||||||
|
S_v(s,k) := { x | x.id ∈ Q(v,s) ∧ x ≤_{(r-s)*k} v
|
||||||
|
∧ x.round = s ∧ x.is_last = true }
|
||||||
|
|
||||||
|
The voting set consists of all last vertices in round s that:
|
||||||
|
1. Belong to a quorum-selected virtual process
|
||||||
|
2. Are k-reachable from v (with distance scaled by round gap)
|
||||||
|
3. Are in v's past
|
||||||
|
"""
|
||||||
|
if vertex.round is None:
|
||||||
|
return set()
|
||||||
|
|
||||||
|
r = vertex.round
|
||||||
|
if round_s >= r:
|
||||||
|
return set()
|
||||||
|
|
||||||
|
# Build knowledge graph and select quorum
|
||||||
|
kg = build_knowledge_graph(vertex, round_s, graph)
|
||||||
|
quorum = select_quorum(kg)
|
||||||
|
|
||||||
|
past_of_v = graph.past(vertex)
|
||||||
|
|
||||||
|
result = set()
|
||||||
|
for v_hat in past_of_v:
|
||||||
|
if (v_hat.round == round_s
|
||||||
|
and v_hat.is_last
|
||||||
|
and v_hat.id in quorum):
|
||||||
|
result.add(v_hat)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Algorithm 6: Safe Voting Pattern (Section 5.6)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def compute_safe_voting_pattern(vertex: Vertex, graph: LamportGraph,
|
||||||
|
difficulty: DifficultyOracle,
|
||||||
|
connectivity_k: int = 2) -> None:
|
||||||
|
"""Algorithm 6: compute the safe voting pattern for a vertex.
|
||||||
|
|
||||||
|
The safe voting pattern v.svp is a totally ordered set of round numbers
|
||||||
|
where "safe" voting took place. Safe means:
|
||||||
|
- The voting set has enough overall weight
|
||||||
|
- The svp of all members agree
|
||||||
|
- Byzantine weight is bounded
|
||||||
|
|
||||||
|
Pseudocode from the paper:
|
||||||
|
1: procedure SVP(vertex:v, lamport_graph:G)
|
||||||
|
2: v.svp <- ∅
|
||||||
|
3: if v.is_last and [safe voting pattern conditions are met] then
|
||||||
|
4: s <- maximum of all such k
|
||||||
|
5: v.svp <- v.svp ∪ {s} for all t ≤ s
|
||||||
|
6: end if
|
||||||
|
7: end procedure
|
||||||
|
|
||||||
|
The procedure checks if the current vertex's round qualifies as a new
|
||||||
|
entry in the safe voting pattern by verifying weight and agreement
|
||||||
|
conditions from its voting set.
|
||||||
|
"""
|
||||||
|
vertex.svp = []
|
||||||
|
|
||||||
|
if not vertex.is_last or vertex.round is None or vertex.round == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
r = vertex.round
|
||||||
|
|
||||||
|
# Check each previous round for safe voting pattern membership
|
||||||
|
for s in range(r):
|
||||||
|
d_s = difficulty.difficulty(s)
|
||||||
|
|
||||||
|
# Get voting set for round s
|
||||||
|
vs = voting_set(vertex, s, connectivity_k, graph)
|
||||||
|
if not vs:
|
||||||
|
continue
|
||||||
|
|
||||||
|
total_weight = graph.set_weight(vs)
|
||||||
|
|
||||||
|
# Check if voting weight exceeds threshold (6 * d_s from Eq. 8)
|
||||||
|
if total_weight <= 6 * d_s:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check that all members of the voting set have compatible svp
|
||||||
|
svps_agree = True
|
||||||
|
for x in vs:
|
||||||
|
for y in vs:
|
||||||
|
if x.svp != y.svp:
|
||||||
|
# Allow prefix agreement
|
||||||
|
min_len = min(len(x.svp), len(y.svp))
|
||||||
|
if x.svp[:min_len] != y.svp[:min_len]:
|
||||||
|
svps_agree = False
|
||||||
|
break
|
||||||
|
if not svps_agree:
|
||||||
|
break
|
||||||
|
|
||||||
|
if svps_agree:
|
||||||
|
vertex.svp.append(s)
|
||||||
|
|
||||||
|
# svp is a nested sequence: add current round
|
||||||
|
if vertex.svp:
|
||||||
|
vertex.svp.append(r)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Initial Vote Function (Definition 5.16, Example 4)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def initial_vote(vertices: set[Vertex], graph: LamportGraph) -> Optional[Message]:
|
||||||
|
"""INITIAL_VOTE: deterministically choose a leader proposal (Def 5.16).
|
||||||
|
|
||||||
|
Example 4 (Highest weight): Choose the underlying message of the highest
|
||||||
|
voting weight vertex. Since we assume it is infeasible to have different
|
||||||
|
vertices of equal weight, this is practically deterministic.
|
||||||
|
|
||||||
|
The initial vote function is a system parameter. Different choices lead
|
||||||
|
to different long-term behavior. Ideally all members of a safe voting
|
||||||
|
pattern would compute the same initial vote.
|
||||||
|
"""
|
||||||
|
if not vertices:
|
||||||
|
return None
|
||||||
|
|
||||||
|
best_vertex = max(vertices, key=lambda v: graph.vertex_weight(v))
|
||||||
|
return best_vertex.m
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Algorithm 7: Virtual Leader Elections (Section 5.7)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def compute_virtual_leader_election(vertex: Vertex, graph: LamportGraph,
|
||||||
|
difficulty: DifficultyOracle,
|
||||||
|
connectivity_k: int,
|
||||||
|
leader_stream: dict[int, list[tuple[int, Message]]]) -> None:
|
||||||
|
"""Algorithm 7: compute votes for all rounds in v's safe voting pattern.
|
||||||
|
|
||||||
|
This is the core virtual BA* protocol. For each element t in v.svp,
|
||||||
|
the vertex computes a vote v.vote(t) = (l, b) based on the stage δ
|
||||||
|
(the position of that round in the svp).
|
||||||
|
|
||||||
|
Stage types (determined by δ = d_{v.svp}(s, t)):
|
||||||
|
δ = 0: Initial leader proposal
|
||||||
|
δ = 1: Leader presorting (gradecast step)
|
||||||
|
δ = 2: BBA* initialization (gradecast step)
|
||||||
|
δ ≥ 3: Binary agreement rounds
|
||||||
|
δ mod 3 = 0: Coin fixed to 0
|
||||||
|
δ mod 3 = 1: Coin fixed to 1
|
||||||
|
δ mod 3 = 2: Genuine coin flip
|
||||||
|
|
||||||
|
The paper notes: "every step is entirely virtual and no votes are
|
||||||
|
actually sent to other real world processes."
|
||||||
|
"""
|
||||||
|
if not vertex.svp:
|
||||||
|
return
|
||||||
|
|
||||||
|
s = max(vertex.svp) if vertex.svp else None
|
||||||
|
if s is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
for t_idx, t in enumerate(vertex.svp):
|
||||||
|
delta = t_idx # stage = position in svp
|
||||||
|
_compute_vote_for_stage(vertex, t, delta, s, graph, difficulty,
|
||||||
|
connectivity_k, leader_stream)
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_vote_for_stage(vertex: Vertex, t: int, delta: int, s: int,
|
||||||
|
graph: LamportGraph, difficulty: DifficultyOracle,
|
||||||
|
connectivity_k: int,
|
||||||
|
leader_stream: dict[int, list[tuple[int, Message]]]) -> None:
|
||||||
|
"""Compute vertex's vote for a specific stage of the virtual leader election.
|
||||||
|
|
||||||
|
Implements the branching logic of Algorithm 7 (pages 19-20 of the paper).
|
||||||
|
"""
|
||||||
|
d_s = difficulty.difficulty(s)
|
||||||
|
vs = voting_set(vertex, t, connectivity_k, graph)
|
||||||
|
n = graph.set_weight(vs)
|
||||||
|
|
||||||
|
NON_LEADER = None # ∅ in the paper
|
||||||
|
|
||||||
|
if delta == 0:
|
||||||
|
# Stage 0: Initial leader proposal
|
||||||
|
l = initial_vote(vs, graph)
|
||||||
|
vertex.vote[t] = Vote(message=l, binary=None) # (INITIAL_VOTE(S), ⊥)
|
||||||
|
|
||||||
|
elif delta == 1:
|
||||||
|
# Stage 1: Leader presorting
|
||||||
|
# Find message with highest round-t voting weight in S
|
||||||
|
l = _highest_weight_message(vs, graph)
|
||||||
|
|
||||||
|
if l is not None:
|
||||||
|
# Check if l has super majority weight
|
||||||
|
l_weight = _vote_weight_for(vs, t, l, None, graph) # votes for (l, ⊥)
|
||||||
|
if l_weight > n - d_s:
|
||||||
|
vertex.vote[t] = Vote(message=l, binary=None) # (l, ⊥)
|
||||||
|
else:
|
||||||
|
vertex.vote[t] = Vote(message=NON_LEADER, binary=None) # (∅, ⊥)
|
||||||
|
else:
|
||||||
|
vertex.vote[t] = Vote(message=NON_LEADER, binary=None)
|
||||||
|
|
||||||
|
elif delta == 2:
|
||||||
|
# Stage 2: BBA* initialization (gradecast)
|
||||||
|
l = _highest_weight_message(vs, graph)
|
||||||
|
|
||||||
|
if l is not None:
|
||||||
|
l_weight_undecided = _vote_weight_for(vs, t, l, None, graph)
|
||||||
|
if l_weight_undecided > n - d_s:
|
||||||
|
vertex.vote[t] = Vote(message=l, binary=0)
|
||||||
|
else:
|
||||||
|
l_weight_1 = _vote_weight_for(vs, t, l, 1, graph)
|
||||||
|
if l_weight_1 > d_s:
|
||||||
|
vertex.vote[t] = Vote(message=l, binary=1)
|
||||||
|
else:
|
||||||
|
vertex.vote[t] = Vote(message=NON_LEADER, binary=1)
|
||||||
|
else:
|
||||||
|
vertex.vote[t] = Vote(message=NON_LEADER, binary=1)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Stage δ ≥ 3: Binary agreement (BBA*)
|
||||||
|
coin_stage = delta % 3
|
||||||
|
l = _highest_weight_message(vs, graph)
|
||||||
|
|
||||||
|
if coin_stage == 0:
|
||||||
|
# Coin fixed to 0
|
||||||
|
_bba_coin_fixed(vertex, t, vs, l, n, d_s, graph,
|
||||||
|
leader_stream, s, fixed_value=0)
|
||||||
|
elif coin_stage == 1:
|
||||||
|
# Coin fixed to 1
|
||||||
|
_bba_coin_fixed(vertex, t, vs, l, n, d_s, graph,
|
||||||
|
leader_stream, s, fixed_value=1)
|
||||||
|
else:
|
||||||
|
# Genuine coin flip (coin_stage == 2)
|
||||||
|
_bba_genuine_coin(vertex, t, vs, l, n, d_s, graph)
|
||||||
|
|
||||||
|
|
||||||
|
def _bba_coin_fixed(vertex: Vertex, t: int, vs: set[Vertex],
|
||||||
|
l: Optional[Message], n: int, d_s: int,
|
||||||
|
graph: LamportGraph,
|
||||||
|
leader_stream: dict[int, list[tuple[int, Message]]],
|
||||||
|
s: int, fixed_value: int) -> None:
|
||||||
|
"""BBA* stage with coin fixed to 0 or 1."""
|
||||||
|
other_value = 1 - fixed_value
|
||||||
|
|
||||||
|
if l is not None:
|
||||||
|
weight_for_fixed = _vote_weight_for_binary(vs, t, fixed_value, graph)
|
||||||
|
if weight_for_fixed > n - d_s:
|
||||||
|
vertex.vote[t] = Vote(message=l, binary=fixed_value)
|
||||||
|
# If weight = n, we have agreement: update leader stream
|
||||||
|
if weight_for_fixed == n:
|
||||||
|
_update_leader_stream(leader_stream, l, s)
|
||||||
|
return
|
||||||
|
|
||||||
|
weight_for_other = _vote_weight_for_binary(vs, t, other_value, graph)
|
||||||
|
if weight_for_other > n - d_s:
|
||||||
|
vertex.vote[t] = Vote(message=l, binary=other_value)
|
||||||
|
return
|
||||||
|
|
||||||
|
vertex.vote[t] = Vote(message=l, binary=fixed_value)
|
||||||
|
|
||||||
|
|
||||||
|
def _bba_genuine_coin(vertex: Vertex, t: int, vs: set[Vertex],
|
||||||
|
l: Optional[Message], n: int, d_s: int,
|
||||||
|
graph: LamportGraph) -> None:
|
||||||
|
"""BBA* stage with genuine coin flip."""
|
||||||
|
if l is not None:
|
||||||
|
weight_0 = _vote_weight_for_binary(vs, t, 0, graph)
|
||||||
|
if weight_0 > n - d_s:
|
||||||
|
vertex.vote[t] = Vote(message=l, binary=0)
|
||||||
|
return
|
||||||
|
|
||||||
|
weight_1 = _vote_weight_for_binary(vs, t, 1, graph)
|
||||||
|
if weight_1 > n - d_s:
|
||||||
|
vertex.vote[t] = Vote(message=l, binary=1)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Genuine coin flip: use LSB of hash of heaviest vertex's message
|
||||||
|
if vs:
|
||||||
|
heaviest = max(vs, key=lambda v: graph.vertex_weight(v))
|
||||||
|
h = heaviest.m.compute_digest()
|
||||||
|
b_coin = least_significant_bit(h)
|
||||||
|
vertex.vote[t] = Vote(message=l, binary=b_coin)
|
||||||
|
else:
|
||||||
|
vertex.vote[t] = Vote(message=l, binary=0)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helper functions for vote weight computation
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _highest_weight_message(vs: set[Vertex], graph: LamportGraph) -> Optional[Message]:
|
||||||
|
"""Find the message with the highest voting weight in a set."""
|
||||||
|
if not vs:
|
||||||
|
return None
|
||||||
|
best = max(vs, key=lambda v: graph.vertex_weight(v))
|
||||||
|
return best.m
|
||||||
|
|
||||||
|
|
||||||
|
def _vote_weight_for(vs: set[Vertex], round_t: int,
|
||||||
|
target_msg: Optional[Message], target_binary: Optional[int],
|
||||||
|
graph: LamportGraph) -> int:
|
||||||
|
"""Compute total voting weight for a specific vote (l, b) in a voting set."""
|
||||||
|
total = 0
|
||||||
|
for v in vs:
|
||||||
|
vote = v.vote.get(round_t)
|
||||||
|
if vote is None:
|
||||||
|
continue
|
||||||
|
msg_match = (vote.message is None and target_msg is None) or \
|
||||||
|
(vote.message is not None and target_msg is not None and
|
||||||
|
vote.message.compute_digest() == target_msg.compute_digest())
|
||||||
|
bin_match = vote.binary == target_binary
|
||||||
|
if msg_match and bin_match:
|
||||||
|
total = graph.weight_system.weight_sum(total, graph.vertex_weight(v))
|
||||||
|
return total
|
||||||
|
|
||||||
|
|
||||||
|
def _vote_weight_for_binary(vs: set[Vertex], round_t: int,
|
||||||
|
target_binary: int,
|
||||||
|
graph: LamportGraph) -> int:
|
||||||
|
"""Compute total voting weight for a specific binary value in a voting set."""
|
||||||
|
total = 0
|
||||||
|
for v in vs:
|
||||||
|
vote = v.vote.get(round_t)
|
||||||
|
if vote is not None and vote.binary == target_binary:
|
||||||
|
total = graph.weight_system.weight_sum(total, graph.vertex_weight(v))
|
||||||
|
return total
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Algorithm 8: Longest Chain Rule (Section 5.8)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _update_leader_stream(leader_stream: dict[int, list[tuple[int, Message]]],
|
||||||
|
message: Message, round_number: int) -> None:
|
||||||
|
"""Algorithm 8: update the leader stream with a new leader candidate.
|
||||||
|
|
||||||
|
The longest chain rule keeps only the chain with the highest deciding
|
||||||
|
round for each round leader. When a new round leader is decided at a
|
||||||
|
higher deciding round, previous entries with lower deciding rounds are
|
||||||
|
replaced.
|
||||||
|
|
||||||
|
Pseudocode:
|
||||||
|
1: procedure LONG_CHAIN(set{(uint,MESSAGE)}:S, MESSAGE:m, uint:s)
|
||||||
|
2: if there is no (l, t) ∈ S with t > s then
|
||||||
|
3: S <- {(l, t) ∈ S | t < s} ∪ (s, m)
|
||||||
|
4: end if
|
||||||
|
5: return S
|
||||||
|
6: end procedure
|
||||||
|
"""
|
||||||
|
if round_number not in leader_stream:
|
||||||
|
leader_stream[round_number] = []
|
||||||
|
|
||||||
|
entries = leader_stream[round_number]
|
||||||
|
|
||||||
|
# Check if there's already an entry with a higher deciding round
|
||||||
|
has_higher = any(t > round_number for (t, _) in entries)
|
||||||
|
if has_higher:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Remove entries with lower deciding rounds, add new one
|
||||||
|
leader_stream[round_number] = [
|
||||||
|
(t, m) for (t, m) in entries if t < round_number
|
||||||
|
] + [(round_number, message)]
|
||||||
190
src/crisis/weight.py
Normal file
190
src/crisis/weight.py
Normal file
|
|
@ -0,0 +1,190 @@
|
||||||
|
"""
|
||||||
|
Weight Systems (Section 3.1.1)
|
||||||
|
|
||||||
|
Definition 3.1 (Weight system): Let MESSAGE be the metric space of all
|
||||||
|
messages and (W, ≤) a totally ordered set. Then the tuple (W, w, ⊕, c_min)
|
||||||
|
is a *weight system* if w is a function
|
||||||
|
|
||||||
|
w : MESSAGE -> W (Eq. 3)
|
||||||
|
|
||||||
|
that assigns an element of W to any message, c_min ∈ W is a constant called
|
||||||
|
the *weight threshold*, and ⊕ is a function
|
||||||
|
|
||||||
|
⊕ : W × W -> W (Eq. 4)
|
||||||
|
|
||||||
|
called the *weight sum*, such that:
|
||||||
|
|
||||||
|
- Tamper proof: w(m) >= c_min and m_hat ≠ m implies w(m_hat) < c_min
|
||||||
|
with high probability.
|
||||||
|
- Uniqueness: m ≠ m_hat implies w(m) ≠ w(m_hat) with high probability.
|
||||||
|
- Summability: (W, ⊕) is a totally ordered, abelian group.
|
||||||
|
|
||||||
|
The weight w(m) is interpreted as the amount of voting power m holds to
|
||||||
|
influence total order generation.
|
||||||
|
|
||||||
|
This module provides:
|
||||||
|
1. An abstract WeightSystem protocol
|
||||||
|
2. A concrete Proof-of-Work implementation (Hashcash-style)
|
||||||
|
|
||||||
|
The PoW weight function counts leading zero bits of H(m), similar to Bitcoin's
|
||||||
|
difficulty mechanism (Nakamoto, 2009; Beck, 2002).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Protocol
|
||||||
|
|
||||||
|
from crisis.crypto import digest, count_leading_zero_bits
|
||||||
|
from crisis.message import Message
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Abstract weight system
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class WeightSystem(Protocol):
|
||||||
|
"""Protocol defining the weight system interface (Definition 3.1).
|
||||||
|
|
||||||
|
Any concrete weight system must provide:
|
||||||
|
- weight(): Compute w(m) for a message
|
||||||
|
- threshold: The minimum weight c_min
|
||||||
|
- weight_sum(): Compute ⊕ for two weights
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def threshold(self) -> int:
|
||||||
|
"""c_min: the minimum weight threshold.
|
||||||
|
|
||||||
|
Messages with weight below this are rejected. This prevents Sybil
|
||||||
|
attacks by ensuring every message requires a minimum investment.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def weight(self, message: Message) -> int:
|
||||||
|
"""w(m): compute the weight of a message.
|
||||||
|
|
||||||
|
The weight represents the voting power of this message in the
|
||||||
|
consensus protocol.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def weight_sum(self, a: int, b: int) -> int:
|
||||||
|
"""⊕: combine two weights.
|
||||||
|
|
||||||
|
Must form a totally ordered abelian group.
|
||||||
|
For our purposes, ordinary integer addition suffices.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def is_valid_weight(self, message: Message) -> bool:
|
||||||
|
"""Check whether w(m) >= c_min."""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Proof-of-Work weight system
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ProofOfWorkWeight:
|
||||||
|
"""A Hashcash-style Proof-of-Work weight system.
|
||||||
|
|
||||||
|
The weight of a message is the number of leading zero bits in H(m).
|
||||||
|
This is similar to Bitcoin's mining: finding a message whose hash starts
|
||||||
|
with k zero bits requires approximately 2^k hash evaluations on average.
|
||||||
|
|
||||||
|
The nonce field of the message is used to search for a valid hash,
|
||||||
|
analogous to Bitcoin's block header nonce.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
min_leading_zeros: c_min -- minimum leading zero bits required.
|
||||||
|
A value of 1 means every message needs at least
|
||||||
|
1 leading zero bit (50% of hashes qualify).
|
||||||
|
"""
|
||||||
|
min_leading_zeros: int = 1
|
||||||
|
|
||||||
|
@property
|
||||||
|
def threshold(self) -> int:
|
||||||
|
return self.min_leading_zeros
|
||||||
|
|
||||||
|
def weight(self, message: Message) -> int:
|
||||||
|
"""Count leading zero bits in H(m).
|
||||||
|
|
||||||
|
More leading zeros = more work performed = higher voting weight.
|
||||||
|
"""
|
||||||
|
h = message.compute_digest()
|
||||||
|
return count_leading_zero_bits(h)
|
||||||
|
|
||||||
|
def weight_sum(self, a: int, b: int) -> int:
|
||||||
|
"""Simple integer addition for combining weights.
|
||||||
|
|
||||||
|
This satisfies the abelian group requirement: (Z, +) is a totally
|
||||||
|
ordered abelian group with identity 0.
|
||||||
|
"""
|
||||||
|
return a + b
|
||||||
|
|
||||||
|
def is_valid_weight(self, message: Message) -> bool:
|
||||||
|
"""Check w(m) >= c_min."""
|
||||||
|
return self.weight(message) >= self.threshold
|
||||||
|
|
||||||
|
def mine_nonce(self, id_bytes: bytes, digests: tuple[bytes, ...],
|
||||||
|
payload: bytes, target_weight: int | None = None) -> Message:
|
||||||
|
"""Search for a nonce that produces a message meeting the weight target.
|
||||||
|
|
||||||
|
This is the "nonce grinding" step: try successive nonce values until
|
||||||
|
H(m) has enough leading zero bits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
id_bytes: The virtual process id for this message.
|
||||||
|
digests: Causal acknowledgements (digests of prior messages).
|
||||||
|
payload: The application payload.
|
||||||
|
target_weight: Minimum weight to achieve. Defaults to c_min.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A Message with a valid nonce.
|
||||||
|
"""
|
||||||
|
if target_weight is None:
|
||||||
|
target_weight = self.threshold
|
||||||
|
|
||||||
|
nonce_int = 0
|
||||||
|
while True:
|
||||||
|
nonce = nonce_int.to_bytes(8, "big")
|
||||||
|
msg = Message(nonce=nonce, id=id_bytes, digests=digests, payload=payload)
|
||||||
|
if self.weight(msg) >= target_weight:
|
||||||
|
return msg
|
||||||
|
nonce_int += 1
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Difficulty Oracle (Section 5.4, Definition 5.2)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DifficultyOracle:
|
||||||
|
"""Maps round numbers to difficulty values (Definition 5.2).
|
||||||
|
|
||||||
|
The difficulty oracle d : N -> W maps natural numbers (rounds) onto
|
||||||
|
weights. The value d_r := d(r) is called the *round r difficulty*.
|
||||||
|
|
||||||
|
The difficulty is designed so that the overall voting weight per round
|
||||||
|
is bounded:
|
||||||
|
|
||||||
|
lim sum(w_s^G / d_s) <= 6 (Eq. 8)
|
||||||
|
|
||||||
|
for all time parameters t, where w_s^G is the overall voting weight of
|
||||||
|
last vertices in round s.
|
||||||
|
|
||||||
|
Example 1 (paper): A fixed constant that does not change over time.
|
||||||
|
This is the simplest starting point for a PoC.
|
||||||
|
"""
|
||||||
|
constant_difficulty: int = 4
|
||||||
|
|
||||||
|
def difficulty(self, round_number: int) -> int:
|
||||||
|
"""d(r): return the difficulty for round r.
|
||||||
|
|
||||||
|
For this PoC we use a fixed constant (paper Example 1).
|
||||||
|
A production system might adapt this based on observed voting
|
||||||
|
weight, similar to Bitcoin's difficulty adjustment.
|
||||||
|
"""
|
||||||
|
return self.constant_difficulty
|
||||||
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
57
tests/test_crypto.py
Normal file
57
tests/test_crypto.py
Normal file
|
|
@ -0,0 +1,57 @@
|
||||||
|
"""Tests for the crypto module (random oracle model)."""
|
||||||
|
|
||||||
|
from crisis.crypto import (
|
||||||
|
digest, digest_hex, verify_digest,
|
||||||
|
least_significant_bit, count_leading_zero_bits,
|
||||||
|
DIGEST_LENGTH,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_digest_returns_32_bytes():
|
||||||
|
h = digest(b"hello")
|
||||||
|
assert len(h) == DIGEST_LENGTH == 32
|
||||||
|
|
||||||
|
|
||||||
|
def test_digest_is_deterministic():
|
||||||
|
assert digest(b"test") == digest(b"test")
|
||||||
|
|
||||||
|
|
||||||
|
def test_digest_different_inputs_different_outputs():
|
||||||
|
assert digest(b"a") != digest(b"b")
|
||||||
|
|
||||||
|
|
||||||
|
def test_digest_hex_matches():
|
||||||
|
h = digest(b"hello")
|
||||||
|
assert digest_hex(b"hello") == h.hex()
|
||||||
|
|
||||||
|
|
||||||
|
def test_verify_digest():
|
||||||
|
h = digest(b"data")
|
||||||
|
assert verify_digest(b"data", h)
|
||||||
|
assert not verify_digest(b"other", h)
|
||||||
|
|
||||||
|
|
||||||
|
def test_least_significant_bit():
|
||||||
|
# 0x00 -> LSB = 0, 0x01 -> LSB = 1
|
||||||
|
assert least_significant_bit(b"\x00") == 0
|
||||||
|
assert least_significant_bit(b"\x01") == 1
|
||||||
|
assert least_significant_bit(b"\x02") == 0
|
||||||
|
assert least_significant_bit(b"\x03") == 1
|
||||||
|
assert least_significant_bit(b"\xff") == 1
|
||||||
|
assert least_significant_bit(b"\xfe") == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_count_leading_zero_bits():
|
||||||
|
assert count_leading_zero_bits(b"\xff") == 0
|
||||||
|
assert count_leading_zero_bits(b"\x7f") == 1
|
||||||
|
assert count_leading_zero_bits(b"\x3f") == 2
|
||||||
|
assert count_leading_zero_bits(b"\x00\xff") == 8
|
||||||
|
assert count_leading_zero_bits(b"\x00\x00\x01") == 23
|
||||||
|
assert count_leading_zero_bits(b"\x00") == 8
|
||||||
|
|
||||||
|
|
||||||
|
def test_empty_digest_is_well_defined():
|
||||||
|
"""Paper: 'Acknowledgement of the empty string is defined as H(∅)'."""
|
||||||
|
h = digest(b"")
|
||||||
|
assert len(h) == 32
|
||||||
|
assert h == digest(b"") # deterministic
|
||||||
218
tests/test_graph.py
Normal file
218
tests/test_graph.py
Normal file
|
|
@ -0,0 +1,218 @@
|
||||||
|
"""Tests for the Lamport graph with integrity checks."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import pytest
|
||||||
|
from crisis.crypto import digest
|
||||||
|
from crisis.graph import LamportGraph
|
||||||
|
from crisis.message import Message, Vertex, ID_LENGTH, NONCE_LENGTH
|
||||||
|
from crisis.weight import ProofOfWorkWeight
|
||||||
|
|
||||||
|
|
||||||
|
def make_id(name: str) -> bytes:
|
||||||
|
return digest(name.encode())[:ID_LENGTH]
|
||||||
|
|
||||||
|
|
||||||
|
def make_nonce(n: int = 0) -> bytes:
|
||||||
|
return n.to_bytes(NONCE_LENGTH, "big")
|
||||||
|
|
||||||
|
|
||||||
|
def make_graph(pow_zeros: int = 0) -> LamportGraph:
|
||||||
|
return LamportGraph(weight_system=ProofOfWorkWeight(min_leading_zeros=pow_zeros))
|
||||||
|
|
||||||
|
|
||||||
|
class TestLamportGraphExtension:
|
||||||
|
|
||||||
|
def test_extend_single_message(self):
|
||||||
|
g = make_graph()
|
||||||
|
msg = Message(nonce=make_nonce(), id=make_id("alice"), payload=b"hello")
|
||||||
|
v = g.extend(msg)
|
||||||
|
assert v is not None
|
||||||
|
assert g.vertex_count() == 1
|
||||||
|
|
||||||
|
def test_extend_chain(self):
|
||||||
|
"""Messages from the same id must form a chain."""
|
||||||
|
g = make_graph()
|
||||||
|
m1 = Message(nonce=make_nonce(0), id=make_id("alice"), payload=b"first")
|
||||||
|
v1 = g.extend(m1)
|
||||||
|
assert v1 is not None
|
||||||
|
|
||||||
|
m2 = Message(
|
||||||
|
nonce=make_nonce(1), id=make_id("alice"),
|
||||||
|
digests=(m1.compute_digest(),),
|
||||||
|
payload=b"second"
|
||||||
|
)
|
||||||
|
v2 = g.extend(m2)
|
||||||
|
assert v2 is not None
|
||||||
|
assert g.vertex_count() == 2
|
||||||
|
|
||||||
|
def test_reject_duplicate(self):
|
||||||
|
"""No two equivalent vertices in the same graph."""
|
||||||
|
g = make_graph()
|
||||||
|
msg = Message(nonce=make_nonce(), id=make_id("alice"), payload=b"x")
|
||||||
|
g.extend(msg)
|
||||||
|
v2 = g.extend(msg)
|
||||||
|
assert v2 is None # Rejected: duplicate
|
||||||
|
assert g.vertex_count() == 1
|
||||||
|
|
||||||
|
def test_reject_missing_reference(self):
|
||||||
|
"""Digests must reference existing vertices."""
|
||||||
|
g = make_graph()
|
||||||
|
fake_digest = digest(b"nonexistent")
|
||||||
|
msg = Message(
|
||||||
|
nonce=make_nonce(), id=make_id("alice"),
|
||||||
|
digests=(fake_digest,), payload=b"orphan"
|
||||||
|
)
|
||||||
|
v = g.extend(msg)
|
||||||
|
assert v is None # Rejected
|
||||||
|
|
||||||
|
def test_reject_broken_chain(self):
|
||||||
|
"""Second message from same id must reference a same-id vertex."""
|
||||||
|
g = make_graph()
|
||||||
|
id_a = make_id("alice")
|
||||||
|
id_b = make_id("bob")
|
||||||
|
|
||||||
|
m1 = Message(nonce=make_nonce(0), id=id_a, payload=b"first")
|
||||||
|
g.extend(m1)
|
||||||
|
|
||||||
|
m_bob = Message(nonce=make_nonce(1), id=id_b, payload=b"bob's msg")
|
||||||
|
g.extend(m_bob)
|
||||||
|
|
||||||
|
# Alice's second message references bob but not herself -> rejected
|
||||||
|
m2 = Message(
|
||||||
|
nonce=make_nonce(2), id=id_a,
|
||||||
|
digests=(m_bob.compute_digest(),),
|
||||||
|
payload=b"broken chain"
|
||||||
|
)
|
||||||
|
v = g.extend(m2)
|
||||||
|
assert v is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestCausality:
|
||||||
|
|
||||||
|
def _build_chain(self):
|
||||||
|
"""Build a simple 3-message chain: m1 <- m2 <- m3."""
|
||||||
|
g = make_graph()
|
||||||
|
id_a = make_id("alice")
|
||||||
|
m1 = Message(nonce=make_nonce(0), id=id_a, payload=b"m1")
|
||||||
|
v1 = g.extend(m1)
|
||||||
|
|
||||||
|
m2 = Message(nonce=make_nonce(1), id=id_a,
|
||||||
|
digests=(m1.compute_digest(),), payload=b"m2")
|
||||||
|
v2 = g.extend(m2)
|
||||||
|
|
||||||
|
m3 = Message(nonce=make_nonce(2), id=id_a,
|
||||||
|
digests=(m2.compute_digest(),), payload=b"m3")
|
||||||
|
v3 = g.extend(m3)
|
||||||
|
|
||||||
|
return g, v1, v2, v3
|
||||||
|
|
||||||
|
def test_direct_causes(self):
|
||||||
|
g, v1, v2, v3 = self._build_chain()
|
||||||
|
causes_of_v3 = g.direct_causes(v3)
|
||||||
|
assert v2 in causes_of_v3
|
||||||
|
assert v1 not in causes_of_v3
|
||||||
|
|
||||||
|
def test_direct_effects(self):
|
||||||
|
g, v1, v2, v3 = self._build_chain()
|
||||||
|
effects_of_v1 = g.direct_effects(v1)
|
||||||
|
assert v2 in effects_of_v1
|
||||||
|
assert v3 not in effects_of_v1 # v3 is indirect
|
||||||
|
|
||||||
|
def test_past(self):
|
||||||
|
"""G_v: the past of v contains all its causes."""
|
||||||
|
g, v1, v2, v3 = self._build_chain()
|
||||||
|
past_of_v3 = g.past(v3)
|
||||||
|
assert v1 in past_of_v3
|
||||||
|
assert v2 in past_of_v3
|
||||||
|
assert v3 in past_of_v3 # reflexive
|
||||||
|
|
||||||
|
def test_future(self):
|
||||||
|
g, v1, v2, v3 = self._build_chain()
|
||||||
|
future_of_v1 = g.future(v1)
|
||||||
|
assert v2 in future_of_v1
|
||||||
|
assert v3 in future_of_v1
|
||||||
|
assert v1 in future_of_v1 # reflexive
|
||||||
|
|
||||||
|
def test_is_cause_of(self):
|
||||||
|
g, v1, v2, v3 = self._build_chain()
|
||||||
|
assert g.is_cause_of(v1, v3)
|
||||||
|
assert g.is_cause_of(v1, v2)
|
||||||
|
assert not g.is_cause_of(v3, v1)
|
||||||
|
|
||||||
|
def test_timelike(self):
|
||||||
|
g, v1, v2, v3 = self._build_chain()
|
||||||
|
assert g.are_timelike(v1, v3)
|
||||||
|
assert g.are_timelike(v3, v1)
|
||||||
|
|
||||||
|
def test_spacelike(self):
|
||||||
|
"""Two independent vertices are spacelike."""
|
||||||
|
g = make_graph()
|
||||||
|
m_a = Message(nonce=make_nonce(0), id=make_id("alice"), payload=b"a")
|
||||||
|
m_b = Message(nonce=make_nonce(0), id=make_id("bob"), payload=b"b")
|
||||||
|
va = g.extend(m_a)
|
||||||
|
vb = g.extend(m_b)
|
||||||
|
assert g.are_spacelike(va, vb)
|
||||||
|
assert not g.are_timelike(va, vb)
|
||||||
|
|
||||||
|
|
||||||
|
class TestInvarianceOfThePast:
|
||||||
|
"""Theorem 3.7: The past of equivalent vertices in two Lamport graphs
|
||||||
|
have the same cardinality."""
|
||||||
|
|
||||||
|
def test_past_invariance_simple(self):
|
||||||
|
"""Same message in two different graphs has same-size past."""
|
||||||
|
g1 = make_graph()
|
||||||
|
g2 = make_graph()
|
||||||
|
id_a = make_id("alice")
|
||||||
|
|
||||||
|
m1 = Message(nonce=make_nonce(0), id=id_a, payload=b"genesis")
|
||||||
|
m2 = Message(nonce=make_nonce(1), id=id_a,
|
||||||
|
digests=(m1.compute_digest(),), payload=b"second")
|
||||||
|
|
||||||
|
# Add to both graphs
|
||||||
|
g1.extend(m1)
|
||||||
|
v1_in_g1 = g1.extend(m2)
|
||||||
|
|
||||||
|
g2.extend(m1)
|
||||||
|
v1_in_g2 = g2.extend(m2)
|
||||||
|
|
||||||
|
# Past should be the same size
|
||||||
|
assert len(g1.past(v1_in_g1)) == len(g2.past(v1_in_g2))
|
||||||
|
|
||||||
|
|
||||||
|
class TestMessageGeneration:
|
||||||
|
|
||||||
|
def test_generate_first_message(self):
|
||||||
|
g = make_graph()
|
||||||
|
msg = g.generate_message(make_id("alice"), b"hello")
|
||||||
|
v = g.extend(msg)
|
||||||
|
assert v is not None
|
||||||
|
assert v.payload == b"hello"
|
||||||
|
|
||||||
|
def test_generate_chain(self):
|
||||||
|
g = make_graph()
|
||||||
|
pid = make_id("alice")
|
||||||
|
m1 = g.generate_message(pid, b"first")
|
||||||
|
g.extend(m1)
|
||||||
|
|
||||||
|
m2 = g.generate_message(pid, b"second")
|
||||||
|
v2 = g.extend(m2)
|
||||||
|
assert v2 is not None
|
||||||
|
# Second message should reference the first
|
||||||
|
assert m1.compute_digest() in m2.digests
|
||||||
|
|
||||||
|
def test_generate_cross_references(self):
|
||||||
|
"""Messages should reference vertices from other process ids."""
|
||||||
|
g = make_graph()
|
||||||
|
pid_a = make_id("alice")
|
||||||
|
pid_b = make_id("bob")
|
||||||
|
|
||||||
|
m_a = g.generate_message(pid_a, b"alice's msg")
|
||||||
|
g.extend(m_a)
|
||||||
|
|
||||||
|
m_b = g.generate_message(pid_b, b"bob's msg")
|
||||||
|
g.extend(m_b)
|
||||||
|
|
||||||
|
# Alice's second message should reference bob's message
|
||||||
|
m_a2 = g.generate_message(pid_a, b"alice second")
|
||||||
|
assert m_b.compute_digest() in m_a2.digests or m_a.compute_digest() in m_a2.digests
|
||||||
125
tests/test_message.py
Normal file
125
tests/test_message.py
Normal file
|
|
@ -0,0 +1,125 @@
|
||||||
|
"""Tests for the message and vertex data structures."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from crisis.crypto import digest, DIGEST_LENGTH
|
||||||
|
from crisis.message import (
|
||||||
|
Message, Vertex, Vote,
|
||||||
|
NONCE_LENGTH, ID_LENGTH, NUM_DIGESTS_LENGTH,
|
||||||
|
EMPTY_MESSAGE_DIGEST,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def make_id(name: str) -> bytes:
|
||||||
|
return digest(name.encode())[:ID_LENGTH]
|
||||||
|
|
||||||
|
|
||||||
|
def make_nonce(n: int = 0) -> bytes:
|
||||||
|
return n.to_bytes(NONCE_LENGTH, "big")
|
||||||
|
|
||||||
|
|
||||||
|
class TestMessage:
|
||||||
|
|
||||||
|
def test_create_minimal_message(self):
|
||||||
|
msg = Message(nonce=make_nonce(), id=make_id("test"), digests=(), payload=b"")
|
||||||
|
assert msg.num_digests == 0
|
||||||
|
|
||||||
|
def test_nonce_length_validation(self):
|
||||||
|
with pytest.raises(ValueError, match="nonce"):
|
||||||
|
Message(nonce=b"\x00", id=make_id("x"))
|
||||||
|
|
||||||
|
def test_id_length_validation(self):
|
||||||
|
with pytest.raises(ValueError, match="id"):
|
||||||
|
Message(nonce=make_nonce(), id=b"\x00")
|
||||||
|
|
||||||
|
def test_digest_length_validation(self):
|
||||||
|
with pytest.raises(ValueError, match="digest"):
|
||||||
|
Message(nonce=make_nonce(), id=make_id("x"),
|
||||||
|
digests=(b"\x00",))
|
||||||
|
|
||||||
|
def test_serialize_roundtrip_deterministic(self):
|
||||||
|
msg = Message(nonce=make_nonce(42), id=make_id("proc1"),
|
||||||
|
digests=(), payload=b"hello world")
|
||||||
|
serialized = msg.serialize()
|
||||||
|
assert isinstance(serialized, bytes)
|
||||||
|
# Same message serializes the same way
|
||||||
|
assert msg.serialize() == serialized
|
||||||
|
|
||||||
|
def test_compute_digest_deterministic(self):
|
||||||
|
msg = Message(nonce=make_nonce(), id=make_id("test"), payload=b"data")
|
||||||
|
d1 = msg.compute_digest()
|
||||||
|
d2 = msg.compute_digest()
|
||||||
|
assert d1 == d2
|
||||||
|
assert len(d1) == DIGEST_LENGTH
|
||||||
|
|
||||||
|
def test_different_messages_different_digests(self):
|
||||||
|
m1 = Message(nonce=make_nonce(1), id=make_id("a"), payload=b"x")
|
||||||
|
m2 = Message(nonce=make_nonce(2), id=make_id("a"), payload=b"x")
|
||||||
|
assert m1.compute_digest() != m2.compute_digest()
|
||||||
|
|
||||||
|
def test_message_with_digests(self):
|
||||||
|
parent = Message(nonce=make_nonce(), id=make_id("a"), payload=b"parent")
|
||||||
|
child = Message(
|
||||||
|
nonce=make_nonce(1), id=make_id("a"),
|
||||||
|
digests=(parent.compute_digest(),),
|
||||||
|
payload=b"child"
|
||||||
|
)
|
||||||
|
assert child.num_digests == 1
|
||||||
|
assert child.digests[0] == parent.compute_digest()
|
||||||
|
|
||||||
|
def test_message_is_immutable(self):
|
||||||
|
msg = Message(nonce=make_nonce(), id=make_id("x"), payload=b"y")
|
||||||
|
with pytest.raises(AttributeError):
|
||||||
|
msg.nonce = b"\x00" * NONCE_LENGTH
|
||||||
|
|
||||||
|
|
||||||
|
class TestVertex:
|
||||||
|
|
||||||
|
def test_vertex_wraps_message(self):
|
||||||
|
msg = Message(nonce=make_nonce(), id=make_id("proc"), payload=b"data")
|
||||||
|
v = Vertex(m=msg)
|
||||||
|
assert v.nonce == msg.nonce
|
||||||
|
assert v.id == msg.id
|
||||||
|
assert v.payload == msg.payload
|
||||||
|
assert v.digests == msg.digests
|
||||||
|
|
||||||
|
def test_vertex_default_state(self):
|
||||||
|
msg = Message(nonce=make_nonce(), id=make_id("x"))
|
||||||
|
v = Vertex(m=msg)
|
||||||
|
assert v.round is None
|
||||||
|
assert v.is_last is None
|
||||||
|
assert v.svp == []
|
||||||
|
assert v.vote == {}
|
||||||
|
assert v.total_position is None
|
||||||
|
|
||||||
|
def test_vertex_equivalence(self):
|
||||||
|
"""Definition 3.3: two vertices are equivalent if v.m = v_hat.m"""
|
||||||
|
msg = Message(nonce=make_nonce(), id=make_id("x"), payload=b"same")
|
||||||
|
v1 = Vertex(m=msg)
|
||||||
|
v2 = Vertex(m=msg)
|
||||||
|
assert v1.equivalent_to(v2)
|
||||||
|
assert v1 == v2
|
||||||
|
assert hash(v1) == hash(v2)
|
||||||
|
|
||||||
|
def test_vertex_non_equivalence(self):
|
||||||
|
m1 = Message(nonce=make_nonce(1), id=make_id("x"))
|
||||||
|
m2 = Message(nonce=make_nonce(2), id=make_id("x"))
|
||||||
|
v1 = Vertex(m=m1)
|
||||||
|
v2 = Vertex(m=m2)
|
||||||
|
assert not v1.equivalent_to(v2)
|
||||||
|
assert v1 != v2
|
||||||
|
|
||||||
|
|
||||||
|
class TestVote:
|
||||||
|
|
||||||
|
def test_vote_undecided(self):
|
||||||
|
v = Vote(message=None, binary=None)
|
||||||
|
assert "∅" in repr(v)
|
||||||
|
assert "⊥" in repr(v)
|
||||||
|
|
||||||
|
def test_vote_with_message(self):
|
||||||
|
msg = Message(nonce=make_nonce(), id=make_id("x"))
|
||||||
|
v = Vote(message=msg, binary=1)
|
||||||
|
assert v.binary == 1
|
||||||
|
|
||||||
|
def test_empty_message_digest(self):
|
||||||
|
assert EMPTY_MESSAGE_DIGEST == digest(b"")
|
||||||
126
tests/test_order.py
Normal file
126
tests/test_order.py
Normal file
|
|
@ -0,0 +1,126 @@
|
||||||
|
"""Tests for total order computation."""
|
||||||
|
|
||||||
|
from crisis.crypto import digest
|
||||||
|
from crisis.graph import LamportGraph
|
||||||
|
from crisis.message import Message, ID_LENGTH, NONCE_LENGTH
|
||||||
|
from crisis.order import LeaderStream, compute_order, _kahns_total_order
|
||||||
|
from crisis.weight import ProofOfWorkWeight
|
||||||
|
|
||||||
|
|
||||||
|
def make_id(name: str) -> bytes:
|
||||||
|
return digest(name.encode())[:ID_LENGTH]
|
||||||
|
|
||||||
|
|
||||||
|
def make_nonce(n: int = 0) -> bytes:
|
||||||
|
return n.to_bytes(NONCE_LENGTH, "big")
|
||||||
|
|
||||||
|
|
||||||
|
def make_graph() -> LamportGraph:
|
||||||
|
return LamportGraph(weight_system=ProofOfWorkWeight(min_leading_zeros=0))
|
||||||
|
|
||||||
|
|
||||||
|
class TestLeaderStream:
|
||||||
|
|
||||||
|
def test_empty_stream(self):
|
||||||
|
ls = LeaderStream()
|
||||||
|
assert ls.max_round() == -1
|
||||||
|
assert ls.get_leader(0) is None
|
||||||
|
|
||||||
|
def test_add_leader(self):
|
||||||
|
ls = LeaderStream()
|
||||||
|
msg = Message(nonce=make_nonce(), id=make_id("leader"), payload=b"L")
|
||||||
|
updated = ls.update(0, 1, msg)
|
||||||
|
assert updated is True
|
||||||
|
assert ls.get_leader(0) is msg
|
||||||
|
|
||||||
|
def test_higher_deciding_round_replaces(self):
|
||||||
|
ls = LeaderStream()
|
||||||
|
m1 = Message(nonce=make_nonce(1), id=make_id("l1"), payload=b"old")
|
||||||
|
m2 = Message(nonce=make_nonce(2), id=make_id("l2"), payload=b"new")
|
||||||
|
|
||||||
|
ls.update(0, 1, m1)
|
||||||
|
ls.update(0, 2, m2)
|
||||||
|
|
||||||
|
assert ls.get_leader(0) is m2
|
||||||
|
|
||||||
|
def test_lower_deciding_round_rejected(self):
|
||||||
|
ls = LeaderStream()
|
||||||
|
m1 = Message(nonce=make_nonce(1), id=make_id("l1"), payload=b"first")
|
||||||
|
m2 = Message(nonce=make_nonce(2), id=make_id("l2"), payload=b"late")
|
||||||
|
|
||||||
|
ls.update(0, 5, m1)
|
||||||
|
updated = ls.update(0, 3, m2)
|
||||||
|
|
||||||
|
assert updated is False
|
||||||
|
assert ls.get_leader(0) is m1
|
||||||
|
|
||||||
|
def test_all_leaders_sorted(self):
|
||||||
|
ls = LeaderStream()
|
||||||
|
m0 = Message(nonce=make_nonce(0), id=make_id("l0"), payload=b"r0")
|
||||||
|
m1 = Message(nonce=make_nonce(1), id=make_id("l1"), payload=b"r1")
|
||||||
|
m2 = Message(nonce=make_nonce(2), id=make_id("l2"), payload=b"r2")
|
||||||
|
|
||||||
|
ls.update(2, 3, m2)
|
||||||
|
ls.update(0, 1, m0)
|
||||||
|
ls.update(1, 2, m1)
|
||||||
|
|
||||||
|
leaders = ls.all_leaders()
|
||||||
|
rounds = [r for r, _ in leaders]
|
||||||
|
assert rounds == sorted(rounds)
|
||||||
|
|
||||||
|
|
||||||
|
class TestKahnsAlgorithm:
|
||||||
|
|
||||||
|
def test_empty_input(self):
|
||||||
|
g = make_graph()
|
||||||
|
result = _kahns_total_order([], g)
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
def test_single_vertex(self):
|
||||||
|
g = make_graph()
|
||||||
|
msg = Message(nonce=make_nonce(), id=make_id("x"), payload=b"only")
|
||||||
|
v = g.extend(msg)
|
||||||
|
result = _kahns_total_order([v], g)
|
||||||
|
assert result == [v]
|
||||||
|
|
||||||
|
def test_chain_order(self):
|
||||||
|
"""A chain should be ordered causes-first."""
|
||||||
|
g = make_graph()
|
||||||
|
pid = make_id("alice")
|
||||||
|
m1 = Message(nonce=make_nonce(0), id=pid, payload=b"first")
|
||||||
|
v1 = g.extend(m1)
|
||||||
|
|
||||||
|
m2 = Message(nonce=make_nonce(1), id=pid,
|
||||||
|
digests=(m1.compute_digest(),), payload=b"second")
|
||||||
|
v2 = g.extend(m2)
|
||||||
|
|
||||||
|
m3 = Message(nonce=make_nonce(2), id=pid,
|
||||||
|
digests=(m2.compute_digest(),), payload=b"third")
|
||||||
|
v3 = g.extend(m3)
|
||||||
|
|
||||||
|
result = _kahns_total_order([v1, v2, v3], g)
|
||||||
|
# Causes come first
|
||||||
|
assert result.index(v1) < result.index(v2)
|
||||||
|
assert result.index(v2) < result.index(v3)
|
||||||
|
|
||||||
|
def test_respects_causality(self):
|
||||||
|
"""Total order must be consistent with causal order."""
|
||||||
|
g = make_graph()
|
||||||
|
m_a = Message(nonce=make_nonce(0), id=make_id("alice"), payload=b"a")
|
||||||
|
va = g.extend(m_a)
|
||||||
|
|
||||||
|
m_b = Message(nonce=make_nonce(0), id=make_id("bob"), payload=b"b")
|
||||||
|
vb = g.extend(m_b)
|
||||||
|
|
||||||
|
# Carol references both alice and bob
|
||||||
|
m_c = Message(
|
||||||
|
nonce=make_nonce(1), id=make_id("carol"),
|
||||||
|
digests=(m_a.compute_digest(), m_b.compute_digest()),
|
||||||
|
payload=b"c"
|
||||||
|
)
|
||||||
|
vc = g.extend(m_c)
|
||||||
|
|
||||||
|
result = _kahns_total_order([va, vb, vc], g)
|
||||||
|
# Carol must come after both alice and bob
|
||||||
|
assert result.index(va) < result.index(vc)
|
||||||
|
assert result.index(vb) < result.index(vc)
|
||||||
110
tests/test_rounds.py
Normal file
110
tests/test_rounds.py
Normal file
|
|
@ -0,0 +1,110 @@
|
||||||
|
"""Tests for virtual synchronous rounds."""
|
||||||
|
|
||||||
|
from crisis.crypto import digest
|
||||||
|
from crisis.graph import LamportGraph
|
||||||
|
from crisis.message import Message, ID_LENGTH, NONCE_LENGTH
|
||||||
|
from crisis.rounds import compute_rounds, max_round, last_vertices_in_round, vertices_in_round
|
||||||
|
from crisis.weight import ProofOfWorkWeight, DifficultyOracle
|
||||||
|
|
||||||
|
|
||||||
|
def make_id(name: str) -> bytes:
|
||||||
|
return digest(name.encode())[:ID_LENGTH]
|
||||||
|
|
||||||
|
|
||||||
|
def make_nonce(n: int = 0) -> bytes:
|
||||||
|
return n.to_bytes(NONCE_LENGTH, "big")
|
||||||
|
|
||||||
|
|
||||||
|
def make_graph() -> LamportGraph:
|
||||||
|
return LamportGraph(weight_system=ProofOfWorkWeight(min_leading_zeros=0))
|
||||||
|
|
||||||
|
|
||||||
|
class TestRoundComputation:
|
||||||
|
|
||||||
|
def test_single_vertex_round_zero(self):
|
||||||
|
"""A single vertex with no causes is in round 0."""
|
||||||
|
g = make_graph()
|
||||||
|
msg = Message(nonce=make_nonce(), id=make_id("alice"), payload=b"genesis")
|
||||||
|
v = g.extend(msg)
|
||||||
|
compute_rounds(g, DifficultyOracle(constant_difficulty=1))
|
||||||
|
assert v.round == 0
|
||||||
|
|
||||||
|
def test_single_vertex_is_last(self):
|
||||||
|
"""Round 0 vertices are always 'last' (bootstrapping)."""
|
||||||
|
g = make_graph()
|
||||||
|
msg = Message(nonce=make_nonce(), id=make_id("alice"))
|
||||||
|
v = g.extend(msg)
|
||||||
|
compute_rounds(g, DifficultyOracle(constant_difficulty=1))
|
||||||
|
assert v.is_last is True
|
||||||
|
|
||||||
|
def test_chain_grows_rounds(self):
|
||||||
|
"""A chain of messages should produce increasing round numbers."""
|
||||||
|
g = make_graph()
|
||||||
|
pid = make_id("alice")
|
||||||
|
difficulty = DifficultyOracle(constant_difficulty=0) # Low difficulty
|
||||||
|
|
||||||
|
# Create a chain
|
||||||
|
prev_msg = None
|
||||||
|
vertices = []
|
||||||
|
for i in range(5):
|
||||||
|
digests = (prev_msg.compute_digest(),) if prev_msg else ()
|
||||||
|
msg = Message(nonce=make_nonce(i), id=pid, digests=digests, payload=f"msg{i}".encode())
|
||||||
|
v = g.extend(msg)
|
||||||
|
vertices.append(v)
|
||||||
|
prev_msg = msg
|
||||||
|
|
||||||
|
compute_rounds(g, difficulty, connectivity_k=0)
|
||||||
|
|
||||||
|
# All should have round numbers assigned
|
||||||
|
for v in vertices:
|
||||||
|
assert v.round is not None
|
||||||
|
|
||||||
|
# First vertex is round 0
|
||||||
|
assert vertices[0].round == 0
|
||||||
|
|
||||||
|
def test_max_round_empty_graph(self):
|
||||||
|
g = make_graph()
|
||||||
|
assert max_round(g) == 0
|
||||||
|
|
||||||
|
def test_max_round_with_vertices(self):
|
||||||
|
g = make_graph()
|
||||||
|
msg = Message(nonce=make_nonce(), id=make_id("x"))
|
||||||
|
g.extend(msg)
|
||||||
|
compute_rounds(g, DifficultyOracle(constant_difficulty=1))
|
||||||
|
assert max_round(g) == 0
|
||||||
|
|
||||||
|
def test_last_vertices_in_round(self):
|
||||||
|
g = make_graph()
|
||||||
|
msg = Message(nonce=make_nonce(), id=make_id("alice"))
|
||||||
|
g.extend(msg)
|
||||||
|
compute_rounds(g, DifficultyOracle(constant_difficulty=1))
|
||||||
|
lasts = last_vertices_in_round(g, 0)
|
||||||
|
assert len(lasts) == 1
|
||||||
|
|
||||||
|
def test_multiple_ids_same_round(self):
|
||||||
|
"""Multiple independent vertices are all in round 0."""
|
||||||
|
g = make_graph()
|
||||||
|
for name in ["alice", "bob", "carol"]:
|
||||||
|
msg = Message(nonce=make_nonce(), id=make_id(name), payload=name.encode())
|
||||||
|
g.extend(msg)
|
||||||
|
|
||||||
|
compute_rounds(g, DifficultyOracle(constant_difficulty=1))
|
||||||
|
|
||||||
|
r0 = vertices_in_round(g, 0)
|
||||||
|
assert len(r0) == 3
|
||||||
|
|
||||||
|
def test_round_invariance(self):
|
||||||
|
"""Proposition 5.3: equivalent vertices in different graphs have same round."""
|
||||||
|
g1 = make_graph()
|
||||||
|
g2 = make_graph()
|
||||||
|
difficulty = DifficultyOracle(constant_difficulty=1)
|
||||||
|
|
||||||
|
msg = Message(nonce=make_nonce(), id=make_id("alice"), payload=b"genesis")
|
||||||
|
v1 = g1.extend(msg)
|
||||||
|
v2 = g2.extend(msg)
|
||||||
|
|
||||||
|
compute_rounds(g1, difficulty)
|
||||||
|
compute_rounds(g2, difficulty)
|
||||||
|
|
||||||
|
assert v1.round == v2.round
|
||||||
|
assert v1.is_last == v2.is_last
|
||||||
55
tests/test_simulation.py
Normal file
55
tests/test_simulation.py
Normal file
|
|
@ -0,0 +1,55 @@
|
||||||
|
"""Integration test: run the full simulation and verify basic properties."""
|
||||||
|
|
||||||
|
from crisis.demo import Simulation
|
||||||
|
|
||||||
|
|
||||||
|
class TestSimulation:
|
||||||
|
|
||||||
|
def test_simulation_runs(self):
|
||||||
|
"""The simulation should complete without errors."""
|
||||||
|
sim = Simulation(num_honest=3, num_byzantine=0, seed=42)
|
||||||
|
results = sim.run(num_steps=5, verbose=False)
|
||||||
|
assert len(results) == 5
|
||||||
|
|
||||||
|
def test_graphs_grow(self):
|
||||||
|
"""Each step should add messages to the graphs."""
|
||||||
|
sim = Simulation(num_honest=2, seed=42)
|
||||||
|
sim.run(num_steps=3, verbose=False)
|
||||||
|
for node in sim.nodes:
|
||||||
|
assert node.graph.vertex_count() > 0
|
||||||
|
|
||||||
|
def test_honest_nodes_same_graph_size(self):
|
||||||
|
"""All honest nodes should have the same number of vertices
|
||||||
|
(since all messages are delivered to all nodes)."""
|
||||||
|
sim = Simulation(num_honest=3, seed=42)
|
||||||
|
sim.run(num_steps=5, verbose=False)
|
||||||
|
sizes = [n.graph.vertex_count() for n in sim.nodes]
|
||||||
|
assert all(s == sizes[0] for s in sizes)
|
||||||
|
|
||||||
|
def test_rounds_are_computed(self):
|
||||||
|
"""After running, vertices should have round numbers."""
|
||||||
|
sim = Simulation(num_honest=3, seed=42)
|
||||||
|
sim.run(num_steps=5, verbose=False)
|
||||||
|
for node in sim.nodes:
|
||||||
|
for v in node.graph.all_vertices():
|
||||||
|
assert v.round is not None
|
||||||
|
|
||||||
|
def test_with_byzantine_node(self):
|
||||||
|
"""Simulation should handle byzantine nodes without crashing."""
|
||||||
|
sim = Simulation(num_honest=3, num_byzantine=1, seed=42)
|
||||||
|
results = sim.run(num_steps=5, verbose=False)
|
||||||
|
assert len(results) == 5
|
||||||
|
|
||||||
|
def test_deterministic_with_seed(self):
|
||||||
|
"""Same seed should produce the same results."""
|
||||||
|
sim1 = Simulation(num_honest=3, seed=123)
|
||||||
|
r1 = sim1.run(num_steps=3, verbose=False)
|
||||||
|
|
||||||
|
sim2 = Simulation(num_honest=3, seed=123)
|
||||||
|
r2 = sim2.run(num_steps=3, verbose=False)
|
||||||
|
|
||||||
|
# Same number of messages at each step
|
||||||
|
for s1, s2 in zip(r1, r2):
|
||||||
|
assert len(s1["new_messages"]) == len(s2["new_messages"])
|
||||||
|
for ns1, ns2 in zip(s1["node_states"], s2["node_states"]):
|
||||||
|
assert ns1["vertices"] == ns2["vertices"]
|
||||||
78
tests/test_weight.py
Normal file
78
tests/test_weight.py
Normal file
|
|
@ -0,0 +1,78 @@
|
||||||
|
"""Tests for the weight system and difficulty oracle."""
|
||||||
|
|
||||||
|
from crisis.crypto import digest
|
||||||
|
from crisis.message import Message, ID_LENGTH, NONCE_LENGTH
|
||||||
|
from crisis.weight import ProofOfWorkWeight, DifficultyOracle
|
||||||
|
|
||||||
|
|
||||||
|
def make_id(name: str) -> bytes:
|
||||||
|
return digest(name.encode())[:ID_LENGTH]
|
||||||
|
|
||||||
|
|
||||||
|
def make_nonce(n: int = 0) -> bytes:
|
||||||
|
return n.to_bytes(NONCE_LENGTH, "big")
|
||||||
|
|
||||||
|
|
||||||
|
class TestProofOfWorkWeight:
|
||||||
|
|
||||||
|
def test_weight_is_non_negative(self):
|
||||||
|
ws = ProofOfWorkWeight(min_leading_zeros=0)
|
||||||
|
msg = Message(nonce=make_nonce(), id=make_id("x"), payload=b"test")
|
||||||
|
assert ws.weight(msg) >= 0
|
||||||
|
|
||||||
|
def test_weight_sum_is_additive(self):
|
||||||
|
ws = ProofOfWorkWeight()
|
||||||
|
assert ws.weight_sum(3, 5) == 8
|
||||||
|
assert ws.weight_sum(0, 0) == 0
|
||||||
|
|
||||||
|
def test_threshold(self):
|
||||||
|
ws = ProofOfWorkWeight(min_leading_zeros=2)
|
||||||
|
assert ws.threshold == 2
|
||||||
|
|
||||||
|
def test_is_valid_weight_with_zero_threshold(self):
|
||||||
|
ws = ProofOfWorkWeight(min_leading_zeros=0)
|
||||||
|
msg = Message(nonce=make_nonce(), id=make_id("x"))
|
||||||
|
assert ws.is_valid_weight(msg) # Everything passes with 0
|
||||||
|
|
||||||
|
def test_mine_nonce_finds_valid_message(self):
|
||||||
|
ws = ProofOfWorkWeight(min_leading_zeros=1)
|
||||||
|
msg = ws.mine_nonce(
|
||||||
|
id_bytes=make_id("miner"),
|
||||||
|
digests=(),
|
||||||
|
payload=b"test payload",
|
||||||
|
target_weight=1
|
||||||
|
)
|
||||||
|
assert ws.weight(msg) >= 1
|
||||||
|
assert ws.is_valid_weight(msg)
|
||||||
|
|
||||||
|
def test_different_nonces_different_weights(self):
|
||||||
|
"""Uniqueness property: different messages have different weights (w.h.p.)."""
|
||||||
|
ws = ProofOfWorkWeight()
|
||||||
|
weights = set()
|
||||||
|
for i in range(20):
|
||||||
|
msg = Message(nonce=make_nonce(i), id=make_id("x"), payload=b"same")
|
||||||
|
weights.add(ws.weight(msg))
|
||||||
|
# Not all the same (with overwhelming probability)
|
||||||
|
assert len(weights) > 1
|
||||||
|
|
||||||
|
def test_tamper_proof(self):
|
||||||
|
"""Changing a message should change its weight (w.h.p.)."""
|
||||||
|
ws = ProofOfWorkWeight()
|
||||||
|
msg1 = Message(nonce=make_nonce(42), id=make_id("x"), payload=b"original")
|
||||||
|
msg2 = Message(nonce=make_nonce(42), id=make_id("x"), payload=b"tampered")
|
||||||
|
# Weights differ because digests differ
|
||||||
|
# (this is probabilistic, but extremely likely)
|
||||||
|
assert msg1.compute_digest() != msg2.compute_digest()
|
||||||
|
|
||||||
|
|
||||||
|
class TestDifficultyOracle:
|
||||||
|
|
||||||
|
def test_constant_difficulty(self):
|
||||||
|
d = DifficultyOracle(constant_difficulty=5)
|
||||||
|
assert d.difficulty(0) == 5
|
||||||
|
assert d.difficulty(100) == 5
|
||||||
|
assert d.difficulty(999) == 5
|
||||||
|
|
||||||
|
def test_default_difficulty(self):
|
||||||
|
d = DifficultyOracle()
|
||||||
|
assert d.difficulty(0) == 4
|
||||||
Loading…
Reference in a new issue