mirror of
https://github.com/saymrwulf/crisis.git
synced 2026-05-14 20:37:54 +00:00
Initial implementation of the Crisis protocol (Richter, 2019)
Complete Python PoC of "Probabilistically Self Organizing Total Order in Unstructured P2P Networks". Implements all 10 algorithms from the paper: message generation, integrity checks, Lamport graphs, virtual synchronous rounds, safe voting patterns, virtual leader election (BA*), longest chain rule, total order via Kahn's algorithm, and push/pull gossip. Includes simulation harness, full node binary, and 72 passing tests. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
commit
1df4790fb4
22 changed files with 3987 additions and 0 deletions
26
.gitignore
vendored
Normal file
26
.gitignore
vendored
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
# Python
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
*.egg-info/
|
||||
dist/
|
||||
build/
|
||||
*.egg
|
||||
|
||||
# Virtual environment
|
||||
.venv/
|
||||
|
||||
# IDE
|
||||
.idea/
|
||||
.vscode/
|
||||
*.swp
|
||||
*.swo
|
||||
|
||||
# OS
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
|
||||
# Testing
|
||||
.pytest_cache/
|
||||
.coverage
|
||||
htmlcov/
|
||||
BIN
Crisis.mirco-richter-2019.pdf
Normal file
BIN
Crisis.mirco-richter-2019.pdf
Normal file
Binary file not shown.
30
pyproject.toml
Normal file
30
pyproject.toml
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
[project]
|
||||
name = "crisis"
|
||||
version = "0.1.0"
|
||||
description = "Crisis: Probabilistically Self Organizing Total Order in Unstructured P2P Networks"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.11"
|
||||
license = "CC-BY-4.0"
|
||||
authors = [
|
||||
{ name = "Mirco Richter (paper)", email = "mirco.richter@mailbox.org" },
|
||||
]
|
||||
dependencies = [
|
||||
"networkx>=3.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"pytest>=8.0",
|
||||
"rich>=13.0",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
crisis-node = "crisis.node:main"
|
||||
crisis-demo = "crisis.demo:main"
|
||||
|
||||
[build-system]
|
||||
requires = ["setuptools>=68.0"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = ["tests"]
|
||||
21
src/crisis/__init__.py
Normal file
21
src/crisis/__init__.py
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
"""
|
||||
Crisis: Probabilistically Self Organizing Total Order in Unstructured P2P Networks
|
||||
|
||||
A Python implementation of the Crisis protocol described by Mirco Richter (2019).
|
||||
|
||||
The protocol achieves total order on messages in fully open, unstructured
|
||||
Peer-to-Peer networks through virtual voting -- votes are never sent explicitly
|
||||
but are deduced from the causal relationships between messages encoded in
|
||||
Lamport graphs.
|
||||
|
||||
Key components:
|
||||
- crypto: Random oracle model (SHA-256 hash function)
|
||||
- message: Message and Vertex data structures
|
||||
- weight: Weight systems (PoW-based Sybil resistance)
|
||||
- graph: Lamport graphs with integrity checking
|
||||
- rounds: Virtual synchronous rounds
|
||||
- voting: Safe voting patterns and virtual leader election (BA*)
|
||||
- order: Total order via leader stream and topological sorting
|
||||
- gossip: Push/pull gossip for member discovery and message dissemination
|
||||
- node: Full Crisis node tying all components together
|
||||
"""
|
||||
104
src/crisis/crypto.py
Normal file
104
src/crisis/crypto.py
Normal file
|
|
@ -0,0 +1,104 @@
|
|||
"""
|
||||
Random Oracle Model (Section 2.1)
|
||||
|
||||
We work in the random oracle model, assuming the existence of a cryptographic
|
||||
hash function that behaves like a random oracle:
|
||||
|
||||
H : {0,1}* -> {0,1}^p (Eq. 1)
|
||||
|
||||
We use SHA-256 as our concrete instantiation. H is assumed to be collision-,
|
||||
preimage-, and second-preimage-resistant.
|
||||
|
||||
We call H(b) the *digest* of the binary string b.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
from typing import Union
|
||||
|
||||
# The digest length in bytes (SHA-256 produces 32 bytes = 256 bits).
|
||||
DIGEST_LENGTH = 32
|
||||
|
||||
|
||||
def digest(data: Union[bytes, bytearray]) -> bytes:
|
||||
"""Compute the SHA-256 digest of arbitrary binary data.
|
||||
|
||||
This is the core random oracle H used throughout the protocol.
|
||||
Every reference to "the digest of" a message or byte string in the
|
||||
paper maps to this function.
|
||||
|
||||
Returns:
|
||||
32-byte digest (256 bits).
|
||||
"""
|
||||
return hashlib.sha256(data).digest()
|
||||
|
||||
|
||||
def digest_hex(data: Union[bytes, bytearray]) -> str:
|
||||
"""Convenience: return the digest as a hex string for display."""
|
||||
return digest(data).hex()
|
||||
|
||||
|
||||
def verify_digest(data: bytes, expected: bytes) -> bool:
|
||||
"""Check that H(data) equals the expected digest."""
|
||||
return digest(data) == expected
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Least significant bit helper (used in the virtual coin flip, Algorithm 7)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def least_significant_bit(h: bytes) -> int:
|
||||
"""Return the least significant bit of a hash value.
|
||||
|
||||
Used in Algorithm 7 (virtual leader election) for the "genuine coin flip"
|
||||
stage, where the LSB of H(v_hat.m) determines the binary vote.
|
||||
|
||||
The paper defines:
|
||||
b_coin := lsb(H(x.m)) for max weight x in S
|
||||
"""
|
||||
return h[-1] & 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Proof-of-Work helpers (used by the weight system, Section 3.1.1)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def count_leading_zero_bits(h: bytes) -> int:
|
||||
"""Count the number of leading zero bits in a hash value.
|
||||
|
||||
This is the standard measure of proof-of-work difficulty: a hash with
|
||||
k leading zero bits required roughly 2^k hash evaluations to find.
|
||||
"""
|
||||
count = 0
|
||||
for byte in h:
|
||||
if byte == 0:
|
||||
count += 8
|
||||
else:
|
||||
# Count leading zeros in this byte
|
||||
count += (byte ^ 0xFF).bit_length() - (255 - byte).bit_length()
|
||||
# Simpler: count leading zeros via bit tricks
|
||||
for bit_pos in range(7, -1, -1):
|
||||
if byte & (1 << bit_pos):
|
||||
return count
|
||||
count += 1
|
||||
break
|
||||
return count
|
||||
|
||||
|
||||
def count_leading_zero_bits(h: bytes) -> int:
|
||||
"""Count the number of leading zero bits in a hash value.
|
||||
|
||||
A hash with k leading zero bits required roughly 2^k evaluations to find.
|
||||
Used by the PoW weight function to assign weight to messages.
|
||||
"""
|
||||
count = 0
|
||||
for byte in h:
|
||||
if byte == 0:
|
||||
count += 8
|
||||
continue
|
||||
# Count leading zeros in this non-zero byte
|
||||
for bit_pos in range(7, -1, -1):
|
||||
if byte & (1 << bit_pos):
|
||||
return count
|
||||
count += 1
|
||||
break
|
||||
return count
|
||||
356
src/crisis/demo.py
Normal file
356
src/crisis/demo.py
Normal file
|
|
@ -0,0 +1,356 @@
|
|||
"""
|
||||
Demonstration / Simulation Harness
|
||||
|
||||
This module provides a deterministic, single-process simulation of the Crisis
|
||||
protocol with N virtual nodes. It is designed as the foundation for a lecture
|
||||
series: each phase of the protocol can be observed step by step.
|
||||
|
||||
The simulation bypasses the network layer entirely -- messages are delivered
|
||||
directly between in-memory Lamport graphs. This makes the consensus mechanism
|
||||
visible without network noise.
|
||||
|
||||
Usage:
|
||||
python -m crisis.demo # Run the full demo
|
||||
python -m crisis.demo --nodes 5 # 5 honest nodes
|
||||
python -m crisis.demo --byzantine 1 # 1 byzantine node
|
||||
python -m crisis.demo --rounds 10 # Run for 10 message rounds
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from crisis.crypto import digest
|
||||
from crisis.graph import LamportGraph
|
||||
from crisis.message import Message, Vertex, ID_LENGTH, NONCE_LENGTH
|
||||
from crisis.order import LeaderStream, compute_order
|
||||
from crisis.rounds import compute_rounds, max_round, last_vertices_in_round
|
||||
from crisis.voting import compute_safe_voting_pattern, compute_virtual_leader_election
|
||||
from crisis.weight import ProofOfWorkWeight, DifficultyOracle
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Simulated Node
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass
|
||||
class SimulatedNode:
|
||||
"""A simulated Crisis node running in-memory.
|
||||
|
||||
Each node has its own Lamport graph and process id. Messages are
|
||||
exchanged by directly sharing Message objects (no serialization needed).
|
||||
"""
|
||||
name: str
|
||||
process_id: bytes
|
||||
graph: LamportGraph
|
||||
leader_stream: LeaderStream = field(default_factory=LeaderStream)
|
||||
is_byzantine: bool = False
|
||||
messages_created: int = 0
|
||||
|
||||
def generate_message(self, payload: str) -> Message:
|
||||
"""Generate a new message from this node."""
|
||||
self.messages_created += 1
|
||||
return self.graph.generate_message(
|
||||
self.process_id,
|
||||
payload.encode(),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Simulation Engine
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class Simulation:
|
||||
"""Deterministic simulation of N Crisis nodes.
|
||||
|
||||
Runs the protocol in lock-step rounds:
|
||||
1. Each node generates a message
|
||||
2. Messages are gossiped (delivered to all nodes)
|
||||
3. Consensus is computed on each node
|
||||
4. State is displayed
|
||||
|
||||
This allows observing how the Lamport graph grows, rounds emerge,
|
||||
and total order converges.
|
||||
"""
|
||||
|
||||
def __init__(self, num_honest: int = 3, num_byzantine: int = 0,
|
||||
pow_zeros: int = 0, difficulty: int = 2,
|
||||
connectivity_k: int = 1, seed: int = 42):
|
||||
self.difficulty_oracle = DifficultyOracle(constant_difficulty=difficulty)
|
||||
self.connectivity_k = connectivity_k
|
||||
self.weight_system = ProofOfWorkWeight(min_leading_zeros=pow_zeros)
|
||||
self.seed = seed
|
||||
random.seed(seed)
|
||||
|
||||
# Create nodes
|
||||
self.nodes: list[SimulatedNode] = []
|
||||
for i in range(num_honest):
|
||||
name = f"honest-{i}"
|
||||
pid = digest(name.encode())[:ID_LENGTH]
|
||||
graph = LamportGraph(weight_system=self.weight_system)
|
||||
self.nodes.append(SimulatedNode(
|
||||
name=name, process_id=pid, graph=graph
|
||||
))
|
||||
|
||||
for i in range(num_byzantine):
|
||||
name = f"byzantine-{i}"
|
||||
pid = digest(name.encode())[:ID_LENGTH]
|
||||
graph = LamportGraph(weight_system=self.weight_system)
|
||||
self.nodes.append(SimulatedNode(
|
||||
name=name, process_id=pid, graph=graph, is_byzantine=True
|
||||
))
|
||||
|
||||
self.step_count = 0
|
||||
self.all_messages: list[Message] = []
|
||||
|
||||
def step(self) -> dict:
|
||||
"""Execute one simulation step.
|
||||
|
||||
Returns a dict with step results for display.
|
||||
"""
|
||||
self.step_count += 1
|
||||
step_results = {
|
||||
"step": self.step_count,
|
||||
"new_messages": [],
|
||||
"node_states": [],
|
||||
}
|
||||
|
||||
# Phase 1: Each node generates a message
|
||||
new_messages: list[tuple[SimulatedNode, Message]] = []
|
||||
for node in self.nodes:
|
||||
if node.is_byzantine:
|
||||
msg = self._byzantine_message(node)
|
||||
else:
|
||||
payload = f"step-{self.step_count}-{node.name}"
|
||||
msg = node.generate_message(payload)
|
||||
|
||||
if msg is not None:
|
||||
new_messages.append((node, msg))
|
||||
step_results["new_messages"].append({
|
||||
"from": node.name,
|
||||
"digest": msg.compute_digest().hex()[:12],
|
||||
"weight": self.weight_system.weight(msg),
|
||||
"payload": msg.payload.decode(errors="replace"),
|
||||
})
|
||||
|
||||
# Phase 2: Gossip -- deliver all messages to all nodes
|
||||
for source_node, msg in new_messages:
|
||||
self.all_messages.append(msg)
|
||||
for target_node in self.nodes:
|
||||
# Deliver to all nodes (including source, for consistency)
|
||||
target_node.graph.extend(msg)
|
||||
|
||||
# Also re-deliver older messages that nodes might be missing
|
||||
# (simulates pull gossip catching up)
|
||||
for msg in self.all_messages:
|
||||
for node in self.nodes:
|
||||
node.graph.extend(msg) # extend() is idempotent (integrity check)
|
||||
|
||||
# Phase 3: Compute consensus on each node
|
||||
for node in self.nodes:
|
||||
compute_rounds(node.graph, self.difficulty_oracle, self.connectivity_k)
|
||||
|
||||
for vertex in node.graph.all_vertices():
|
||||
if vertex.is_last:
|
||||
compute_safe_voting_pattern(
|
||||
vertex, node.graph, self.difficulty_oracle,
|
||||
self.connectivity_k
|
||||
)
|
||||
|
||||
leader_dict: dict[int, list[tuple[int, Message]]] = {}
|
||||
for vertex in node.graph.all_vertices():
|
||||
if vertex.svp:
|
||||
compute_virtual_leader_election(
|
||||
vertex, node.graph, self.difficulty_oracle,
|
||||
self.connectivity_k, leader_dict
|
||||
)
|
||||
|
||||
for round_num, entries in leader_dict.items():
|
||||
for deciding_round, leader_msg in entries:
|
||||
node.leader_stream.update(round_num, deciding_round, leader_msg)
|
||||
|
||||
ordered = compute_order(node.graph, node.leader_stream)
|
||||
|
||||
mr = max_round(node.graph)
|
||||
step_results["node_states"].append({
|
||||
"name": node.name,
|
||||
"vertices": node.graph.vertex_count(),
|
||||
"max_round": mr,
|
||||
"leaders": len(node.leader_stream.leaders),
|
||||
"ordered": len(ordered),
|
||||
"is_byzantine": node.is_byzantine,
|
||||
})
|
||||
|
||||
return step_results
|
||||
|
||||
def _byzantine_message(self, node: SimulatedNode) -> Optional[Message]:
|
||||
"""Generate a byzantine message.
|
||||
|
||||
Byzantine nodes can exhibit several faulty behaviors:
|
||||
- Mutations: same id, forking the causal chain
|
||||
- Strategic distribution: different messages to different peers
|
||||
- Time travel: referencing old rounds
|
||||
|
||||
For this demo, we generate a message with a random payload that
|
||||
may not reference the latest same-id message (creating a mutation).
|
||||
"""
|
||||
payload = f"byz-{self.step_count}-{node.name}-{random.randint(0, 999)}"
|
||||
|
||||
# 50% chance of creating a mutation (not referencing last same-id vertex)
|
||||
if random.random() < 0.5 and node.graph.vertex_count() > 0:
|
||||
# Pick random digests instead of following the chain
|
||||
available = list(node.graph.vertices.keys())
|
||||
num_refs = min(random.randint(1, 3), len(available))
|
||||
digests = tuple(random.sample(available, num_refs))
|
||||
nonce = os.urandom(NONCE_LENGTH)
|
||||
return Message(
|
||||
nonce=nonce, id=node.process_id,
|
||||
digests=digests, payload=payload.encode()
|
||||
)
|
||||
else:
|
||||
return node.generate_message(payload)
|
||||
|
||||
def run(self, num_steps: int = 10, verbose: bool = True) -> list[dict]:
|
||||
"""Run the simulation for a number of steps."""
|
||||
results = []
|
||||
for _ in range(num_steps):
|
||||
result = self.step()
|
||||
results.append(result)
|
||||
if verbose:
|
||||
_print_step(result)
|
||||
|
||||
if verbose:
|
||||
_print_convergence_summary(self)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Display functions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _print_step(result: dict) -> None:
|
||||
"""Print the results of a simulation step."""
|
||||
print(f"\n{'='*70}")
|
||||
print(f" Step {result['step']}")
|
||||
print(f"{'='*70}")
|
||||
|
||||
print(f"\n New messages:")
|
||||
for msg in result["new_messages"]:
|
||||
print(f" {msg['from']:>15s} -> {msg['digest']} "
|
||||
f"w={msg['weight']} {msg['payload'][:40]}")
|
||||
|
||||
print(f"\n Node states:")
|
||||
print(f" {'Name':>15s} {'Vertices':>8s} {'Round':>5s} "
|
||||
f"{'Leaders':>7s} {'Ordered':>7s}")
|
||||
print(f" {'-'*15} {'-'*8} {'-'*5} {'-'*7} {'-'*7}")
|
||||
for ns in result["node_states"]:
|
||||
byz = " [BYZ]" if ns["is_byzantine"] else ""
|
||||
print(f" {ns['name']:>15s} {ns['vertices']:>8d} "
|
||||
f"{ns['max_round']:>5d} {ns['leaders']:>7d} "
|
||||
f"{ns['ordered']:>7d}{byz}")
|
||||
|
||||
|
||||
def _print_convergence_summary(sim: Simulation) -> None:
|
||||
"""Print a summary showing whether honest nodes have converged."""
|
||||
print(f"\n{'='*70}")
|
||||
print(f" Convergence Summary")
|
||||
print(f"{'='*70}")
|
||||
|
||||
honest_nodes = [n for n in sim.nodes if not n.is_byzantine]
|
||||
|
||||
# Check if all honest nodes have the same total order
|
||||
orders = []
|
||||
for node in honest_nodes:
|
||||
ordered = compute_order(node.graph, node.leader_stream)
|
||||
order_digests = [v.message_digest.hex()[:12] for v in ordered]
|
||||
orders.append(order_digests)
|
||||
|
||||
if len(orders) >= 2:
|
||||
# Compare pairwise
|
||||
all_agree = all(o == orders[0] for o in orders[1:])
|
||||
if all_agree:
|
||||
print(f"\n All {len(honest_nodes)} honest nodes AGREE "
|
||||
f"on total order ({len(orders[0])} messages)")
|
||||
else:
|
||||
print(f"\n Honest nodes have DIVERGENT total orders "
|
||||
f"(convergence in progress)")
|
||||
for i, (node, order) in enumerate(zip(honest_nodes, orders)):
|
||||
print(f" {node.name}: {len(order)} ordered messages")
|
||||
|
||||
# Show the total order from the first honest node
|
||||
if orders and orders[0]:
|
||||
print(f"\n Total order (from {honest_nodes[0].name}):")
|
||||
first_node = honest_nodes[0]
|
||||
ordered = compute_order(first_node.graph, first_node.leader_stream)
|
||||
for v in ordered[:20]: # Show first 20
|
||||
print(f" pos={v.total_position:>3d} "
|
||||
f"hash={v.message_digest.hex()[:12]} "
|
||||
f"r={v.round} "
|
||||
f"payload={v.payload.decode(errors='replace')[:40]}")
|
||||
if len(ordered) > 20:
|
||||
print(f" ... and {len(ordered) - 20} more")
|
||||
|
||||
# Show leader stream
|
||||
if honest_nodes:
|
||||
ls = honest_nodes[0].leader_stream
|
||||
if ls.leaders:
|
||||
print(f"\n Leader stream ({honest_nodes[0].name}):")
|
||||
for round_num, (dec_round, msg) in sorted(ls.leaders.items()):
|
||||
print(f" round {round_num}: leader={msg.compute_digest().hex()[:12]} "
|
||||
f"decided_in_round={dec_round}")
|
||||
|
||||
print()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CLI entry point
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def main():
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Crisis Protocol Simulation",
|
||||
epilog="Demonstrates probabilistic total order convergence"
|
||||
)
|
||||
parser.add_argument("--nodes", type=int, default=3,
|
||||
help="Number of honest nodes (default: 3)")
|
||||
parser.add_argument("--byzantine", type=int, default=0,
|
||||
help="Number of byzantine nodes (default: 0)")
|
||||
parser.add_argument("--steps", type=int, default=10,
|
||||
help="Number of simulation steps (default: 10)")
|
||||
parser.add_argument("--pow-zeros", type=int, default=0,
|
||||
help="Min PoW leading zeros (default: 0 = no PoW)")
|
||||
parser.add_argument("--difficulty", type=int, default=2,
|
||||
help="Difficulty oracle constant (default: 2)")
|
||||
parser.add_argument("--seed", type=int, default=42,
|
||||
help="Random seed for reproducibility (default: 42)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print(f"Crisis Protocol Simulation")
|
||||
print(f" Honest nodes: {args.nodes}")
|
||||
print(f" Byzantine nodes: {args.byzantine}")
|
||||
print(f" Steps: {args.steps}")
|
||||
print(f" PoW zeros: {args.pow_zeros}")
|
||||
print(f" Difficulty: {args.difficulty}")
|
||||
print(f" Seed: {args.seed}")
|
||||
|
||||
sim = Simulation(
|
||||
num_honest=args.nodes,
|
||||
num_byzantine=args.byzantine,
|
||||
pow_zeros=args.pow_zeros,
|
||||
difficulty=args.difficulty,
|
||||
seed=args.seed,
|
||||
)
|
||||
|
||||
sim.run(num_steps=args.steps)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
414
src/crisis/gossip.py
Normal file
414
src/crisis/gossip.py
Normal file
|
|
@ -0,0 +1,414 @@
|
|||
"""
|
||||
Communication (Section 4)
|
||||
|
||||
Crisis is built on top of two simple push & pull gossip protocols:
|
||||
1. Member discovery gossip (Algorithm 3)
|
||||
2. Message gossip (Algorithm 4)
|
||||
|
||||
These are well suited for communication in unstructured P2P networks.
|
||||
All the system needs is a way to distribute messages in a byzantine-prone
|
||||
environment.
|
||||
|
||||
4.3 Member Discovery Gossip (Algorithm 3):
|
||||
Each process maintains a partial view Π_j(t) of the network.
|
||||
Periodically, a process pushes its neighbor list to a random peer
|
||||
and pulls neighbor lists from other peers.
|
||||
|
||||
4.4 Message Gossip (Algorithm 4):
|
||||
Processes push unordered messages to random peers and pull missing
|
||||
messages. Already ordered messages are pushed only as responses
|
||||
to pull requests (stop criterion for push gossip).
|
||||
|
||||
This module implements both gossip protocols using asyncio for the
|
||||
"run in parallel forever" loops described in the paper.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import struct
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from crisis.graph import LamportGraph
|
||||
from crisis.message import Message, Vertex, NONCE_LENGTH, ID_LENGTH
|
||||
from crisis.crypto import DIGEST_LENGTH
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Peer identity and network view
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass
|
||||
class PeerInfo:
|
||||
"""Information about a known peer in the network."""
|
||||
host: str
|
||||
port: int
|
||||
process_id: bytes = b"" # The peer's virtual process id, if known
|
||||
|
||||
@property
|
||||
def address(self) -> tuple[str, int]:
|
||||
return (self.host, self.port)
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.host, self.port))
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, PeerInfo):
|
||||
return NotImplemented
|
||||
return self.host == other.host and self.port == other.port
|
||||
|
||||
|
||||
@dataclass
|
||||
class NetworkView:
|
||||
"""Π_j(t): a process's partial view of the network at time t.
|
||||
|
||||
"No process must know the entire system and each j ∈ Π(t) might
|
||||
have a partial view Π_j(t) only." (Section 4.3)
|
||||
"""
|
||||
peers: set[PeerInfo] = field(default_factory=set)
|
||||
max_peers: int = 50 # Limit to prevent unbounded growth
|
||||
|
||||
def add_peer(self, peer: PeerInfo) -> None:
|
||||
if len(self.peers) < self.max_peers:
|
||||
self.peers.add(peer)
|
||||
|
||||
def remove_peer(self, peer: PeerInfo) -> None:
|
||||
self.peers.discard(peer)
|
||||
|
||||
def random_peer(self) -> Optional[PeerInfo]:
|
||||
if not self.peers:
|
||||
return None
|
||||
return random.choice(list(self.peers))
|
||||
|
||||
def random_subset(self, k: int) -> list[PeerInfo]:
|
||||
peers_list = list(self.peers)
|
||||
return random.sample(peers_list, min(k, len(peers_list)))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Message serialization for network transport
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def serialize_message(message: Message) -> bytes:
|
||||
"""Serialize a Message for network transmission.
|
||||
|
||||
Format: [total_length:4][nonce:8][id:32][num_digests:2][digests...][payload]
|
||||
"""
|
||||
body = message.serialize()
|
||||
length = len(body)
|
||||
return struct.pack("!I", length) + body
|
||||
|
||||
|
||||
def deserialize_message(data: bytes) -> Message:
|
||||
"""Deserialize a Message from network bytes.
|
||||
|
||||
Parses the fixed-size fields and reconstructs the Message object.
|
||||
"""
|
||||
offset = 0
|
||||
nonce = data[offset:offset + NONCE_LENGTH]
|
||||
offset += NONCE_LENGTH
|
||||
|
||||
id_bytes = data[offset:offset + ID_LENGTH]
|
||||
offset += ID_LENGTH
|
||||
|
||||
num_digests = int.from_bytes(data[offset:offset + 2], "big")
|
||||
offset += 2
|
||||
|
||||
digests = []
|
||||
for _ in range(num_digests):
|
||||
d = data[offset:offset + DIGEST_LENGTH]
|
||||
digests.append(d)
|
||||
offset += DIGEST_LENGTH
|
||||
|
||||
payload = data[offset:]
|
||||
|
||||
return Message(
|
||||
nonce=nonce,
|
||||
id=id_bytes,
|
||||
digests=tuple(digests),
|
||||
payload=payload,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Protocol message types
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Simple protocol: 1-byte type prefix
|
||||
MSG_TYPE_PUSH_MESSAGE = b"\x01" # Push a crisis message
|
||||
MSG_TYPE_PULL_REQUEST = b"\x02" # Request missing messages
|
||||
MSG_TYPE_PULL_RESPONSE = b"\x03" # Response with requested messages
|
||||
MSG_TYPE_PEER_PUSH = b"\x04" # Push peer list
|
||||
MSG_TYPE_PEER_PULL = b"\x05" # Request peer list
|
||||
MSG_TYPE_PEER_RESPONSE = b"\x06" # Response with peer list
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Gossip Server
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class GossipServer:
|
||||
"""Asyncio-based gossip server implementing Algorithms 3 and 4.
|
||||
|
||||
Runs two parallel loops:
|
||||
1. Member discovery push & pull (Algorithm 3)
|
||||
2. Message push & pull (Algorithm 4)
|
||||
|
||||
Plus a listener that handles incoming connections.
|
||||
"""
|
||||
|
||||
def __init__(self, host: str, port: int, graph: LamportGraph,
|
||||
network_view: NetworkView,
|
||||
push_interval: float = 2.0,
|
||||
discovery_interval: float = 5.0):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.graph = graph
|
||||
self.network_view = network_view
|
||||
self.push_interval = push_interval
|
||||
self.discovery_interval = discovery_interval
|
||||
self._server: Optional[asyncio.Server] = None
|
||||
self._running = False
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the gossip server and all gossip loops."""
|
||||
self._running = True
|
||||
self._server = await asyncio.start_server(
|
||||
self._handle_connection, self.host, self.port
|
||||
)
|
||||
logger.info(f"Gossip server listening on {self.host}:{self.port}")
|
||||
|
||||
# Run the gossip loops concurrently (paper: "run in parallel forever")
|
||||
await asyncio.gather(
|
||||
self._server.serve_forever(),
|
||||
self._discovery_push_loop(),
|
||||
self._message_push_loop(),
|
||||
)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the gossip server."""
|
||||
self._running = False
|
||||
if self._server:
|
||||
self._server.close()
|
||||
await self._server.wait_closed()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Algorithm 3: Member discovery push & pull
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _discovery_push_loop(self) -> None:
|
||||
"""Algorithm 3, lines 1-5: periodically push peer list to random peers."""
|
||||
while self._running:
|
||||
await asyncio.sleep(self.discovery_interval)
|
||||
|
||||
peer = self.network_view.random_peer()
|
||||
if peer is None:
|
||||
continue
|
||||
|
||||
try:
|
||||
await self._send_peer_push(peer)
|
||||
await self._send_peer_pull(peer)
|
||||
except (ConnectionError, OSError) as e:
|
||||
logger.debug(f"Discovery push to {peer.address} failed: {e}")
|
||||
self.network_view.remove_peer(peer)
|
||||
|
||||
async def _send_peer_push(self, peer: PeerInfo) -> None:
|
||||
"""Push our peer list to a remote peer."""
|
||||
peer_data = self._encode_peer_list(list(self.network_view.peers))
|
||||
await self._send_to_peer(peer, MSG_TYPE_PEER_PUSH + peer_data)
|
||||
|
||||
async def _send_peer_pull(self, peer: PeerInfo) -> None:
|
||||
"""Request a peer list from a remote peer."""
|
||||
response = await self._send_and_receive(peer, MSG_TYPE_PEER_PULL)
|
||||
if response and response[0:1] == MSG_TYPE_PEER_RESPONSE:
|
||||
new_peers = self._decode_peer_list(response[1:])
|
||||
for p in new_peers:
|
||||
if p.host != self.host or p.port != self.port:
|
||||
self.network_view.add_peer(p)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Algorithm 4: Message gossip push & pull
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _message_push_loop(self) -> None:
|
||||
"""Algorithm 4, lines 1-5: push unordered messages to random peers.
|
||||
|
||||
"Messages are retransmitted via push gossip, only if they don't
|
||||
have a total order yet." (Section 4.4)
|
||||
"""
|
||||
while self._running:
|
||||
await asyncio.sleep(self.push_interval)
|
||||
|
||||
peer = self.network_view.random_peer()
|
||||
if peer is None:
|
||||
continue
|
||||
|
||||
# Push messages that don't have total_position yet
|
||||
unordered = [
|
||||
v for v in self.graph.all_vertices()
|
||||
if v.total_position is None
|
||||
]
|
||||
|
||||
if not unordered:
|
||||
continue
|
||||
|
||||
try:
|
||||
for vertex in unordered:
|
||||
msg_bytes = serialize_message(vertex.m)
|
||||
await self._send_to_peer(
|
||||
peer, MSG_TYPE_PUSH_MESSAGE + msg_bytes
|
||||
)
|
||||
except (ConnectionError, OSError) as e:
|
||||
logger.debug(f"Message push to {peer.address} failed: {e}")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Connection handler (incoming)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _handle_connection(self, reader: asyncio.StreamReader,
|
||||
writer: asyncio.StreamWriter) -> None:
|
||||
"""Handle an incoming gossip connection.
|
||||
|
||||
Algorithm 3, lines 6-13 (peer data) and Algorithm 4, lines 6-13
|
||||
(message data).
|
||||
"""
|
||||
try:
|
||||
data = await asyncio.wait_for(reader.read(65536), timeout=10.0)
|
||||
if not data:
|
||||
return
|
||||
|
||||
msg_type = data[0:1]
|
||||
payload = data[1:]
|
||||
|
||||
if msg_type == MSG_TYPE_PUSH_MESSAGE:
|
||||
# Received a message: try to extend our Lamport graph
|
||||
self._handle_push_message(payload)
|
||||
|
||||
elif msg_type == MSG_TYPE_PULL_REQUEST:
|
||||
# Someone wants messages: send what we have
|
||||
response = self._handle_pull_request(payload)
|
||||
writer.write(response)
|
||||
await writer.drain()
|
||||
|
||||
elif msg_type == MSG_TYPE_PEER_PUSH:
|
||||
# Received a peer list: update our view
|
||||
new_peers = self._decode_peer_list(payload)
|
||||
for p in new_peers:
|
||||
if p.host != self.host or p.port != self.port:
|
||||
self.network_view.add_peer(p)
|
||||
|
||||
elif msg_type == MSG_TYPE_PEER_PULL:
|
||||
# Someone wants our peer list
|
||||
response = MSG_TYPE_PEER_RESPONSE + self._encode_peer_list(
|
||||
list(self.network_view.peers)
|
||||
)
|
||||
writer.write(response)
|
||||
await writer.drain()
|
||||
|
||||
except (asyncio.TimeoutError, ConnectionError):
|
||||
pass
|
||||
finally:
|
||||
writer.close()
|
||||
|
||||
def _handle_push_message(self, data: bytes) -> Optional[Vertex]:
|
||||
"""Process a pushed message: validate and extend graph if valid.
|
||||
|
||||
Algorithm 4, lines 7-8: "if MESSAGE_INTEGRITY(m, G) then
|
||||
expand G with vertex v, such that v.m = m"
|
||||
"""
|
||||
try:
|
||||
# Parse length prefix
|
||||
if len(data) < 4:
|
||||
return None
|
||||
length = struct.unpack("!I", data[:4])[0]
|
||||
msg_data = data[4:4 + length]
|
||||
|
||||
message = deserialize_message(msg_data)
|
||||
return self.graph.extend(message)
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to process pushed message: {e}")
|
||||
return None
|
||||
|
||||
def _handle_pull_request(self, data: bytes) -> bytes:
|
||||
"""Respond to a pull request with messages the requester is missing.
|
||||
|
||||
Algorithm 4, lines 10-11: "respond with appropriate set of messages"
|
||||
"""
|
||||
# Data contains a list of digests the requester already has
|
||||
known_digests = set()
|
||||
offset = 0
|
||||
while offset + DIGEST_LENGTH <= len(data):
|
||||
known_digests.add(data[offset:offset + DIGEST_LENGTH])
|
||||
offset += DIGEST_LENGTH
|
||||
|
||||
# Send messages the requester doesn't have
|
||||
response_parts = [MSG_TYPE_PULL_RESPONSE]
|
||||
for d, vertex in self.graph.vertices.items():
|
||||
if d not in known_digests:
|
||||
response_parts.append(serialize_message(vertex.m))
|
||||
return b"".join(response_parts)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Network I/O helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _send_to_peer(self, peer: PeerInfo, data: bytes) -> None:
|
||||
"""Send data to a peer (fire-and-forget)."""
|
||||
reader, writer = await asyncio.open_connection(peer.host, peer.port)
|
||||
writer.write(data)
|
||||
await writer.drain()
|
||||
writer.close()
|
||||
|
||||
async def _send_and_receive(self, peer: PeerInfo, data: bytes) -> Optional[bytes]:
|
||||
"""Send data and wait for a response."""
|
||||
try:
|
||||
reader, writer = await asyncio.open_connection(peer.host, peer.port)
|
||||
writer.write(data)
|
||||
await writer.drain()
|
||||
response = await asyncio.wait_for(reader.read(65536), timeout=5.0)
|
||||
writer.close()
|
||||
return response
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Peer list encoding
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _encode_peer_list(peers: list[PeerInfo]) -> bytes:
|
||||
"""Encode a list of peers as bytes: [count:2][host_len:1][host][port:2]..."""
|
||||
parts = [struct.pack("!H", len(peers))]
|
||||
for peer in peers:
|
||||
host_bytes = peer.host.encode("utf-8")
|
||||
parts.append(struct.pack("!B", len(host_bytes)))
|
||||
parts.append(host_bytes)
|
||||
parts.append(struct.pack("!H", peer.port))
|
||||
return b"".join(parts)
|
||||
|
||||
@staticmethod
|
||||
def _decode_peer_list(data: bytes) -> list[PeerInfo]:
|
||||
"""Decode a peer list from bytes."""
|
||||
if len(data) < 2:
|
||||
return []
|
||||
count = struct.unpack("!H", data[:2])[0]
|
||||
offset = 2
|
||||
peers = []
|
||||
for _ in range(count):
|
||||
if offset >= len(data):
|
||||
break
|
||||
host_len = data[offset]
|
||||
offset += 1
|
||||
host = data[offset:offset + host_len].decode("utf-8")
|
||||
offset += host_len
|
||||
port = struct.unpack("!H", data[offset:offset + 2])[0]
|
||||
offset += 2
|
||||
peers.append(PeerInfo(host=host, port=port))
|
||||
return peers
|
||||
479
src/crisis/graph.py
Normal file
479
src/crisis/graph.py
Normal file
|
|
@ -0,0 +1,479 @@
|
|||
"""
|
||||
Lamport Graphs (Section 3.2)
|
||||
|
||||
Lamport graphs represent the causal partial order between messages as a
|
||||
directed acyclic graph. They are the central data structure of the Crisis
|
||||
protocol -- all consensus state is derived from the graph structure.
|
||||
|
||||
Definition 3.5 (Lamport Graph):
|
||||
Let V ⊂ VERTEX be a finite set of vertices, such that all vertices
|
||||
v_hat with v_hat ≤ v for all v ∈ V are in V, but no two vertices in V
|
||||
are equivalent. Then the graph G = (V, A) with (v, v_hat) ∈ A if and
|
||||
only if v -> v_hat is called a *Lamport graph*.
|
||||
|
||||
Key properties:
|
||||
- Directed and acyclic (Proposition 3.6)
|
||||
- The past of a vertex is invariant across Lamport graphs (Theorem 3.7)
|
||||
- No two equivalent vertices exist in the same graph
|
||||
|
||||
This module implements:
|
||||
- Algorithm 1: Message generation
|
||||
- Algorithm 2: Message integrity checking and graph extension
|
||||
- Causality queries (past, future, timelike, spacelike)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from crisis.crypto import digest
|
||||
from crisis.message import Message, Vertex, Vote, ID_LENGTH, NONCE_LENGTH
|
||||
from crisis.weight import ProofOfWorkWeight, WeightSystem
|
||||
|
||||
|
||||
class LamportGraph:
|
||||
"""A Lamport graph: a DAG of vertices connected by causal acknowledgement.
|
||||
|
||||
The graph is stored as:
|
||||
- vertices: dict mapping message digest -> Vertex
|
||||
- edges: dict mapping digest -> set of digests it references
|
||||
(i.e. v -> v_hat means v acknowledges v_hat)
|
||||
|
||||
Invariants maintained:
|
||||
- No two vertices have the same underlying message (no equivalence)
|
||||
- All referenced digests either exist in the graph or are the empty digest
|
||||
- The graph is acyclic (guaranteed by hash function properties)
|
||||
"""
|
||||
|
||||
def __init__(self, weight_system: WeightSystem | None = None):
|
||||
self.weight_system: WeightSystem = weight_system or ProofOfWorkWeight(min_leading_zeros=0)
|
||||
|
||||
# digest -> Vertex
|
||||
self.vertices: dict[bytes, Vertex] = {}
|
||||
|
||||
# digest -> set of digests this vertex references (outgoing causal edges)
|
||||
# An edge v -> v_hat means "v acknowledges v_hat" i.e. H(v_hat.m) ∈ v.m.digests
|
||||
self.edges: dict[bytes, set[bytes]] = {}
|
||||
|
||||
# Reverse edges for efficient "future" queries
|
||||
# digest -> set of digests that reference this vertex
|
||||
self.reverse_edges: dict[bytes, set[bytes]] = {}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Graph queries
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.vertices)
|
||||
|
||||
def __contains__(self, digest_or_vertex) -> bool:
|
||||
if isinstance(digest_or_vertex, Vertex):
|
||||
return digest_or_vertex.message_digest in self.vertices
|
||||
return digest_or_vertex in self.vertices
|
||||
|
||||
def get_vertex(self, msg_digest: bytes) -> Optional[Vertex]:
|
||||
return self.vertices.get(msg_digest)
|
||||
|
||||
def all_vertices(self) -> list[Vertex]:
|
||||
return list(self.vertices.values())
|
||||
|
||||
def vertex_count(self) -> int:
|
||||
return len(self.vertices)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Causality (Definition 3.2)
|
||||
# ------------------------------------------------------------------
|
||||
# m -> m_hat (m happens before m_hat) iff:
|
||||
# - m = m_hat, OR
|
||||
# - there is a chain m -> m1 -> ... -> mk -> m_hat
|
||||
# In our DAG: v has an edge to v_hat means v acknowledges v_hat.
|
||||
# So v is in the *future* of v_hat, and v_hat is in the *past* of v.
|
||||
|
||||
def direct_causes(self, v: Vertex) -> list[Vertex]:
|
||||
"""Return the direct causes of v (vertices that v acknowledges).
|
||||
|
||||
These are the vertices whose digests appear in v.m.digests.
|
||||
In graph terms: the outgoing neighbors of v.
|
||||
"""
|
||||
result = []
|
||||
for d in self.edges.get(v.message_digest, set()):
|
||||
vertex = self.vertices.get(d)
|
||||
if vertex is not None:
|
||||
result.append(vertex)
|
||||
return result
|
||||
|
||||
def direct_effects(self, v: Vertex) -> list[Vertex]:
|
||||
"""Return the direct effects of v (vertices that acknowledge v).
|
||||
|
||||
In graph terms: the incoming neighbors of v (who references v).
|
||||
"""
|
||||
result = []
|
||||
for d in self.reverse_edges.get(v.message_digest, set()):
|
||||
vertex = self.vertices.get(d)
|
||||
if vertex is not None:
|
||||
result.append(vertex)
|
||||
return result
|
||||
|
||||
def past(self, v: Vertex) -> set[Vertex]:
|
||||
"""G_v: the subgraph of G containing all causes of v.
|
||||
|
||||
Definition 3.5: "the subgraph G_v of G that contains all causes
|
||||
of v is called the *past* of v".
|
||||
|
||||
Returns the set of all vertices that are causally before v
|
||||
(including v itself -- reflexivity).
|
||||
"""
|
||||
visited: set[bytes] = set()
|
||||
stack = [v.message_digest]
|
||||
|
||||
while stack:
|
||||
current = stack.pop()
|
||||
if current in visited:
|
||||
continue
|
||||
visited.add(current)
|
||||
for neighbor in self.edges.get(current, set()):
|
||||
if neighbor in self.vertices and neighbor not in visited:
|
||||
stack.append(neighbor)
|
||||
|
||||
return {self.vertices[d] for d in visited if d in self.vertices}
|
||||
|
||||
def future(self, v: Vertex) -> set[Vertex]:
|
||||
"""All vertices that are causally after v (including v itself)."""
|
||||
visited: set[bytes] = set()
|
||||
stack = [v.message_digest]
|
||||
|
||||
while stack:
|
||||
current = stack.pop()
|
||||
if current in visited:
|
||||
continue
|
||||
visited.add(current)
|
||||
for neighbor in self.reverse_edges.get(current, set()):
|
||||
if neighbor in self.vertices and neighbor not in visited:
|
||||
stack.append(neighbor)
|
||||
|
||||
return {self.vertices[d] for d in visited if d in self.vertices}
|
||||
|
||||
def is_cause_of(self, v: Vertex, v_hat: Vertex) -> bool:
|
||||
"""Check if v ≤ v_hat (v is in the past of v_hat).
|
||||
|
||||
Definition 3.4: v is said to happen before v_hat (v ≤ v_hat)
|
||||
if there is a causality chain from v to v_hat.
|
||||
"""
|
||||
if v == v_hat:
|
||||
return True
|
||||
return v in self.past(v_hat)
|
||||
|
||||
def are_timelike(self, v: Vertex, v_hat: Vertex) -> bool:
|
||||
"""Check if v and v_hat are timelike (comparable / causally related)."""
|
||||
return self.is_cause_of(v, v_hat) or self.is_cause_of(v_hat, v)
|
||||
|
||||
def are_spacelike(self, v: Vertex, v_hat: Vertex) -> bool:
|
||||
"""Check if v and v_hat are spacelike (incomparable / no causal relation).
|
||||
|
||||
Spacelike vertices are the ones that need the total order protocol
|
||||
to become comparable. The protocol extends the timelike partial
|
||||
order to cover spacelike vertices as well.
|
||||
"""
|
||||
return not self.are_timelike(v, v_hat)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Mutations (Definition 4.2)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def find_mutations(self, vertex_id: bytes) -> list[list[Vertex]]:
|
||||
"""Find mutations: vertices with the same id that are spacelike.
|
||||
|
||||
Definition 4.2: Two vertices v and v_hat in G are called a *mutation*
|
||||
of a virtual process if they have the same id and are spacelike,
|
||||
i.e. neither v ≤ v_hat nor v_hat ≤ v holds.
|
||||
|
||||
Mutations are the virtual voting equivalent of equivocation -- a
|
||||
byzantine actor sending different votes to different processes.
|
||||
|
||||
Returns a list of groups of mutually spacelike same-id vertices.
|
||||
"""
|
||||
# Group vertices by id
|
||||
by_id: dict[bytes, list[Vertex]] = {}
|
||||
for v in self.vertices.values():
|
||||
by_id.setdefault(v.id, []).append(v)
|
||||
|
||||
mutations = []
|
||||
for vid, group in by_id.items():
|
||||
if vid != vertex_id:
|
||||
continue
|
||||
# Find spacelike pairs within the group
|
||||
spacelike_group = []
|
||||
for i, v1 in enumerate(group):
|
||||
for v2 in group[i + 1:]:
|
||||
if self.are_spacelike(v1, v2):
|
||||
if v1 not in spacelike_group:
|
||||
spacelike_group.append(v1)
|
||||
if v2 not in spacelike_group:
|
||||
spacelike_group.append(v2)
|
||||
if spacelike_group:
|
||||
mutations.append(spacelike_group)
|
||||
return mutations
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Byte-level correctness (part of Algorithm 2)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _bytelevel_correctness(self, message: Message) -> bool:
|
||||
"""BYTELEVEL_CORRECTNESS: basic structural validation of a message.
|
||||
|
||||
Checks that the message has valid field lengths and is well-formed.
|
||||
"""
|
||||
if len(message.nonce) != NONCE_LENGTH:
|
||||
return False
|
||||
if len(message.id) != ID_LENGTH:
|
||||
return False
|
||||
for d in message.digests:
|
||||
if len(d) != 32: # DIGEST_LENGTH
|
||||
return False
|
||||
return True
|
||||
|
||||
def _payload_correctness(self, message: Message) -> bool:
|
||||
"""PAYLOAD_CORRECTNESS: validate the payload against system rules.
|
||||
|
||||
In this PoC, any payload is accepted. A real system would enforce
|
||||
application-specific validation here.
|
||||
"""
|
||||
return True
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Algorithm 2: Message integrity (Section 4.2)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def message_integrity(self, message: Message) -> bool:
|
||||
"""Check whether a message can be validly added to this Lamport graph.
|
||||
|
||||
Algorithm 2 from the paper:
|
||||
|
||||
1. Check BYTELEVEL_CORRECTNESS(m)
|
||||
2. Check w(m) > c_min (weight threshold)
|
||||
3. Check PAYLOAD_CORRECTNESS(m.payload)
|
||||
4. Check no equivalent vertex exists (no vertex with same digest)
|
||||
5. For each digest in m.digests:
|
||||
- It must reference a vertex in G
|
||||
- All referenced vertices must have different id's
|
||||
6. If there is a vertex v in G with v.id = m.id:
|
||||
- One of m.digests must reference v (or a vertex in v's past)
|
||||
Ensures the virtual process forms a chain, not a tree.
|
||||
|
||||
Returns True if the message passes integrity checks.
|
||||
"""
|
||||
# Step 1: byte-level structure
|
||||
if not self._bytelevel_correctness(message):
|
||||
return False
|
||||
|
||||
# Step 2: weight threshold
|
||||
if not self.weight_system.is_valid_weight(message):
|
||||
return False
|
||||
|
||||
# Step 3: payload rules
|
||||
if not self._payload_correctness(message):
|
||||
return False
|
||||
|
||||
msg_digest = message.compute_digest()
|
||||
|
||||
# Step 4: no duplicate (no equivalent vertex)
|
||||
if msg_digest in self.vertices:
|
||||
return False
|
||||
|
||||
# Step 5: all referenced digests must exist in G
|
||||
# and all referenced vertices must have different ids
|
||||
referenced_ids: set[bytes] = set()
|
||||
for ref_digest in message.digests:
|
||||
if ref_digest not in self.vertices:
|
||||
return False
|
||||
ref_vertex = self.vertices[ref_digest]
|
||||
if ref_vertex.id in referenced_ids:
|
||||
return False # Two references to same id
|
||||
referenced_ids.add(ref_vertex.id)
|
||||
|
||||
# Step 6: if same id exists, must reference it (chain constraint)
|
||||
# Find the "last vertex" with this id (not referenced by any other
|
||||
# vertex with the same id)
|
||||
same_id_vertices = [v for v in self.vertices.values() if v.id == message.id]
|
||||
if same_id_vertices:
|
||||
# Check that at least one digest references a same-id vertex
|
||||
referenced_digests = set(message.digests)
|
||||
found_chain_link = False
|
||||
for v in same_id_vertices:
|
||||
if v.message_digest in referenced_digests:
|
||||
found_chain_link = True
|
||||
break
|
||||
if not found_chain_link:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lamport graph extension (Section 4.2)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def extend(self, message: Message) -> Optional[Vertex]:
|
||||
"""Attempt to extend the Lamport graph with a new message.
|
||||
|
||||
If the message passes integrity checks (Algorithm 2), create a new
|
||||
vertex and add it to the graph with appropriate edges.
|
||||
|
||||
Proposition 4.1 guarantees that the extension of a Lamport graph
|
||||
by a valid message is itself a Lamport graph.
|
||||
|
||||
Returns the new Vertex if successful, None if integrity check fails.
|
||||
"""
|
||||
if not self.message_integrity(message):
|
||||
return None
|
||||
|
||||
vertex = Vertex(m=message)
|
||||
msg_digest = message.compute_digest()
|
||||
|
||||
# Add vertex
|
||||
self.vertices[msg_digest] = vertex
|
||||
|
||||
# Add edges: this vertex -> each referenced vertex
|
||||
self.edges[msg_digest] = set()
|
||||
for ref_digest in message.digests:
|
||||
self.edges[msg_digest].add(ref_digest)
|
||||
# Reverse edge
|
||||
if ref_digest not in self.reverse_edges:
|
||||
self.reverse_edges[ref_digest] = set()
|
||||
self.reverse_edges[ref_digest].add(msg_digest)
|
||||
|
||||
# Initialize reverse_edges entry for this vertex
|
||||
if msg_digest not in self.reverse_edges:
|
||||
self.reverse_edges[msg_digest] = set()
|
||||
|
||||
return vertex
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Algorithm 1: Message generation (Section 4.1)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def generate_message(self, process_id: bytes, payload: bytes,
|
||||
weight_system: WeightSystem | None = None) -> Message:
|
||||
"""Generate a valid message for a given virtual process id.
|
||||
|
||||
Algorithm 1 from the paper:
|
||||
1. Find the last vertex v with v.id = id in G
|
||||
2. Choose S ⊂ {v.m | v ∈ G ∧ v ∉ G_v} such that all have different ids
|
||||
3. Return message with digests = {H(v.m)} ∪ {H(m) | m ∈ S ∪ {v.m}}
|
||||
|
||||
The nonce is chosen so that w(m) > c_min (via mining if PoW).
|
||||
"""
|
||||
ws = weight_system or self.weight_system
|
||||
|
||||
# Find the last vertex with this process id
|
||||
last_vertex = self._find_last_vertex(process_id)
|
||||
|
||||
# Collect digests: last same-id vertex + a sample of other vertices
|
||||
digests_list: list[bytes] = []
|
||||
|
||||
if last_vertex is not None:
|
||||
# Must reference the last vertex with same id
|
||||
digests_list.append(last_vertex.message_digest)
|
||||
|
||||
# Add cross-references to vertices NOT in last_vertex's past
|
||||
past_digests = {v.message_digest for v in self.past(last_vertex)}
|
||||
candidates = [
|
||||
v for d, v in self.vertices.items()
|
||||
if d not in past_digests
|
||||
and v.id != process_id
|
||||
and d != last_vertex.message_digest
|
||||
]
|
||||
|
||||
# Include candidates with different ids
|
||||
seen_ids: set[bytes] = {process_id}
|
||||
for candidate in candidates:
|
||||
if candidate.id not in seen_ids:
|
||||
digests_list.append(candidate.message_digest)
|
||||
seen_ids.add(candidate.id)
|
||||
else:
|
||||
# First message for this id: reference a sample of existing vertices
|
||||
seen_ids = {process_id}
|
||||
for v in self.vertices.values():
|
||||
if v.id not in seen_ids:
|
||||
digests_list.append(v.message_digest)
|
||||
seen_ids.add(v.id)
|
||||
|
||||
digests_tuple = tuple(digests_list)
|
||||
|
||||
# Mine a valid nonce (or just find one that meets threshold)
|
||||
if isinstance(ws, ProofOfWorkWeight):
|
||||
message = ws.mine_nonce(process_id, digests_tuple, payload)
|
||||
else:
|
||||
# For non-PoW systems, use a random nonce
|
||||
nonce = os.urandom(NONCE_LENGTH)
|
||||
message = Message(nonce=nonce, id=process_id, digests=digests_tuple, payload=payload)
|
||||
|
||||
return message
|
||||
|
||||
def _find_last_vertex(self, process_id: bytes) -> Optional[Vertex]:
|
||||
"""Find the last vertex with a given process id.
|
||||
|
||||
A vertex is "last" for an id if no other vertex with the same id
|
||||
references it (i.e. it has no same-id successor).
|
||||
"""
|
||||
same_id = [v for v in self.vertices.values() if v.id == process_id]
|
||||
if not same_id:
|
||||
return None
|
||||
|
||||
# Find the one that is not referenced by any other same-id vertex
|
||||
referenced_by_same_id: set[bytes] = set()
|
||||
for v in same_id:
|
||||
for d in v.digests:
|
||||
ref = self.vertices.get(d)
|
||||
if ref is not None and ref.id == process_id:
|
||||
referenced_by_same_id.add(d)
|
||||
|
||||
for v in same_id:
|
||||
if v.message_digest not in referenced_by_same_id:
|
||||
return v
|
||||
|
||||
# Fallback: return the one added most recently (by convention)
|
||||
return same_id[-1]
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Vertices by id (for virtual process queries)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def vertices_by_id(self, process_id: bytes) -> list[Vertex]:
|
||||
"""Return all vertices belonging to a given virtual process id."""
|
||||
return [v for v in self.vertices.values() if v.id == process_id]
|
||||
|
||||
def all_process_ids(self) -> set[bytes]:
|
||||
"""Return all unique virtual process ids in this graph."""
|
||||
return {v.id for v in self.vertices.values()}
|
||||
|
||||
def last_vertices_by_id(self) -> dict[bytes, Vertex]:
|
||||
"""Return the last vertex for each virtual process id."""
|
||||
result = {}
|
||||
for pid in self.all_process_ids():
|
||||
last = self._find_last_vertex(pid)
|
||||
if last is not None:
|
||||
result[pid] = last
|
||||
return result
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Weight queries
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def vertex_weight(self, v: Vertex) -> int:
|
||||
"""w(v) = w(v.m): the weight of a vertex is the weight of its message."""
|
||||
return self.weight_system.weight(v.m)
|
||||
|
||||
def set_weight(self, vertices: set[Vertex] | list[Vertex]) -> int:
|
||||
"""w(M) := ⊕_{m ∈ M} w(m): the combined weight of a set of vertices."""
|
||||
total = 0
|
||||
for v in vertices:
|
||||
total = self.weight_system.weight_sum(total, self.vertex_weight(v))
|
||||
return total
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Display
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"LamportGraph(vertices={len(self.vertices)}, ids={len(self.all_process_ids())})"
|
||||
250
src/crisis/message.py
Normal file
250
src/crisis/message.py
Normal file
|
|
@ -0,0 +1,250 @@
|
|||
"""
|
||||
Data Structures (Section 3)
|
||||
|
||||
3.1 Messages
|
||||
-------------
|
||||
Messages distribute payload across the network. The purpose of the protocol
|
||||
is to establish a total order on those messages that respects causality.
|
||||
|
||||
A message is a byte string of variable length with the following structure
|
||||
(paper, page 3):
|
||||
|
||||
struct Message {
|
||||
byte[c1] nonce,
|
||||
byte[c2] id,
|
||||
byte[c3] num_digests,
|
||||
byte[p * num_digests] digests,
|
||||
byte[] payload
|
||||
}
|
||||
|
||||
Where c1, c2, c3 are fixed protocol constants and p is the digest length.
|
||||
|
||||
The *nonce* is used by the weight function (e.g. PoW grinding).
|
||||
The *id* groups messages into virtual processes.
|
||||
The *digests* field encodes causal acknowledgement of other messages.
|
||||
|
||||
Key insight: a message that acknowledges other messages defines an inherent
|
||||
natural causality -- this is the Lamport "happens-before" relation (1978).
|
||||
|
||||
m -> m_hat iff H(m_hat) is contained in m.digests (Eq. 2)
|
||||
|
||||
3.1.3 Vertices
|
||||
---------------
|
||||
To establish total order, messages are extended by local voting data that is
|
||||
NOT transmitted. Votes are deduced from the causal relation between messages.
|
||||
This is the key characteristic of virtual voting (Moser & Melliar-Smith).
|
||||
|
||||
struct Vertex {
|
||||
Message m,
|
||||
Option<uint> round,
|
||||
Option<boolean> is_last,
|
||||
Option<TotalOrderSet<uint>> svp, # safe voting pattern
|
||||
Option<(Message, Option<bool>)> vote,
|
||||
Option<uint> total_position
|
||||
}
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from crisis.crypto import digest, DIGEST_LENGTH
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Protocol constants (c1, c2, c3 from the paper)
|
||||
# ---------------------------------------------------------------------------
|
||||
# These define the byte-lengths of the fixed-size fields in a message.
|
||||
# Chosen for a practical PoC: generous enough for real use, compact enough
|
||||
# for clarity.
|
||||
|
||||
NONCE_LENGTH = 8 # c1: 8 bytes of nonce (plenty for PoW search space)
|
||||
ID_LENGTH = 32 # c2: 32 bytes for virtual process id (a hash)
|
||||
NUM_DIGESTS_LENGTH = 2 # c3: 2 bytes => up to 65535 referenced digests
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Message
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Message:
|
||||
"""An immutable Crisis message as defined in Section 3.1.
|
||||
|
||||
A message is the atomic unit of communication in the Crisis protocol.
|
||||
It carries a payload and encodes causal history through its digests field.
|
||||
|
||||
Attributes:
|
||||
nonce: Used by the weight function (e.g. PoW nonce grinding).
|
||||
id: Groups this message into a virtual process.
|
||||
digests: Tuple of digests of causally prior messages (H values).
|
||||
payload: The actual application data being ordered.
|
||||
"""
|
||||
nonce: bytes
|
||||
id: bytes
|
||||
digests: tuple[bytes, ...] = ()
|
||||
payload: bytes = b""
|
||||
|
||||
def __post_init__(self):
|
||||
if len(self.nonce) != NONCE_LENGTH:
|
||||
raise ValueError(f"nonce must be {NONCE_LENGTH} bytes, got {len(self.nonce)}")
|
||||
if len(self.id) != ID_LENGTH:
|
||||
raise ValueError(f"id must be {ID_LENGTH} bytes, got {len(self.id)}")
|
||||
for i, d in enumerate(self.digests):
|
||||
if len(d) != DIGEST_LENGTH:
|
||||
raise ValueError(f"digest[{i}] must be {DIGEST_LENGTH} bytes")
|
||||
|
||||
def serialize(self) -> bytes:
|
||||
"""Serialize this message to a canonical byte string.
|
||||
|
||||
The serialized form is what gets hashed to produce the message's digest.
|
||||
Format: nonce | id | num_digests (2 bytes big-endian) | digests... | payload
|
||||
"""
|
||||
num = len(self.digests)
|
||||
parts = [
|
||||
self.nonce,
|
||||
self.id,
|
||||
num.to_bytes(NUM_DIGESTS_LENGTH, "big"),
|
||||
]
|
||||
for d in self.digests:
|
||||
parts.append(d)
|
||||
parts.append(self.payload)
|
||||
return b"".join(parts)
|
||||
|
||||
def compute_digest(self) -> bytes:
|
||||
"""Compute H(m) -- the digest of this message.
|
||||
|
||||
This is the value other messages include in their digests field
|
||||
to acknowledge this message (establishing causality, Eq. 2).
|
||||
"""
|
||||
return digest(self.serialize())
|
||||
|
||||
@property
|
||||
def num_digests(self) -> int:
|
||||
return len(self.digests)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
h = self.compute_digest().hex()[:12]
|
||||
return f"Message(id={self.id.hex()[:8]}..., digests={self.num_digests}, hash={h}...)"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# The empty message (paper: ∅ ∈ MESSAGE)
|
||||
# ---------------------------------------------------------------------------
|
||||
# "We postulate a special non-message ∅ ∈ MESSAGE" (Section 3.1)
|
||||
# Acknowledgement of ∅ is defined as H(empty string).
|
||||
|
||||
EMPTY_MESSAGE_DIGEST = digest(b"")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Vote
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Vote:
|
||||
"""A virtual vote as computed locally by each vertex.
|
||||
|
||||
From the paper (Algorithm 7): v.vote(r) = (l, b) describes v's vote
|
||||
on some message l, together with a possibly undecided binary value
|
||||
b ∈ {⊥, 0, 1} in a round r.
|
||||
|
||||
Attributes:
|
||||
message: The message l being voted on (None = ∅, the non-leader).
|
||||
binary: The binary part of the vote: None=⊥ (undecided), 0, or 1.
|
||||
"""
|
||||
message: Optional[Message] = None
|
||||
binary: Optional[int] = None # None = ⊥, 0, or 1
|
||||
|
||||
def __repr__(self) -> str:
|
||||
msg_str = "∅" if self.message is None else self.message.compute_digest().hex()[:8]
|
||||
bin_str = "⊥" if self.binary is None else str(self.binary)
|
||||
return f"Vote({msg_str}, {bin_str})"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Vertex
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass
|
||||
class Vertex:
|
||||
"""A vertex in a Lamport graph (Section 3.1.3).
|
||||
|
||||
A vertex wraps a message and adds locally-computed consensus state.
|
||||
The additional fields (round, is_last, svp, vote, total_position) are
|
||||
never transmitted -- they are deduced from the causal structure.
|
||||
|
||||
From the paper (page 5, Eq. 6):
|
||||
w(v) <- w(v.m)
|
||||
v.nonce <- v.m.nonce
|
||||
v.id <- v.m.id
|
||||
v.num_digests <- v.m.num_digests
|
||||
v.digests <- v.m.digests
|
||||
v.payload <- v.m.payload
|
||||
|
||||
Attributes:
|
||||
m: The underlying message.
|
||||
round: The virtual round number (Algorithm 5).
|
||||
is_last: Whether this is a "last vertex" of its round (Alg 5).
|
||||
svp: Safe voting pattern -- ordered set of round numbers.
|
||||
vote: Per-round votes: round -> Vote.
|
||||
total_position: Final position in the total order (Algorithm 9/10).
|
||||
"""
|
||||
m: Message
|
||||
|
||||
# Locally computed consensus state (initialized to None / ⊥)
|
||||
round: Optional[int] = None
|
||||
is_last: Optional[bool] = None
|
||||
svp: list[int] = field(default_factory=list)
|
||||
vote: dict[int, Vote] = field(default_factory=dict)
|
||||
total_position: Optional[int] = None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Convenience accessors that delegate to the underlying message
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@property
|
||||
def nonce(self) -> bytes:
|
||||
return self.m.nonce
|
||||
|
||||
@property
|
||||
def id(self) -> bytes:
|
||||
return self.m.id
|
||||
|
||||
@property
|
||||
def digests(self) -> tuple[bytes, ...]:
|
||||
return self.m.digests
|
||||
|
||||
@property
|
||||
def payload(self) -> bytes:
|
||||
return self.m.payload
|
||||
|
||||
@property
|
||||
def message_digest(self) -> bytes:
|
||||
"""H(v.m) -- the digest that uniquely identifies this vertex's message."""
|
||||
return self.m.compute_digest()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Equivalence (Definition 3.3)
|
||||
# ------------------------------------------------------------------
|
||||
# "Two vertices v and v_hat are equivalent if v.m = v_hat.m"
|
||||
# i.e. they wrap the same underlying message.
|
||||
|
||||
def equivalent_to(self, other: Vertex) -> bool:
|
||||
"""Check vertex equivalence: same underlying message."""
|
||||
return self.message_digest == other.message_digest
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, Vertex):
|
||||
return NotImplemented
|
||||
return self.message_digest == other.message_digest
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.message_digest)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
h = self.message_digest.hex()[:12]
|
||||
round_str = str(self.round) if self.round is not None else "?"
|
||||
last_str = "*" if self.is_last else ""
|
||||
return f"Vertex({h}..., r={round_str}{last_str})"
|
||||
336
src/crisis/node.py
Normal file
336
src/crisis/node.py
Normal file
|
|
@ -0,0 +1,336 @@
|
|||
"""
|
||||
Crisis Node (Section 5.9 -- The Crisis Protocol)
|
||||
|
||||
This module ties all components together into a full Crisis node.
|
||||
|
||||
From the paper (Section 5.9):
|
||||
"The overall algorithm works as follows: Member discovery (3) and
|
||||
message gossip (4) are executed in infinite loops, concurrently to
|
||||
the rest of the system. Ideally the message sending loop is executed
|
||||
on as many parallel threads as possible. This implies that an overall
|
||||
unbounded amount of new messages arrive over time due to our liveness
|
||||
assumption. In addition each process may generate messages and write
|
||||
them into its own Lamport graph."
|
||||
|
||||
The full node runs these concurrent loops:
|
||||
1. Gossip: member discovery + message dissemination
|
||||
2. Message generation: create new messages with PoW
|
||||
3. Consensus: compute rounds, voting patterns, leader elections, order
|
||||
|
||||
Each loop runs independently and they communicate through the shared
|
||||
Lamport graph.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from crisis.crypto import digest
|
||||
from crisis.graph import LamportGraph
|
||||
from crisis.gossip import GossipServer, NetworkView, PeerInfo
|
||||
from crisis.message import Message, Vertex, ID_LENGTH, NONCE_LENGTH
|
||||
from crisis.order import LeaderStream, compute_order
|
||||
from crisis.rounds import compute_rounds
|
||||
from crisis.voting import compute_virtual_leader_election, compute_safe_voting_pattern
|
||||
from crisis.weight import ProofOfWorkWeight, DifficultyOracle
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CrisisNode:
|
||||
"""A full Crisis protocol node.
|
||||
|
||||
Combines all protocol components into a single running process:
|
||||
- Lamport graph (the shared DAG)
|
||||
- Weight system (PoW)
|
||||
- Difficulty oracle
|
||||
- Gossip server (member discovery + message dissemination)
|
||||
- Consensus engine (rounds, voting, ordering)
|
||||
|
||||
Attributes:
|
||||
process_id: This node's virtual process identity.
|
||||
graph: The local Lamport graph.
|
||||
leader_stream: The evolving total order leader stream.
|
||||
network_view: Known peers in the network.
|
||||
"""
|
||||
|
||||
def __init__(self, host: str = "127.0.0.1", port: int = 9000,
|
||||
min_pow_zeros: int = 1,
|
||||
difficulty_constant: int = 4,
|
||||
connectivity_k: int = 2,
|
||||
message_interval: float = 3.0,
|
||||
consensus_interval: float = 5.0,
|
||||
seed_peers: list[tuple[str, int]] | None = None):
|
||||
# Identity: use a hash of host:port as this node's virtual process id
|
||||
self.process_id = digest(f"{host}:{port}".encode())[:ID_LENGTH]
|
||||
self.host = host
|
||||
self.port = port
|
||||
|
||||
# Protocol components
|
||||
self.weight_system = ProofOfWorkWeight(min_leading_zeros=min_pow_zeros)
|
||||
self.difficulty = DifficultyOracle(constant_difficulty=difficulty_constant)
|
||||
self.connectivity_k = connectivity_k
|
||||
self.graph = LamportGraph(weight_system=self.weight_system)
|
||||
self.leader_stream = LeaderStream()
|
||||
|
||||
# Timing
|
||||
self.message_interval = message_interval
|
||||
self.consensus_interval = consensus_interval
|
||||
|
||||
# Network
|
||||
self.network_view = NetworkView()
|
||||
if seed_peers:
|
||||
for h, p in seed_peers:
|
||||
self.network_view.add_peer(PeerInfo(host=h, port=p))
|
||||
|
||||
self.gossip = GossipServer(
|
||||
host=host, port=port,
|
||||
graph=self.graph,
|
||||
network_view=self.network_view,
|
||||
)
|
||||
|
||||
# State
|
||||
self._running = False
|
||||
self._message_count = 0
|
||||
|
||||
# Callbacks for monitoring
|
||||
self.on_new_vertex: Optional[callable] = None
|
||||
self.on_round_update: Optional[callable] = None
|
||||
self.on_order_update: Optional[callable] = None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Main entry point
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def run(self) -> None:
|
||||
"""Start all protocol loops concurrently.
|
||||
|
||||
This is the Crisis protocol (Section 5.9): three concurrent loops.
|
||||
"""
|
||||
self._running = True
|
||||
logger.info(
|
||||
f"Crisis node starting on {self.host}:{self.port} "
|
||||
f"(id={self.process_id.hex()[:16]}...)"
|
||||
)
|
||||
|
||||
try:
|
||||
await asyncio.gather(
|
||||
self._gossip_loop(),
|
||||
self._message_generation_loop(),
|
||||
self._consensus_loop(),
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Crisis node shutting down")
|
||||
finally:
|
||||
self._running = False
|
||||
|
||||
async def stop(self) -> None:
|
||||
self._running = False
|
||||
await self.gossip.stop()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Loop 1: Gossip (Algorithms 3 & 4)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _gossip_loop(self) -> None:
|
||||
"""Run the gossip server (member discovery + message dissemination)."""
|
||||
try:
|
||||
await self.gossip.start()
|
||||
except Exception as e:
|
||||
logger.error(f"Gossip loop error: {e}")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Loop 2: Message generation (Algorithm 1)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _message_generation_loop(self) -> None:
|
||||
"""Periodically generate new messages and add them to the graph.
|
||||
|
||||
Each message:
|
||||
1. References the last same-id message (chain constraint)
|
||||
2. References a sample of other vertices (cross-links for connectivity)
|
||||
3. Has a PoW nonce meeting the weight threshold
|
||||
4. Carries an application payload
|
||||
"""
|
||||
while self._running:
|
||||
await asyncio.sleep(self.message_interval)
|
||||
|
||||
try:
|
||||
payload = self._generate_payload()
|
||||
message = self.graph.generate_message(
|
||||
self.process_id, payload, self.weight_system
|
||||
)
|
||||
vertex = self.graph.extend(message)
|
||||
|
||||
if vertex is not None:
|
||||
self._message_count += 1
|
||||
logger.debug(
|
||||
f"Generated message #{self._message_count}: {vertex}"
|
||||
)
|
||||
if self.on_new_vertex:
|
||||
self.on_new_vertex(vertex)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Message generation error: {e}")
|
||||
|
||||
def _generate_payload(self) -> bytes:
|
||||
"""Generate a payload for a new message.
|
||||
|
||||
In this PoC, payloads are simple timestamped entries.
|
||||
A real application would put actual data here.
|
||||
"""
|
||||
self._message_count += 1
|
||||
return f"msg-{self._message_count}-{time.time():.3f}".encode()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Loop 3: Consensus (Algorithms 5, 6, 7, 9, 10)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _consensus_loop(self) -> None:
|
||||
"""Periodically recompute consensus state.
|
||||
|
||||
From Section 5.9 and the proof section (Section 6):
|
||||
"algorithms (5), (6) and (7) are executed in that order concurrently
|
||||
on each vertex from V... the total order loop (9) runs concurrently
|
||||
and waits for updates of the leader stream."
|
||||
"""
|
||||
while self._running:
|
||||
await asyncio.sleep(self.consensus_interval)
|
||||
|
||||
if self.graph.vertex_count() == 0:
|
||||
continue
|
||||
|
||||
try:
|
||||
# Step 1: Compute rounds (Algorithm 5)
|
||||
compute_rounds(self.graph, self.difficulty, self.connectivity_k)
|
||||
|
||||
if self.on_round_update:
|
||||
self.on_round_update(self.graph)
|
||||
|
||||
# Step 2: Compute safe voting patterns (Algorithm 6)
|
||||
for vertex in self.graph.all_vertices():
|
||||
if vertex.is_last:
|
||||
compute_safe_voting_pattern(
|
||||
vertex, self.graph, self.difficulty,
|
||||
self.connectivity_k
|
||||
)
|
||||
|
||||
# Step 3: Virtual leader election (Algorithm 7)
|
||||
leader_dict: dict[int, list[tuple[int, Message]]] = {}
|
||||
for vertex in self.graph.all_vertices():
|
||||
if vertex.svp:
|
||||
compute_virtual_leader_election(
|
||||
vertex, self.graph, self.difficulty,
|
||||
self.connectivity_k, leader_dict
|
||||
)
|
||||
|
||||
# Update leader stream from election results
|
||||
for round_num, entries in leader_dict.items():
|
||||
for deciding_round, leader_msg in entries:
|
||||
self.leader_stream.update(
|
||||
round_num, deciding_round, leader_msg
|
||||
)
|
||||
|
||||
# Step 4: Compute total order (Algorithms 9 & 10)
|
||||
ordered = compute_order(self.graph, self.leader_stream)
|
||||
|
||||
if ordered and self.on_order_update:
|
||||
self.on_order_update(ordered)
|
||||
|
||||
logger.debug(
|
||||
f"Consensus: {self.graph.vertex_count()} vertices, "
|
||||
f"{len(self.leader_stream.leaders)} leaders, "
|
||||
f"{len(ordered)} ordered"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Consensus loop error: {e}")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def submit_message(self, payload: bytes) -> Optional[Vertex]:
|
||||
"""Submit an application message to be ordered by the protocol."""
|
||||
message = self.graph.generate_message(
|
||||
self.process_id, payload, self.weight_system
|
||||
)
|
||||
return self.graph.extend(message)
|
||||
|
||||
def get_total_order(self) -> list[tuple[int, bytes]]:
|
||||
"""Get the current total order as (position, payload) pairs."""
|
||||
ordered = compute_order(self.graph, self.leader_stream)
|
||||
return [
|
||||
(v.total_position, v.payload)
|
||||
for v in ordered
|
||||
if v.total_position is not None
|
||||
]
|
||||
|
||||
def status(self) -> dict:
|
||||
"""Return a summary of this node's current state."""
|
||||
from crisis.rounds import max_round as get_max_round
|
||||
return {
|
||||
"process_id": self.process_id.hex()[:16],
|
||||
"address": f"{self.host}:{self.port}",
|
||||
"vertices": self.graph.vertex_count(),
|
||||
"process_ids": len(self.graph.all_process_ids()),
|
||||
"max_round": get_max_round(self.graph),
|
||||
"leaders": len(self.leader_stream.leaders),
|
||||
"peers": len(self.network_view.peers),
|
||||
"messages_generated": self._message_count,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CLI entry point
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def main():
|
||||
"""Run a Crisis node from the command line."""
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Crisis Protocol Node",
|
||||
epilog="Probabilistically self-organizing total order in P2P networks"
|
||||
)
|
||||
parser.add_argument("--host", default="127.0.0.1", help="Listen address")
|
||||
parser.add_argument("--port", type=int, default=9000, help="Listen port")
|
||||
parser.add_argument("--pow-zeros", type=int, default=1,
|
||||
help="Min PoW leading zeros (weight threshold)")
|
||||
parser.add_argument("--difficulty", type=int, default=4,
|
||||
help="Difficulty oracle constant")
|
||||
parser.add_argument("--msg-interval", type=float, default=3.0,
|
||||
help="Seconds between message generation")
|
||||
parser.add_argument("--peers", nargs="*", default=[],
|
||||
help="Seed peers as host:port")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
seed_peers = []
|
||||
for peer_str in args.peers:
|
||||
h, p = peer_str.rsplit(":", 1)
|
||||
seed_peers.append((h, int(p)))
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s [%(name)s] %(levelname)s: %(message)s"
|
||||
)
|
||||
|
||||
node = CrisisNode(
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
min_pow_zeros=args.pow_zeros,
|
||||
difficulty_constant=args.difficulty,
|
||||
seed_peers=seed_peers,
|
||||
message_interval=args.msg_interval,
|
||||
)
|
||||
|
||||
asyncio.run(node.run())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
254
src/crisis/order.py
Normal file
254
src/crisis/order.py
Normal file
|
|
@ -0,0 +1,254 @@
|
|||
"""
|
||||
Total Order (Section 5.8)
|
||||
|
||||
As time goes by and the Lamport graph grows, more and more round leaders
|
||||
are computed and incorporated into the global leader stream LEADER_G(·).
|
||||
|
||||
Algorithm 9 (Order loop): watches for leader stream updates and recomputes
|
||||
total order. Total order is achieved by topological sorting on the past
|
||||
of appropriate vertices.
|
||||
|
||||
Algorithm 10 (Total order using Kahn's algorithm): generates total order
|
||||
in linear runtime by processing vertices without outgoing causal edges first,
|
||||
using voting weight to break ties among spacelike vertices.
|
||||
|
||||
The total order converges probabilistically: any two non-byzantine processes
|
||||
will eventually compute the same total order (Proposition 6.21).
|
||||
|
||||
Definition 5.17 (Leader Stream):
|
||||
LEADER_G : N -> Option<(uint, MESSAGE)>
|
||||
is called the *global leader stream* of the Lamport graph.
|
||||
|
||||
Corollary 6.19 (Leader stream convergence):
|
||||
If the probability for new rounds and safe voting pattern is not zero,
|
||||
the leader streams of any two honest processes will converge.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from crisis.graph import LamportGraph
|
||||
from crisis.message import Message, Vertex
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Leader Stream (Definition 5.17)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class LeaderStream:
|
||||
"""The global leader stream of a Lamport graph.
|
||||
|
||||
Maps round numbers to (deciding_round, leader_message) pairs.
|
||||
Uses the Nakamoto longest chain rule: when a new leader is decided
|
||||
in a later round, it may replace leaders decided in earlier rounds.
|
||||
|
||||
The leader stream converges to contain a single element per round
|
||||
(Theorem 6.18), and honest processes' leader streams converge to
|
||||
the same values (Corollary 6.19).
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# round_number -> (deciding_round, leader_message)
|
||||
self.leaders: dict[int, tuple[int, Message]] = {}
|
||||
|
||||
def update(self, round_number: int, deciding_round: int,
|
||||
leader_message: Message) -> bool:
|
||||
"""Update the leader for a round via the Nakamoto longest chain rule.
|
||||
|
||||
Algorithm 8 (LONG_CHAIN): keep only the leader decided in the
|
||||
highest round. Delete leaders from previous rounds that have
|
||||
lower deciding rounds.
|
||||
|
||||
Returns True if the leader stream was modified.
|
||||
"""
|
||||
current = self.leaders.get(round_number)
|
||||
|
||||
if current is not None:
|
||||
existing_deciding_round, _ = current
|
||||
if existing_deciding_round >= deciding_round:
|
||||
return False # Already have a leader from a higher round
|
||||
|
||||
self.leaders[round_number] = (deciding_round, leader_message)
|
||||
|
||||
# Prune: remove leaders with lower deciding rounds
|
||||
# (longest chain rule -- keep only the longest)
|
||||
max_deciding = max(dr for dr, _ in self.leaders.values())
|
||||
to_remove = []
|
||||
for r, (dr, _) in self.leaders.items():
|
||||
if dr < max_deciding and r < round_number:
|
||||
to_remove.append(r)
|
||||
for r in to_remove:
|
||||
del self.leaders[r]
|
||||
|
||||
return True
|
||||
|
||||
def get_leader(self, round_number: int) -> Optional[Message]:
|
||||
"""Get the current leader message for a round, if any."""
|
||||
entry = self.leaders.get(round_number)
|
||||
return entry[1] if entry else None
|
||||
|
||||
def max_round(self) -> int:
|
||||
"""Highest round with a decided leader."""
|
||||
return max(self.leaders.keys()) if self.leaders else -1
|
||||
|
||||
def all_leaders(self) -> list[tuple[int, Message]]:
|
||||
"""Return all leaders ordered by round number."""
|
||||
return [(r, msg) for r, (_, msg) in sorted(self.leaders.items())]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
rounds = sorted(self.leaders.keys())
|
||||
return f"LeaderStream(rounds={rounds})"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Algorithm 9: Order Loop
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def compute_order(graph: LamportGraph, leader_stream: LeaderStream) -> list[Vertex]:
|
||||
"""Algorithm 9: compute total order from the leader stream.
|
||||
|
||||
Pseudocode:
|
||||
1: loop order update loop
|
||||
2: wait for LEADER_G(·) to change
|
||||
3: s <- min round of all changed LEADER_G(t)
|
||||
4: r <- max round of all LEADER_G(t) ≠ ∅
|
||||
5: v_{l_r} <- leader in highest round, smallest s in G
|
||||
6: n <- max(v.total_position | v ∈ Ord_G(v_{l_{r-1}}))
|
||||
7: for x ≤ t ≤ r do
|
||||
8: randomly choose (p, l_t) ∈ LEADER_G(t)
|
||||
9: if l_t ≠ ∅ then
|
||||
10: ORDER(Ord_G(v_t), n) ▷ v_t.m = l_t
|
||||
11: end if
|
||||
12: end for
|
||||
13: end loop
|
||||
|
||||
For this PoC, we compute the order in a single pass over the current
|
||||
leader stream state.
|
||||
"""
|
||||
if not leader_stream.leaders:
|
||||
return []
|
||||
|
||||
ordered: list[Vertex] = []
|
||||
position = 0
|
||||
|
||||
# Process leaders in round order
|
||||
for round_number, leader_message in leader_stream.all_leaders():
|
||||
# Find the vertex corresponding to this leader message
|
||||
leader_digest = leader_message.compute_digest()
|
||||
leader_vertex = graph.get_vertex(leader_digest)
|
||||
|
||||
if leader_vertex is None:
|
||||
continue
|
||||
|
||||
# Order the past of this leader vertex (excluding already-ordered)
|
||||
past_vertices = graph.past(leader_vertex)
|
||||
already_ordered = {v.message_digest for v in ordered}
|
||||
new_vertices = [
|
||||
v for v in past_vertices
|
||||
if v.message_digest not in already_ordered
|
||||
]
|
||||
|
||||
# Sort new vertices using Kahn's algorithm (Algorithm 10)
|
||||
sorted_new = _kahns_total_order(new_vertices, graph)
|
||||
|
||||
for v in sorted_new:
|
||||
v.total_position = position
|
||||
ordered.append(v)
|
||||
position += 1
|
||||
|
||||
return ordered
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Algorithm 10: Total Order using Kahn's Algorithm
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _kahns_total_order(vertices: list[Vertex], graph: LamportGraph) -> list[Vertex]:
|
||||
"""Algorithm 10: generate total order using Kahn's algorithm.
|
||||
|
||||
Kahn's algorithm in its "arrow reversed" incarnation: we want to order
|
||||
the past before the future in our Lamport graph.
|
||||
|
||||
Pseudocode from the paper:
|
||||
1: procedure ORDER(dag:Ord(v), uint:last)
|
||||
2: n <- last + 1
|
||||
3: S <- set of all elements of Ord(v) with no outgoing edges
|
||||
4: while S ≠ ∅ do
|
||||
5: remove x with highest weight w(x) from S
|
||||
6: x.total_position <- n
|
||||
7: n <- n + 1
|
||||
8: for each vertex y ∈ Ord(v) with edge e : y -> x do
|
||||
9: remove edge e from Ord(v)
|
||||
10: if y has no other outgoing edge then
|
||||
11: S <- S ∪ {y}
|
||||
12: end if
|
||||
13: end for
|
||||
14: end while
|
||||
15: end procedure
|
||||
|
||||
Tie-breaking by voting weight ensures that all honest processes produce
|
||||
the same total order from equivalent Lamport graphs.
|
||||
"""
|
||||
if not vertices:
|
||||
return []
|
||||
|
||||
# Build a local subgraph for just these vertices
|
||||
vertex_set = {v.message_digest for v in vertices}
|
||||
|
||||
# out_degree: for each vertex, count edges to other vertices in this set
|
||||
out_edges: dict[bytes, set[bytes]] = {}
|
||||
in_edges: dict[bytes, set[bytes]] = {}
|
||||
|
||||
for v in vertices:
|
||||
d = v.message_digest
|
||||
out_edges[d] = set()
|
||||
in_edges[d] = set()
|
||||
|
||||
for v in vertices:
|
||||
d = v.message_digest
|
||||
for cause_d in graph.edges.get(d, set()):
|
||||
if cause_d in vertex_set:
|
||||
out_edges[d].add(cause_d)
|
||||
in_edges[cause_d].add(d)
|
||||
|
||||
# Start with vertices that have no outgoing edges (sinks = earliest causes)
|
||||
result: list[Vertex] = []
|
||||
available = [
|
||||
v for v in vertices
|
||||
if len(out_edges[v.message_digest]) == 0
|
||||
]
|
||||
|
||||
while available:
|
||||
# Remove the vertex with highest weight (deterministic tie-breaking)
|
||||
available.sort(key=lambda v: graph.vertex_weight(v), reverse=True)
|
||||
chosen = available.pop(0)
|
||||
result.append(chosen)
|
||||
|
||||
# Remove edges pointing to chosen
|
||||
chosen_d = chosen.message_digest
|
||||
for referrer_d in list(in_edges.get(chosen_d, set())):
|
||||
out_edges[referrer_d].discard(chosen_d)
|
||||
if len(out_edges[referrer_d]) == 0:
|
||||
referrer_vertex = graph.get_vertex(referrer_d)
|
||||
if referrer_vertex is not None and referrer_vertex not in result:
|
||||
available.append(referrer_vertex)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Convenience: full pipeline
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def total_order_positions(graph: LamportGraph,
|
||||
leader_stream: LeaderStream) -> dict[bytes, int]:
|
||||
"""Return a mapping of message digest -> total order position.
|
||||
|
||||
This is the final output of the Crisis protocol: a total order on
|
||||
messages that respects causality and is probabilistically invariant
|
||||
among all honest participants.
|
||||
"""
|
||||
ordered = compute_order(graph, leader_stream)
|
||||
return {v.message_digest: v.total_position for v in ordered
|
||||
if v.total_position is not None}
|
||||
231
src/crisis/rounds.py
Normal file
231
src/crisis/rounds.py
Normal file
|
|
@ -0,0 +1,231 @@
|
|||
"""
|
||||
Virtual Synchronous Rounds (Section 5.3)
|
||||
|
||||
Lamport graphs represent a timelike order between vertices that we interpret
|
||||
as virtual communication channels. Going one step further, we can think from
|
||||
inside the Lamport graph to define a virtual clock tick as a transition from
|
||||
one vertex to another.
|
||||
|
||||
This simple idea allows for internal synchronism that enables us to execute
|
||||
strongly synchronous agreement protocols like Feldman & Micali's BA*
|
||||
virtually, but without any compromise in external asynchronism.
|
||||
|
||||
Algorithm 5 (Virtual synchronous rounds):
|
||||
The algorithm computes *round numbers* and the *is_last* property
|
||||
of any vertex.
|
||||
|
||||
- The round number is computed by taking the largest round of all
|
||||
direct causes.
|
||||
- If the vertex is a direct effect of a current round vertex with
|
||||
the is_last property, a new round begins.
|
||||
- If the vertex has enough last vertices of the previous round in its
|
||||
past and it is k-reachable from all of them, the vertex becomes a
|
||||
last vertex in its own round.
|
||||
|
||||
Definition 5.1 (k-reachability):
|
||||
v_hat is said to be k-reachable from v, if the overall weight of all
|
||||
vertices in all paths from v to v_hat is greater than k.
|
||||
|
||||
Proposition 5.3 (Round invariance):
|
||||
The round number and is_last property do not depend on the actual
|
||||
Lamport graph, but are the same for equivalent vertices.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from crisis.graph import LamportGraph
|
||||
from crisis.message import Vertex
|
||||
from crisis.weight import DifficultyOracle
|
||||
|
||||
|
||||
def compute_rounds(graph: LamportGraph, difficulty: DifficultyOracle,
|
||||
connectivity_k: int = 2) -> None:
|
||||
"""Execute Algorithm 5 on all vertices in the graph.
|
||||
|
||||
This computes v.round and v.is_last for every vertex v in the graph.
|
||||
The algorithm processes vertices in causal order (causes before effects)
|
||||
to ensure dependencies are resolved before they are needed.
|
||||
|
||||
Args:
|
||||
graph: The Lamport graph to process.
|
||||
difficulty: The difficulty oracle d : N -> W.
|
||||
connectivity_k: The connectivity parameter k for k-reachability.
|
||||
"""
|
||||
# Process vertices in topological order (causes first)
|
||||
ordered = _topological_sort(graph)
|
||||
|
||||
for vertex in ordered:
|
||||
_compute_round_for_vertex(vertex, graph, difficulty, connectivity_k)
|
||||
|
||||
|
||||
def _compute_round_for_vertex(vertex: Vertex, graph: LamportGraph,
|
||||
difficulty: DifficultyOracle,
|
||||
connectivity_k: int) -> None:
|
||||
"""Algorithm 5: compute round number and is_last for a single vertex.
|
||||
|
||||
Pseudocode from the paper:
|
||||
1: procedure ROUND(vertex:v, lamport_graph:G)
|
||||
2: N_v <- {v_hat ∈ G | v -> v_hat} # direct causes
|
||||
3: r <- max({v_hat.round | v_hat ∈ N_v} ∪ {0})
|
||||
4: if there is a v_hat ∈ N_v with v_hat.is_last and v_hat.round = r then
|
||||
5: v.round <- r + 1
|
||||
6: else
|
||||
7: v.round <- r
|
||||
8: end if
|
||||
9: S_r <- {v_hat ∈ G | v_hat.round = v.round - 1, v_hat.is_last, v_hat ≤_k v}
|
||||
10: if w(S_r) > 3 * d_r then
|
||||
11: v.is_last <- true
|
||||
12: else
|
||||
13: v.is_last <- (r = 0)
|
||||
14: end if
|
||||
15: end procedure
|
||||
"""
|
||||
# Step 2: direct causes
|
||||
direct_causes = graph.direct_causes(vertex)
|
||||
|
||||
# Step 3: max round of direct causes (default 0 if no causes)
|
||||
if direct_causes:
|
||||
max_round = max(
|
||||
(dc.round if dc.round is not None else 0) for dc in direct_causes
|
||||
)
|
||||
else:
|
||||
max_round = 0
|
||||
|
||||
# Steps 4-8: determine this vertex's round
|
||||
# If any direct cause is a "last vertex" of the current max round,
|
||||
# this vertex starts a new round.
|
||||
has_last_cause_in_max_round = any(
|
||||
dc.is_last and dc.round == max_round
|
||||
for dc in direct_causes
|
||||
if dc.round is not None and dc.is_last is not None
|
||||
)
|
||||
|
||||
if has_last_cause_in_max_round:
|
||||
vertex.round = max_round + 1
|
||||
else:
|
||||
vertex.round = max_round
|
||||
|
||||
# Steps 9-14: determine is_last
|
||||
r = vertex.round
|
||||
if r == 0:
|
||||
# All round-0 vertices are "last" (bootstrapping)
|
||||
vertex.is_last = True
|
||||
return
|
||||
|
||||
# Find last vertices of the previous round that are k-reachable from v
|
||||
d_r = difficulty.difficulty(r)
|
||||
|
||||
previous_round_lasts = [
|
||||
v_hat for v_hat in graph.all_vertices()
|
||||
if v_hat.round == r - 1
|
||||
and v_hat.is_last
|
||||
and _is_k_reachable(v_hat, vertex, graph, connectivity_k)
|
||||
]
|
||||
|
||||
# Weight of k-reachable last vertices from previous round
|
||||
weight_of_previous_lasts = graph.set_weight(previous_round_lasts)
|
||||
|
||||
if weight_of_previous_lasts > 3 * d_r:
|
||||
vertex.is_last = True
|
||||
else:
|
||||
vertex.is_last = False
|
||||
|
||||
|
||||
def _is_k_reachable(v_from: Vertex, v_to: Vertex,
|
||||
graph: LamportGraph, k: int) -> bool:
|
||||
"""Check k-reachability (Definition 5.1).
|
||||
|
||||
v_hat is k-reachable from v if the overall weight of all vertices in
|
||||
all paths from v to v_hat is greater than k.
|
||||
|
||||
For simplicity in this PoC, we approximate this by checking if v_from
|
||||
is in the past of v_to and the total weight along the path exceeds k.
|
||||
|
||||
The paper notes (page 11): "counting disjoint paths is computationally
|
||||
expensive and not really necessary in our setting... all we need is some
|
||||
insurance that information flows through enough real world processes."
|
||||
We use total path weight as a simpler proxy.
|
||||
"""
|
||||
if v_from not in graph.past(v_to):
|
||||
return False
|
||||
|
||||
# Compute the weight of all vertices in the path from v_from to v_to
|
||||
# (all vertices that are in both the future of v_from and the past of v_to)
|
||||
past_of_to = graph.past(v_to)
|
||||
future_of_from = graph.future(v_from)
|
||||
|
||||
path_vertices = past_of_to & future_of_from
|
||||
total_weight = graph.set_weight(path_vertices)
|
||||
|
||||
return total_weight > k
|
||||
|
||||
|
||||
def _topological_sort(graph: LamportGraph) -> list[Vertex]:
|
||||
"""Sort vertices in causal order: causes come before their effects.
|
||||
|
||||
Uses Kahn's algorithm. Vertices with no causes (sources) come first.
|
||||
This ensures that when we process a vertex, all its causes already
|
||||
have their round numbers computed.
|
||||
"""
|
||||
# Compute in-degree (number of causes each vertex has within the graph)
|
||||
in_degree: dict[bytes, int] = {}
|
||||
for d, v in graph.vertices.items():
|
||||
in_degree[d] = 0
|
||||
|
||||
for d, v in graph.vertices.items():
|
||||
for ref_d in graph.edges.get(d, set()):
|
||||
if ref_d in graph.vertices:
|
||||
# ref_d is a cause of d, so d has an additional in-edge
|
||||
# But we want causal order: causes first
|
||||
# edges go from effect -> cause, so we need reverse
|
||||
pass
|
||||
|
||||
# Actually: edges[d] contains the causes of d (d -> cause).
|
||||
# For topological sort where causes come first, we need:
|
||||
# in_degree[d] = number of digests in edges[d] that are in the graph
|
||||
for d in graph.vertices:
|
||||
count = 0
|
||||
for cause_d in graph.edges.get(d, set()):
|
||||
if cause_d in graph.vertices:
|
||||
count += 1
|
||||
in_degree[d] = count
|
||||
|
||||
# Start with vertices that have no causes (in_degree = 0)
|
||||
queue = [d for d, deg in in_degree.items() if deg == 0]
|
||||
result = []
|
||||
|
||||
while queue:
|
||||
current = queue.pop(0)
|
||||
result.append(graph.vertices[current])
|
||||
|
||||
# For each vertex that current is a cause of (reverse edges)
|
||||
for effect_d in graph.reverse_edges.get(current, set()):
|
||||
if effect_d in in_degree:
|
||||
in_degree[effect_d] -= 1
|
||||
if in_degree[effect_d] == 0:
|
||||
queue.append(effect_d)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Queries on computed rounds
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def last_vertices_in_round(graph: LamportGraph, round_number: int) -> list[Vertex]:
|
||||
"""Return all last vertices in a given round."""
|
||||
return [
|
||||
v for v in graph.all_vertices()
|
||||
if v.round == round_number and v.is_last
|
||||
]
|
||||
|
||||
|
||||
def max_round(graph: LamportGraph) -> int:
|
||||
"""Return the highest round number in the graph."""
|
||||
rounds = [v.round for v in graph.all_vertices() if v.round is not None]
|
||||
return max(rounds) if rounds else 0
|
||||
|
||||
|
||||
def vertices_in_round(graph: LamportGraph, round_number: int) -> list[Vertex]:
|
||||
"""Return all vertices in a given round."""
|
||||
return [v for v in graph.all_vertices() if v.round == round_number]
|
||||
527
src/crisis/voting.py
Normal file
527
src/crisis/voting.py
Normal file
|
|
@ -0,0 +1,527 @@
|
|||
"""
|
||||
Virtual Voting, Safe Voting Patterns, and Leader Election (Section 5)
|
||||
|
||||
This module implements the heart of the Crisis protocol: the virtual voting
|
||||
mechanism that achieves total order without ever sending explicit vote messages.
|
||||
|
||||
Key concepts:
|
||||
|
||||
5.5 Virtual Process Sortition & Knowledge Graphs
|
||||
- Knowledge graph (Def 5.8): quotient graph projecting vertices to virtual
|
||||
processes, representing what each process "knows" about others.
|
||||
- Quorum selector (Def 5.11): deterministically chooses a subset of virtual
|
||||
processes for each round -- the quorum that participates in agreement.
|
||||
|
||||
5.6 Safe Voting Pattern
|
||||
- Voting sets (Def 5.12): the set of vertices participating in round s
|
||||
agreement, reachable with connectivity k from vertex v.
|
||||
- Algorithm 6: computes the safe voting pattern -- a nested sequence of
|
||||
rounds where voting took place with appropriately bounded byzantine weight.
|
||||
|
||||
5.7 Local Leader Election
|
||||
- Algorithm 7: virtual leader elections -- an adaptation of Chen, Feldman
|
||||
& Micali's BA* to virtual voting on Lamport graphs.
|
||||
- Three stage types: initial proposal (δ=0), presorting/gradecast (δ∈{1,2}),
|
||||
and BBA* binary agreement (δ≥3) with "coin fixed to 0/1" and "genuine
|
||||
coin flip" sub-stages.
|
||||
|
||||
5.8 Longest Chain Rule
|
||||
- Algorithm 8: maintains the leader stream by keeping only the longest
|
||||
chain of round leaders (similar to Nakamoto's longest chain rule).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from crisis.crypto import digest, least_significant_bit
|
||||
from crisis.graph import LamportGraph
|
||||
from crisis.message import Message, Vertex, Vote, EMPTY_MESSAGE_DIGEST
|
||||
from crisis.rounds import last_vertices_in_round, max_round
|
||||
from crisis.weight import DifficultyOracle
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Knowledge Graph (Definition 5.8)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass
|
||||
class KnowledgeGraph:
|
||||
"""The round s knowledge graph of vertex v (Definition 5.8).
|
||||
|
||||
Given rounds s < r, a Lamport graph G, and v a last message in round r,
|
||||
the knowledge graph Π^s_v is the quotient graph G^s_v / ≃_id.
|
||||
|
||||
Each node in the knowledge graph represents a virtual process (identified
|
||||
by its id). An edge from process id to id' means that some vertex with
|
||||
v.id = id in round s has a vertex with v_hat.id = id' in its past.
|
||||
|
||||
This represents what each virtual process "knows" about others.
|
||||
"""
|
||||
# id -> set of ids that this process has edges to
|
||||
edges: dict[bytes, set[bytes]] = field(default_factory=dict)
|
||||
# id -> total weight of vertices in this equivalence class
|
||||
weights: dict[bytes, int] = field(default_factory=dict)
|
||||
|
||||
|
||||
def build_knowledge_graph(vertex: Vertex, round_s: int,
|
||||
graph: LamportGraph) -> KnowledgeGraph:
|
||||
"""Build the round s knowledge graph for vertex v.
|
||||
|
||||
Collects all round-s vertices in v's past, groups them by id,
|
||||
and builds the quotient graph.
|
||||
"""
|
||||
kg = KnowledgeGraph()
|
||||
past = graph.past(vertex)
|
||||
|
||||
# Find all round-s vertices in v's past
|
||||
round_s_vertices = [v for v in past if v.round == round_s]
|
||||
|
||||
# Group by id and compute edges
|
||||
for v_s in round_s_vertices:
|
||||
vid = v_s.id
|
||||
if vid not in kg.edges:
|
||||
kg.edges[vid] = set()
|
||||
if vid not in kg.weights:
|
||||
kg.weights[vid] = 0
|
||||
|
||||
kg.weights[vid] = graph.weight_system.weight_sum(
|
||||
kg.weights[vid], graph.vertex_weight(v_s)
|
||||
)
|
||||
|
||||
# Add edges based on what this vertex references
|
||||
for cause in graph.direct_causes(v_s):
|
||||
if cause.round is not None and cause.round == round_s:
|
||||
kg.edges[vid].add(cause.id)
|
||||
|
||||
return kg
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Quorum Selector (Definition 5.11)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def select_quorum(knowledge_graph: KnowledgeGraph, n: int = 3) -> set[bytes]:
|
||||
"""Select a quorum from a knowledge graph (Definition 5.11).
|
||||
|
||||
Example 3 (Highest voting weight quorum):
|
||||
Choose the weakly connected component with the highest combined voting
|
||||
weight, then take the heaviest n virtual processes from it.
|
||||
|
||||
The quorum selector serves as a filter to reduce byzantine noise that
|
||||
might appear in the voting process. By restricting to a heavily
|
||||
connected component, faulty behavior based on graph partition is reduced.
|
||||
"""
|
||||
if not knowledge_graph.edges:
|
||||
return set()
|
||||
|
||||
# Find weakly connected components using simple BFS
|
||||
all_ids = set(knowledge_graph.edges.keys())
|
||||
visited: set[bytes] = set()
|
||||
components: list[set[bytes]] = []
|
||||
|
||||
for start_id in all_ids:
|
||||
if start_id in visited:
|
||||
continue
|
||||
component: set[bytes] = set()
|
||||
queue = [start_id]
|
||||
while queue:
|
||||
current = queue.pop(0)
|
||||
if current in visited:
|
||||
continue
|
||||
visited.add(current)
|
||||
component.add(current)
|
||||
# Follow edges in both directions (weakly connected)
|
||||
for neighbor in knowledge_graph.edges.get(current, set()):
|
||||
if neighbor not in visited and neighbor in all_ids:
|
||||
queue.append(neighbor)
|
||||
# Reverse edges
|
||||
for other_id, neighbors in knowledge_graph.edges.items():
|
||||
if current in neighbors and other_id not in visited:
|
||||
queue.append(other_id)
|
||||
components.append(component)
|
||||
|
||||
# Choose the component with highest total weight
|
||||
def component_weight(comp: set[bytes]) -> int:
|
||||
return sum(knowledge_graph.weights.get(pid, 0) for pid in comp)
|
||||
|
||||
best_component = max(components, key=component_weight)
|
||||
|
||||
# Take the n heaviest processes from this component
|
||||
sorted_by_weight = sorted(
|
||||
best_component,
|
||||
key=lambda pid: knowledge_graph.weights.get(pid, 0),
|
||||
reverse=True
|
||||
)
|
||||
|
||||
return set(sorted_by_weight[:n])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Voting Sets (Definition 5.12)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def voting_set(vertex: Vertex, round_s: int, connectivity_k: int,
|
||||
graph: LamportGraph) -> set[Vertex]:
|
||||
"""Compute S_v(s,k): v's round s voting set (Definition 5.12).
|
||||
|
||||
S_v(s,k) := { x | x.id ∈ Q(v,s) ∧ x ≤_{(r-s)*k} v
|
||||
∧ x.round = s ∧ x.is_last = true }
|
||||
|
||||
The voting set consists of all last vertices in round s that:
|
||||
1. Belong to a quorum-selected virtual process
|
||||
2. Are k-reachable from v (with distance scaled by round gap)
|
||||
3. Are in v's past
|
||||
"""
|
||||
if vertex.round is None:
|
||||
return set()
|
||||
|
||||
r = vertex.round
|
||||
if round_s >= r:
|
||||
return set()
|
||||
|
||||
# Build knowledge graph and select quorum
|
||||
kg = build_knowledge_graph(vertex, round_s, graph)
|
||||
quorum = select_quorum(kg)
|
||||
|
||||
past_of_v = graph.past(vertex)
|
||||
|
||||
result = set()
|
||||
for v_hat in past_of_v:
|
||||
if (v_hat.round == round_s
|
||||
and v_hat.is_last
|
||||
and v_hat.id in quorum):
|
||||
result.add(v_hat)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Algorithm 6: Safe Voting Pattern (Section 5.6)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def compute_safe_voting_pattern(vertex: Vertex, graph: LamportGraph,
|
||||
difficulty: DifficultyOracle,
|
||||
connectivity_k: int = 2) -> None:
|
||||
"""Algorithm 6: compute the safe voting pattern for a vertex.
|
||||
|
||||
The safe voting pattern v.svp is a totally ordered set of round numbers
|
||||
where "safe" voting took place. Safe means:
|
||||
- The voting set has enough overall weight
|
||||
- The svp of all members agree
|
||||
- Byzantine weight is bounded
|
||||
|
||||
Pseudocode from the paper:
|
||||
1: procedure SVP(vertex:v, lamport_graph:G)
|
||||
2: v.svp <- ∅
|
||||
3: if v.is_last and [safe voting pattern conditions are met] then
|
||||
4: s <- maximum of all such k
|
||||
5: v.svp <- v.svp ∪ {s} for all t ≤ s
|
||||
6: end if
|
||||
7: end procedure
|
||||
|
||||
The procedure checks if the current vertex's round qualifies as a new
|
||||
entry in the safe voting pattern by verifying weight and agreement
|
||||
conditions from its voting set.
|
||||
"""
|
||||
vertex.svp = []
|
||||
|
||||
if not vertex.is_last or vertex.round is None or vertex.round == 0:
|
||||
return
|
||||
|
||||
r = vertex.round
|
||||
|
||||
# Check each previous round for safe voting pattern membership
|
||||
for s in range(r):
|
||||
d_s = difficulty.difficulty(s)
|
||||
|
||||
# Get voting set for round s
|
||||
vs = voting_set(vertex, s, connectivity_k, graph)
|
||||
if not vs:
|
||||
continue
|
||||
|
||||
total_weight = graph.set_weight(vs)
|
||||
|
||||
# Check if voting weight exceeds threshold (6 * d_s from Eq. 8)
|
||||
if total_weight <= 6 * d_s:
|
||||
continue
|
||||
|
||||
# Check that all members of the voting set have compatible svp
|
||||
svps_agree = True
|
||||
for x in vs:
|
||||
for y in vs:
|
||||
if x.svp != y.svp:
|
||||
# Allow prefix agreement
|
||||
min_len = min(len(x.svp), len(y.svp))
|
||||
if x.svp[:min_len] != y.svp[:min_len]:
|
||||
svps_agree = False
|
||||
break
|
||||
if not svps_agree:
|
||||
break
|
||||
|
||||
if svps_agree:
|
||||
vertex.svp.append(s)
|
||||
|
||||
# svp is a nested sequence: add current round
|
||||
if vertex.svp:
|
||||
vertex.svp.append(r)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Initial Vote Function (Definition 5.16, Example 4)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def initial_vote(vertices: set[Vertex], graph: LamportGraph) -> Optional[Message]:
|
||||
"""INITIAL_VOTE: deterministically choose a leader proposal (Def 5.16).
|
||||
|
||||
Example 4 (Highest weight): Choose the underlying message of the highest
|
||||
voting weight vertex. Since we assume it is infeasible to have different
|
||||
vertices of equal weight, this is practically deterministic.
|
||||
|
||||
The initial vote function is a system parameter. Different choices lead
|
||||
to different long-term behavior. Ideally all members of a safe voting
|
||||
pattern would compute the same initial vote.
|
||||
"""
|
||||
if not vertices:
|
||||
return None
|
||||
|
||||
best_vertex = max(vertices, key=lambda v: graph.vertex_weight(v))
|
||||
return best_vertex.m
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Algorithm 7: Virtual Leader Elections (Section 5.7)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def compute_virtual_leader_election(vertex: Vertex, graph: LamportGraph,
|
||||
difficulty: DifficultyOracle,
|
||||
connectivity_k: int,
|
||||
leader_stream: dict[int, list[tuple[int, Message]]]) -> None:
|
||||
"""Algorithm 7: compute votes for all rounds in v's safe voting pattern.
|
||||
|
||||
This is the core virtual BA* protocol. For each element t in v.svp,
|
||||
the vertex computes a vote v.vote(t) = (l, b) based on the stage δ
|
||||
(the position of that round in the svp).
|
||||
|
||||
Stage types (determined by δ = d_{v.svp}(s, t)):
|
||||
δ = 0: Initial leader proposal
|
||||
δ = 1: Leader presorting (gradecast step)
|
||||
δ = 2: BBA* initialization (gradecast step)
|
||||
δ ≥ 3: Binary agreement rounds
|
||||
δ mod 3 = 0: Coin fixed to 0
|
||||
δ mod 3 = 1: Coin fixed to 1
|
||||
δ mod 3 = 2: Genuine coin flip
|
||||
|
||||
The paper notes: "every step is entirely virtual and no votes are
|
||||
actually sent to other real world processes."
|
||||
"""
|
||||
if not vertex.svp:
|
||||
return
|
||||
|
||||
s = max(vertex.svp) if vertex.svp else None
|
||||
if s is None:
|
||||
return
|
||||
|
||||
for t_idx, t in enumerate(vertex.svp):
|
||||
delta = t_idx # stage = position in svp
|
||||
_compute_vote_for_stage(vertex, t, delta, s, graph, difficulty,
|
||||
connectivity_k, leader_stream)
|
||||
|
||||
|
||||
def _compute_vote_for_stage(vertex: Vertex, t: int, delta: int, s: int,
|
||||
graph: LamportGraph, difficulty: DifficultyOracle,
|
||||
connectivity_k: int,
|
||||
leader_stream: dict[int, list[tuple[int, Message]]]) -> None:
|
||||
"""Compute vertex's vote for a specific stage of the virtual leader election.
|
||||
|
||||
Implements the branching logic of Algorithm 7 (pages 19-20 of the paper).
|
||||
"""
|
||||
d_s = difficulty.difficulty(s)
|
||||
vs = voting_set(vertex, t, connectivity_k, graph)
|
||||
n = graph.set_weight(vs)
|
||||
|
||||
NON_LEADER = None # ∅ in the paper
|
||||
|
||||
if delta == 0:
|
||||
# Stage 0: Initial leader proposal
|
||||
l = initial_vote(vs, graph)
|
||||
vertex.vote[t] = Vote(message=l, binary=None) # (INITIAL_VOTE(S), ⊥)
|
||||
|
||||
elif delta == 1:
|
||||
# Stage 1: Leader presorting
|
||||
# Find message with highest round-t voting weight in S
|
||||
l = _highest_weight_message(vs, graph)
|
||||
|
||||
if l is not None:
|
||||
# Check if l has super majority weight
|
||||
l_weight = _vote_weight_for(vs, t, l, None, graph) # votes for (l, ⊥)
|
||||
if l_weight > n - d_s:
|
||||
vertex.vote[t] = Vote(message=l, binary=None) # (l, ⊥)
|
||||
else:
|
||||
vertex.vote[t] = Vote(message=NON_LEADER, binary=None) # (∅, ⊥)
|
||||
else:
|
||||
vertex.vote[t] = Vote(message=NON_LEADER, binary=None)
|
||||
|
||||
elif delta == 2:
|
||||
# Stage 2: BBA* initialization (gradecast)
|
||||
l = _highest_weight_message(vs, graph)
|
||||
|
||||
if l is not None:
|
||||
l_weight_undecided = _vote_weight_for(vs, t, l, None, graph)
|
||||
if l_weight_undecided > n - d_s:
|
||||
vertex.vote[t] = Vote(message=l, binary=0)
|
||||
else:
|
||||
l_weight_1 = _vote_weight_for(vs, t, l, 1, graph)
|
||||
if l_weight_1 > d_s:
|
||||
vertex.vote[t] = Vote(message=l, binary=1)
|
||||
else:
|
||||
vertex.vote[t] = Vote(message=NON_LEADER, binary=1)
|
||||
else:
|
||||
vertex.vote[t] = Vote(message=NON_LEADER, binary=1)
|
||||
|
||||
else:
|
||||
# Stage δ ≥ 3: Binary agreement (BBA*)
|
||||
coin_stage = delta % 3
|
||||
l = _highest_weight_message(vs, graph)
|
||||
|
||||
if coin_stage == 0:
|
||||
# Coin fixed to 0
|
||||
_bba_coin_fixed(vertex, t, vs, l, n, d_s, graph,
|
||||
leader_stream, s, fixed_value=0)
|
||||
elif coin_stage == 1:
|
||||
# Coin fixed to 1
|
||||
_bba_coin_fixed(vertex, t, vs, l, n, d_s, graph,
|
||||
leader_stream, s, fixed_value=1)
|
||||
else:
|
||||
# Genuine coin flip (coin_stage == 2)
|
||||
_bba_genuine_coin(vertex, t, vs, l, n, d_s, graph)
|
||||
|
||||
|
||||
def _bba_coin_fixed(vertex: Vertex, t: int, vs: set[Vertex],
|
||||
l: Optional[Message], n: int, d_s: int,
|
||||
graph: LamportGraph,
|
||||
leader_stream: dict[int, list[tuple[int, Message]]],
|
||||
s: int, fixed_value: int) -> None:
|
||||
"""BBA* stage with coin fixed to 0 or 1."""
|
||||
other_value = 1 - fixed_value
|
||||
|
||||
if l is not None:
|
||||
weight_for_fixed = _vote_weight_for_binary(vs, t, fixed_value, graph)
|
||||
if weight_for_fixed > n - d_s:
|
||||
vertex.vote[t] = Vote(message=l, binary=fixed_value)
|
||||
# If weight = n, we have agreement: update leader stream
|
||||
if weight_for_fixed == n:
|
||||
_update_leader_stream(leader_stream, l, s)
|
||||
return
|
||||
|
||||
weight_for_other = _vote_weight_for_binary(vs, t, other_value, graph)
|
||||
if weight_for_other > n - d_s:
|
||||
vertex.vote[t] = Vote(message=l, binary=other_value)
|
||||
return
|
||||
|
||||
vertex.vote[t] = Vote(message=l, binary=fixed_value)
|
||||
|
||||
|
||||
def _bba_genuine_coin(vertex: Vertex, t: int, vs: set[Vertex],
|
||||
l: Optional[Message], n: int, d_s: int,
|
||||
graph: LamportGraph) -> None:
|
||||
"""BBA* stage with genuine coin flip."""
|
||||
if l is not None:
|
||||
weight_0 = _vote_weight_for_binary(vs, t, 0, graph)
|
||||
if weight_0 > n - d_s:
|
||||
vertex.vote[t] = Vote(message=l, binary=0)
|
||||
return
|
||||
|
||||
weight_1 = _vote_weight_for_binary(vs, t, 1, graph)
|
||||
if weight_1 > n - d_s:
|
||||
vertex.vote[t] = Vote(message=l, binary=1)
|
||||
return
|
||||
|
||||
# Genuine coin flip: use LSB of hash of heaviest vertex's message
|
||||
if vs:
|
||||
heaviest = max(vs, key=lambda v: graph.vertex_weight(v))
|
||||
h = heaviest.m.compute_digest()
|
||||
b_coin = least_significant_bit(h)
|
||||
vertex.vote[t] = Vote(message=l, binary=b_coin)
|
||||
else:
|
||||
vertex.vote[t] = Vote(message=l, binary=0)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helper functions for vote weight computation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _highest_weight_message(vs: set[Vertex], graph: LamportGraph) -> Optional[Message]:
|
||||
"""Find the message with the highest voting weight in a set."""
|
||||
if not vs:
|
||||
return None
|
||||
best = max(vs, key=lambda v: graph.vertex_weight(v))
|
||||
return best.m
|
||||
|
||||
|
||||
def _vote_weight_for(vs: set[Vertex], round_t: int,
|
||||
target_msg: Optional[Message], target_binary: Optional[int],
|
||||
graph: LamportGraph) -> int:
|
||||
"""Compute total voting weight for a specific vote (l, b) in a voting set."""
|
||||
total = 0
|
||||
for v in vs:
|
||||
vote = v.vote.get(round_t)
|
||||
if vote is None:
|
||||
continue
|
||||
msg_match = (vote.message is None and target_msg is None) or \
|
||||
(vote.message is not None and target_msg is not None and
|
||||
vote.message.compute_digest() == target_msg.compute_digest())
|
||||
bin_match = vote.binary == target_binary
|
||||
if msg_match and bin_match:
|
||||
total = graph.weight_system.weight_sum(total, graph.vertex_weight(v))
|
||||
return total
|
||||
|
||||
|
||||
def _vote_weight_for_binary(vs: set[Vertex], round_t: int,
|
||||
target_binary: int,
|
||||
graph: LamportGraph) -> int:
|
||||
"""Compute total voting weight for a specific binary value in a voting set."""
|
||||
total = 0
|
||||
for v in vs:
|
||||
vote = v.vote.get(round_t)
|
||||
if vote is not None and vote.binary == target_binary:
|
||||
total = graph.weight_system.weight_sum(total, graph.vertex_weight(v))
|
||||
return total
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Algorithm 8: Longest Chain Rule (Section 5.8)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _update_leader_stream(leader_stream: dict[int, list[tuple[int, Message]]],
|
||||
message: Message, round_number: int) -> None:
|
||||
"""Algorithm 8: update the leader stream with a new leader candidate.
|
||||
|
||||
The longest chain rule keeps only the chain with the highest deciding
|
||||
round for each round leader. When a new round leader is decided at a
|
||||
higher deciding round, previous entries with lower deciding rounds are
|
||||
replaced.
|
||||
|
||||
Pseudocode:
|
||||
1: procedure LONG_CHAIN(set{(uint,MESSAGE)}:S, MESSAGE:m, uint:s)
|
||||
2: if there is no (l, t) ∈ S with t > s then
|
||||
3: S <- {(l, t) ∈ S | t < s} ∪ (s, m)
|
||||
4: end if
|
||||
5: return S
|
||||
6: end procedure
|
||||
"""
|
||||
if round_number not in leader_stream:
|
||||
leader_stream[round_number] = []
|
||||
|
||||
entries = leader_stream[round_number]
|
||||
|
||||
# Check if there's already an entry with a higher deciding round
|
||||
has_higher = any(t > round_number for (t, _) in entries)
|
||||
if has_higher:
|
||||
return
|
||||
|
||||
# Remove entries with lower deciding rounds, add new one
|
||||
leader_stream[round_number] = [
|
||||
(t, m) for (t, m) in entries if t < round_number
|
||||
] + [(round_number, message)]
|
||||
190
src/crisis/weight.py
Normal file
190
src/crisis/weight.py
Normal file
|
|
@ -0,0 +1,190 @@
|
|||
"""
|
||||
Weight Systems (Section 3.1.1)
|
||||
|
||||
Definition 3.1 (Weight system): Let MESSAGE be the metric space of all
|
||||
messages and (W, ≤) a totally ordered set. Then the tuple (W, w, ⊕, c_min)
|
||||
is a *weight system* if w is a function
|
||||
|
||||
w : MESSAGE -> W (Eq. 3)
|
||||
|
||||
that assigns an element of W to any message, c_min ∈ W is a constant called
|
||||
the *weight threshold*, and ⊕ is a function
|
||||
|
||||
⊕ : W × W -> W (Eq. 4)
|
||||
|
||||
called the *weight sum*, such that:
|
||||
|
||||
- Tamper proof: w(m) >= c_min and m_hat ≠ m implies w(m_hat) < c_min
|
||||
with high probability.
|
||||
- Uniqueness: m ≠ m_hat implies w(m) ≠ w(m_hat) with high probability.
|
||||
- Summability: (W, ⊕) is a totally ordered, abelian group.
|
||||
|
||||
The weight w(m) is interpreted as the amount of voting power m holds to
|
||||
influence total order generation.
|
||||
|
||||
This module provides:
|
||||
1. An abstract WeightSystem protocol
|
||||
2. A concrete Proof-of-Work implementation (Hashcash-style)
|
||||
|
||||
The PoW weight function counts leading zero bits of H(m), similar to Bitcoin's
|
||||
difficulty mechanism (Nakamoto, 2009; Beck, 2002).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Protocol
|
||||
|
||||
from crisis.crypto import digest, count_leading_zero_bits
|
||||
from crisis.message import Message
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Abstract weight system
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class WeightSystem(Protocol):
|
||||
"""Protocol defining the weight system interface (Definition 3.1).
|
||||
|
||||
Any concrete weight system must provide:
|
||||
- weight(): Compute w(m) for a message
|
||||
- threshold: The minimum weight c_min
|
||||
- weight_sum(): Compute ⊕ for two weights
|
||||
"""
|
||||
|
||||
@property
|
||||
def threshold(self) -> int:
|
||||
"""c_min: the minimum weight threshold.
|
||||
|
||||
Messages with weight below this are rejected. This prevents Sybil
|
||||
attacks by ensuring every message requires a minimum investment.
|
||||
"""
|
||||
...
|
||||
|
||||
def weight(self, message: Message) -> int:
|
||||
"""w(m): compute the weight of a message.
|
||||
|
||||
The weight represents the voting power of this message in the
|
||||
consensus protocol.
|
||||
"""
|
||||
...
|
||||
|
||||
def weight_sum(self, a: int, b: int) -> int:
|
||||
"""⊕: combine two weights.
|
||||
|
||||
Must form a totally ordered abelian group.
|
||||
For our purposes, ordinary integer addition suffices.
|
||||
"""
|
||||
...
|
||||
|
||||
def is_valid_weight(self, message: Message) -> bool:
|
||||
"""Check whether w(m) >= c_min."""
|
||||
...
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Proof-of-Work weight system
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass
|
||||
class ProofOfWorkWeight:
|
||||
"""A Hashcash-style Proof-of-Work weight system.
|
||||
|
||||
The weight of a message is the number of leading zero bits in H(m).
|
||||
This is similar to Bitcoin's mining: finding a message whose hash starts
|
||||
with k zero bits requires approximately 2^k hash evaluations on average.
|
||||
|
||||
The nonce field of the message is used to search for a valid hash,
|
||||
analogous to Bitcoin's block header nonce.
|
||||
|
||||
Attributes:
|
||||
min_leading_zeros: c_min -- minimum leading zero bits required.
|
||||
A value of 1 means every message needs at least
|
||||
1 leading zero bit (50% of hashes qualify).
|
||||
"""
|
||||
min_leading_zeros: int = 1
|
||||
|
||||
@property
|
||||
def threshold(self) -> int:
|
||||
return self.min_leading_zeros
|
||||
|
||||
def weight(self, message: Message) -> int:
|
||||
"""Count leading zero bits in H(m).
|
||||
|
||||
More leading zeros = more work performed = higher voting weight.
|
||||
"""
|
||||
h = message.compute_digest()
|
||||
return count_leading_zero_bits(h)
|
||||
|
||||
def weight_sum(self, a: int, b: int) -> int:
|
||||
"""Simple integer addition for combining weights.
|
||||
|
||||
This satisfies the abelian group requirement: (Z, +) is a totally
|
||||
ordered abelian group with identity 0.
|
||||
"""
|
||||
return a + b
|
||||
|
||||
def is_valid_weight(self, message: Message) -> bool:
|
||||
"""Check w(m) >= c_min."""
|
||||
return self.weight(message) >= self.threshold
|
||||
|
||||
def mine_nonce(self, id_bytes: bytes, digests: tuple[bytes, ...],
|
||||
payload: bytes, target_weight: int | None = None) -> Message:
|
||||
"""Search for a nonce that produces a message meeting the weight target.
|
||||
|
||||
This is the "nonce grinding" step: try successive nonce values until
|
||||
H(m) has enough leading zero bits.
|
||||
|
||||
Args:
|
||||
id_bytes: The virtual process id for this message.
|
||||
digests: Causal acknowledgements (digests of prior messages).
|
||||
payload: The application payload.
|
||||
target_weight: Minimum weight to achieve. Defaults to c_min.
|
||||
|
||||
Returns:
|
||||
A Message with a valid nonce.
|
||||
"""
|
||||
if target_weight is None:
|
||||
target_weight = self.threshold
|
||||
|
||||
nonce_int = 0
|
||||
while True:
|
||||
nonce = nonce_int.to_bytes(8, "big")
|
||||
msg = Message(nonce=nonce, id=id_bytes, digests=digests, payload=payload)
|
||||
if self.weight(msg) >= target_weight:
|
||||
return msg
|
||||
nonce_int += 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Difficulty Oracle (Section 5.4, Definition 5.2)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass
|
||||
class DifficultyOracle:
|
||||
"""Maps round numbers to difficulty values (Definition 5.2).
|
||||
|
||||
The difficulty oracle d : N -> W maps natural numbers (rounds) onto
|
||||
weights. The value d_r := d(r) is called the *round r difficulty*.
|
||||
|
||||
The difficulty is designed so that the overall voting weight per round
|
||||
is bounded:
|
||||
|
||||
lim sum(w_s^G / d_s) <= 6 (Eq. 8)
|
||||
|
||||
for all time parameters t, where w_s^G is the overall voting weight of
|
||||
last vertices in round s.
|
||||
|
||||
Example 1 (paper): A fixed constant that does not change over time.
|
||||
This is the simplest starting point for a PoC.
|
||||
"""
|
||||
constant_difficulty: int = 4
|
||||
|
||||
def difficulty(self, round_number: int) -> int:
|
||||
"""d(r): return the difficulty for round r.
|
||||
|
||||
For this PoC we use a fixed constant (paper Example 1).
|
||||
A production system might adapt this based on observed voting
|
||||
weight, similar to Bitcoin's difficulty adjustment.
|
||||
"""
|
||||
return self.constant_difficulty
|
||||
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
57
tests/test_crypto.py
Normal file
57
tests/test_crypto.py
Normal file
|
|
@ -0,0 +1,57 @@
|
|||
"""Tests for the crypto module (random oracle model)."""
|
||||
|
||||
from crisis.crypto import (
|
||||
digest, digest_hex, verify_digest,
|
||||
least_significant_bit, count_leading_zero_bits,
|
||||
DIGEST_LENGTH,
|
||||
)
|
||||
|
||||
|
||||
def test_digest_returns_32_bytes():
|
||||
h = digest(b"hello")
|
||||
assert len(h) == DIGEST_LENGTH == 32
|
||||
|
||||
|
||||
def test_digest_is_deterministic():
|
||||
assert digest(b"test") == digest(b"test")
|
||||
|
||||
|
||||
def test_digest_different_inputs_different_outputs():
|
||||
assert digest(b"a") != digest(b"b")
|
||||
|
||||
|
||||
def test_digest_hex_matches():
|
||||
h = digest(b"hello")
|
||||
assert digest_hex(b"hello") == h.hex()
|
||||
|
||||
|
||||
def test_verify_digest():
|
||||
h = digest(b"data")
|
||||
assert verify_digest(b"data", h)
|
||||
assert not verify_digest(b"other", h)
|
||||
|
||||
|
||||
def test_least_significant_bit():
|
||||
# 0x00 -> LSB = 0, 0x01 -> LSB = 1
|
||||
assert least_significant_bit(b"\x00") == 0
|
||||
assert least_significant_bit(b"\x01") == 1
|
||||
assert least_significant_bit(b"\x02") == 0
|
||||
assert least_significant_bit(b"\x03") == 1
|
||||
assert least_significant_bit(b"\xff") == 1
|
||||
assert least_significant_bit(b"\xfe") == 0
|
||||
|
||||
|
||||
def test_count_leading_zero_bits():
|
||||
assert count_leading_zero_bits(b"\xff") == 0
|
||||
assert count_leading_zero_bits(b"\x7f") == 1
|
||||
assert count_leading_zero_bits(b"\x3f") == 2
|
||||
assert count_leading_zero_bits(b"\x00\xff") == 8
|
||||
assert count_leading_zero_bits(b"\x00\x00\x01") == 23
|
||||
assert count_leading_zero_bits(b"\x00") == 8
|
||||
|
||||
|
||||
def test_empty_digest_is_well_defined():
|
||||
"""Paper: 'Acknowledgement of the empty string is defined as H(∅)'."""
|
||||
h = digest(b"")
|
||||
assert len(h) == 32
|
||||
assert h == digest(b"") # deterministic
|
||||
218
tests/test_graph.py
Normal file
218
tests/test_graph.py
Normal file
|
|
@ -0,0 +1,218 @@
|
|||
"""Tests for the Lamport graph with integrity checks."""
|
||||
|
||||
import os
|
||||
import pytest
|
||||
from crisis.crypto import digest
|
||||
from crisis.graph import LamportGraph
|
||||
from crisis.message import Message, Vertex, ID_LENGTH, NONCE_LENGTH
|
||||
from crisis.weight import ProofOfWorkWeight
|
||||
|
||||
|
||||
def make_id(name: str) -> bytes:
|
||||
return digest(name.encode())[:ID_LENGTH]
|
||||
|
||||
|
||||
def make_nonce(n: int = 0) -> bytes:
|
||||
return n.to_bytes(NONCE_LENGTH, "big")
|
||||
|
||||
|
||||
def make_graph(pow_zeros: int = 0) -> LamportGraph:
|
||||
return LamportGraph(weight_system=ProofOfWorkWeight(min_leading_zeros=pow_zeros))
|
||||
|
||||
|
||||
class TestLamportGraphExtension:
|
||||
|
||||
def test_extend_single_message(self):
|
||||
g = make_graph()
|
||||
msg = Message(nonce=make_nonce(), id=make_id("alice"), payload=b"hello")
|
||||
v = g.extend(msg)
|
||||
assert v is not None
|
||||
assert g.vertex_count() == 1
|
||||
|
||||
def test_extend_chain(self):
|
||||
"""Messages from the same id must form a chain."""
|
||||
g = make_graph()
|
||||
m1 = Message(nonce=make_nonce(0), id=make_id("alice"), payload=b"first")
|
||||
v1 = g.extend(m1)
|
||||
assert v1 is not None
|
||||
|
||||
m2 = Message(
|
||||
nonce=make_nonce(1), id=make_id("alice"),
|
||||
digests=(m1.compute_digest(),),
|
||||
payload=b"second"
|
||||
)
|
||||
v2 = g.extend(m2)
|
||||
assert v2 is not None
|
||||
assert g.vertex_count() == 2
|
||||
|
||||
def test_reject_duplicate(self):
|
||||
"""No two equivalent vertices in the same graph."""
|
||||
g = make_graph()
|
||||
msg = Message(nonce=make_nonce(), id=make_id("alice"), payload=b"x")
|
||||
g.extend(msg)
|
||||
v2 = g.extend(msg)
|
||||
assert v2 is None # Rejected: duplicate
|
||||
assert g.vertex_count() == 1
|
||||
|
||||
def test_reject_missing_reference(self):
|
||||
"""Digests must reference existing vertices."""
|
||||
g = make_graph()
|
||||
fake_digest = digest(b"nonexistent")
|
||||
msg = Message(
|
||||
nonce=make_nonce(), id=make_id("alice"),
|
||||
digests=(fake_digest,), payload=b"orphan"
|
||||
)
|
||||
v = g.extend(msg)
|
||||
assert v is None # Rejected
|
||||
|
||||
def test_reject_broken_chain(self):
|
||||
"""Second message from same id must reference a same-id vertex."""
|
||||
g = make_graph()
|
||||
id_a = make_id("alice")
|
||||
id_b = make_id("bob")
|
||||
|
||||
m1 = Message(nonce=make_nonce(0), id=id_a, payload=b"first")
|
||||
g.extend(m1)
|
||||
|
||||
m_bob = Message(nonce=make_nonce(1), id=id_b, payload=b"bob's msg")
|
||||
g.extend(m_bob)
|
||||
|
||||
# Alice's second message references bob but not herself -> rejected
|
||||
m2 = Message(
|
||||
nonce=make_nonce(2), id=id_a,
|
||||
digests=(m_bob.compute_digest(),),
|
||||
payload=b"broken chain"
|
||||
)
|
||||
v = g.extend(m2)
|
||||
assert v is None
|
||||
|
||||
|
||||
class TestCausality:
|
||||
|
||||
def _build_chain(self):
|
||||
"""Build a simple 3-message chain: m1 <- m2 <- m3."""
|
||||
g = make_graph()
|
||||
id_a = make_id("alice")
|
||||
m1 = Message(nonce=make_nonce(0), id=id_a, payload=b"m1")
|
||||
v1 = g.extend(m1)
|
||||
|
||||
m2 = Message(nonce=make_nonce(1), id=id_a,
|
||||
digests=(m1.compute_digest(),), payload=b"m2")
|
||||
v2 = g.extend(m2)
|
||||
|
||||
m3 = Message(nonce=make_nonce(2), id=id_a,
|
||||
digests=(m2.compute_digest(),), payload=b"m3")
|
||||
v3 = g.extend(m3)
|
||||
|
||||
return g, v1, v2, v3
|
||||
|
||||
def test_direct_causes(self):
|
||||
g, v1, v2, v3 = self._build_chain()
|
||||
causes_of_v3 = g.direct_causes(v3)
|
||||
assert v2 in causes_of_v3
|
||||
assert v1 not in causes_of_v3
|
||||
|
||||
def test_direct_effects(self):
|
||||
g, v1, v2, v3 = self._build_chain()
|
||||
effects_of_v1 = g.direct_effects(v1)
|
||||
assert v2 in effects_of_v1
|
||||
assert v3 not in effects_of_v1 # v3 is indirect
|
||||
|
||||
def test_past(self):
|
||||
"""G_v: the past of v contains all its causes."""
|
||||
g, v1, v2, v3 = self._build_chain()
|
||||
past_of_v3 = g.past(v3)
|
||||
assert v1 in past_of_v3
|
||||
assert v2 in past_of_v3
|
||||
assert v3 in past_of_v3 # reflexive
|
||||
|
||||
def test_future(self):
|
||||
g, v1, v2, v3 = self._build_chain()
|
||||
future_of_v1 = g.future(v1)
|
||||
assert v2 in future_of_v1
|
||||
assert v3 in future_of_v1
|
||||
assert v1 in future_of_v1 # reflexive
|
||||
|
||||
def test_is_cause_of(self):
|
||||
g, v1, v2, v3 = self._build_chain()
|
||||
assert g.is_cause_of(v1, v3)
|
||||
assert g.is_cause_of(v1, v2)
|
||||
assert not g.is_cause_of(v3, v1)
|
||||
|
||||
def test_timelike(self):
|
||||
g, v1, v2, v3 = self._build_chain()
|
||||
assert g.are_timelike(v1, v3)
|
||||
assert g.are_timelike(v3, v1)
|
||||
|
||||
def test_spacelike(self):
|
||||
"""Two independent vertices are spacelike."""
|
||||
g = make_graph()
|
||||
m_a = Message(nonce=make_nonce(0), id=make_id("alice"), payload=b"a")
|
||||
m_b = Message(nonce=make_nonce(0), id=make_id("bob"), payload=b"b")
|
||||
va = g.extend(m_a)
|
||||
vb = g.extend(m_b)
|
||||
assert g.are_spacelike(va, vb)
|
||||
assert not g.are_timelike(va, vb)
|
||||
|
||||
|
||||
class TestInvarianceOfThePast:
|
||||
"""Theorem 3.7: The past of equivalent vertices in two Lamport graphs
|
||||
have the same cardinality."""
|
||||
|
||||
def test_past_invariance_simple(self):
|
||||
"""Same message in two different graphs has same-size past."""
|
||||
g1 = make_graph()
|
||||
g2 = make_graph()
|
||||
id_a = make_id("alice")
|
||||
|
||||
m1 = Message(nonce=make_nonce(0), id=id_a, payload=b"genesis")
|
||||
m2 = Message(nonce=make_nonce(1), id=id_a,
|
||||
digests=(m1.compute_digest(),), payload=b"second")
|
||||
|
||||
# Add to both graphs
|
||||
g1.extend(m1)
|
||||
v1_in_g1 = g1.extend(m2)
|
||||
|
||||
g2.extend(m1)
|
||||
v1_in_g2 = g2.extend(m2)
|
||||
|
||||
# Past should be the same size
|
||||
assert len(g1.past(v1_in_g1)) == len(g2.past(v1_in_g2))
|
||||
|
||||
|
||||
class TestMessageGeneration:
|
||||
|
||||
def test_generate_first_message(self):
|
||||
g = make_graph()
|
||||
msg = g.generate_message(make_id("alice"), b"hello")
|
||||
v = g.extend(msg)
|
||||
assert v is not None
|
||||
assert v.payload == b"hello"
|
||||
|
||||
def test_generate_chain(self):
|
||||
g = make_graph()
|
||||
pid = make_id("alice")
|
||||
m1 = g.generate_message(pid, b"first")
|
||||
g.extend(m1)
|
||||
|
||||
m2 = g.generate_message(pid, b"second")
|
||||
v2 = g.extend(m2)
|
||||
assert v2 is not None
|
||||
# Second message should reference the first
|
||||
assert m1.compute_digest() in m2.digests
|
||||
|
||||
def test_generate_cross_references(self):
|
||||
"""Messages should reference vertices from other process ids."""
|
||||
g = make_graph()
|
||||
pid_a = make_id("alice")
|
||||
pid_b = make_id("bob")
|
||||
|
||||
m_a = g.generate_message(pid_a, b"alice's msg")
|
||||
g.extend(m_a)
|
||||
|
||||
m_b = g.generate_message(pid_b, b"bob's msg")
|
||||
g.extend(m_b)
|
||||
|
||||
# Alice's second message should reference bob's message
|
||||
m_a2 = g.generate_message(pid_a, b"alice second")
|
||||
assert m_b.compute_digest() in m_a2.digests or m_a.compute_digest() in m_a2.digests
|
||||
125
tests/test_message.py
Normal file
125
tests/test_message.py
Normal file
|
|
@ -0,0 +1,125 @@
|
|||
"""Tests for the message and vertex data structures."""
|
||||
|
||||
import pytest
|
||||
from crisis.crypto import digest, DIGEST_LENGTH
|
||||
from crisis.message import (
|
||||
Message, Vertex, Vote,
|
||||
NONCE_LENGTH, ID_LENGTH, NUM_DIGESTS_LENGTH,
|
||||
EMPTY_MESSAGE_DIGEST,
|
||||
)
|
||||
|
||||
|
||||
def make_id(name: str) -> bytes:
|
||||
return digest(name.encode())[:ID_LENGTH]
|
||||
|
||||
|
||||
def make_nonce(n: int = 0) -> bytes:
|
||||
return n.to_bytes(NONCE_LENGTH, "big")
|
||||
|
||||
|
||||
class TestMessage:
|
||||
|
||||
def test_create_minimal_message(self):
|
||||
msg = Message(nonce=make_nonce(), id=make_id("test"), digests=(), payload=b"")
|
||||
assert msg.num_digests == 0
|
||||
|
||||
def test_nonce_length_validation(self):
|
||||
with pytest.raises(ValueError, match="nonce"):
|
||||
Message(nonce=b"\x00", id=make_id("x"))
|
||||
|
||||
def test_id_length_validation(self):
|
||||
with pytest.raises(ValueError, match="id"):
|
||||
Message(nonce=make_nonce(), id=b"\x00")
|
||||
|
||||
def test_digest_length_validation(self):
|
||||
with pytest.raises(ValueError, match="digest"):
|
||||
Message(nonce=make_nonce(), id=make_id("x"),
|
||||
digests=(b"\x00",))
|
||||
|
||||
def test_serialize_roundtrip_deterministic(self):
|
||||
msg = Message(nonce=make_nonce(42), id=make_id("proc1"),
|
||||
digests=(), payload=b"hello world")
|
||||
serialized = msg.serialize()
|
||||
assert isinstance(serialized, bytes)
|
||||
# Same message serializes the same way
|
||||
assert msg.serialize() == serialized
|
||||
|
||||
def test_compute_digest_deterministic(self):
|
||||
msg = Message(nonce=make_nonce(), id=make_id("test"), payload=b"data")
|
||||
d1 = msg.compute_digest()
|
||||
d2 = msg.compute_digest()
|
||||
assert d1 == d2
|
||||
assert len(d1) == DIGEST_LENGTH
|
||||
|
||||
def test_different_messages_different_digests(self):
|
||||
m1 = Message(nonce=make_nonce(1), id=make_id("a"), payload=b"x")
|
||||
m2 = Message(nonce=make_nonce(2), id=make_id("a"), payload=b"x")
|
||||
assert m1.compute_digest() != m2.compute_digest()
|
||||
|
||||
def test_message_with_digests(self):
|
||||
parent = Message(nonce=make_nonce(), id=make_id("a"), payload=b"parent")
|
||||
child = Message(
|
||||
nonce=make_nonce(1), id=make_id("a"),
|
||||
digests=(parent.compute_digest(),),
|
||||
payload=b"child"
|
||||
)
|
||||
assert child.num_digests == 1
|
||||
assert child.digests[0] == parent.compute_digest()
|
||||
|
||||
def test_message_is_immutable(self):
|
||||
msg = Message(nonce=make_nonce(), id=make_id("x"), payload=b"y")
|
||||
with pytest.raises(AttributeError):
|
||||
msg.nonce = b"\x00" * NONCE_LENGTH
|
||||
|
||||
|
||||
class TestVertex:
|
||||
|
||||
def test_vertex_wraps_message(self):
|
||||
msg = Message(nonce=make_nonce(), id=make_id("proc"), payload=b"data")
|
||||
v = Vertex(m=msg)
|
||||
assert v.nonce == msg.nonce
|
||||
assert v.id == msg.id
|
||||
assert v.payload == msg.payload
|
||||
assert v.digests == msg.digests
|
||||
|
||||
def test_vertex_default_state(self):
|
||||
msg = Message(nonce=make_nonce(), id=make_id("x"))
|
||||
v = Vertex(m=msg)
|
||||
assert v.round is None
|
||||
assert v.is_last is None
|
||||
assert v.svp == []
|
||||
assert v.vote == {}
|
||||
assert v.total_position is None
|
||||
|
||||
def test_vertex_equivalence(self):
|
||||
"""Definition 3.3: two vertices are equivalent if v.m = v_hat.m"""
|
||||
msg = Message(nonce=make_nonce(), id=make_id("x"), payload=b"same")
|
||||
v1 = Vertex(m=msg)
|
||||
v2 = Vertex(m=msg)
|
||||
assert v1.equivalent_to(v2)
|
||||
assert v1 == v2
|
||||
assert hash(v1) == hash(v2)
|
||||
|
||||
def test_vertex_non_equivalence(self):
|
||||
m1 = Message(nonce=make_nonce(1), id=make_id("x"))
|
||||
m2 = Message(nonce=make_nonce(2), id=make_id("x"))
|
||||
v1 = Vertex(m=m1)
|
||||
v2 = Vertex(m=m2)
|
||||
assert not v1.equivalent_to(v2)
|
||||
assert v1 != v2
|
||||
|
||||
|
||||
class TestVote:
|
||||
|
||||
def test_vote_undecided(self):
|
||||
v = Vote(message=None, binary=None)
|
||||
assert "∅" in repr(v)
|
||||
assert "⊥" in repr(v)
|
||||
|
||||
def test_vote_with_message(self):
|
||||
msg = Message(nonce=make_nonce(), id=make_id("x"))
|
||||
v = Vote(message=msg, binary=1)
|
||||
assert v.binary == 1
|
||||
|
||||
def test_empty_message_digest(self):
|
||||
assert EMPTY_MESSAGE_DIGEST == digest(b"")
|
||||
126
tests/test_order.py
Normal file
126
tests/test_order.py
Normal file
|
|
@ -0,0 +1,126 @@
|
|||
"""Tests for total order computation."""
|
||||
|
||||
from crisis.crypto import digest
|
||||
from crisis.graph import LamportGraph
|
||||
from crisis.message import Message, ID_LENGTH, NONCE_LENGTH
|
||||
from crisis.order import LeaderStream, compute_order, _kahns_total_order
|
||||
from crisis.weight import ProofOfWorkWeight
|
||||
|
||||
|
||||
def make_id(name: str) -> bytes:
|
||||
return digest(name.encode())[:ID_LENGTH]
|
||||
|
||||
|
||||
def make_nonce(n: int = 0) -> bytes:
|
||||
return n.to_bytes(NONCE_LENGTH, "big")
|
||||
|
||||
|
||||
def make_graph() -> LamportGraph:
|
||||
return LamportGraph(weight_system=ProofOfWorkWeight(min_leading_zeros=0))
|
||||
|
||||
|
||||
class TestLeaderStream:
|
||||
|
||||
def test_empty_stream(self):
|
||||
ls = LeaderStream()
|
||||
assert ls.max_round() == -1
|
||||
assert ls.get_leader(0) is None
|
||||
|
||||
def test_add_leader(self):
|
||||
ls = LeaderStream()
|
||||
msg = Message(nonce=make_nonce(), id=make_id("leader"), payload=b"L")
|
||||
updated = ls.update(0, 1, msg)
|
||||
assert updated is True
|
||||
assert ls.get_leader(0) is msg
|
||||
|
||||
def test_higher_deciding_round_replaces(self):
|
||||
ls = LeaderStream()
|
||||
m1 = Message(nonce=make_nonce(1), id=make_id("l1"), payload=b"old")
|
||||
m2 = Message(nonce=make_nonce(2), id=make_id("l2"), payload=b"new")
|
||||
|
||||
ls.update(0, 1, m1)
|
||||
ls.update(0, 2, m2)
|
||||
|
||||
assert ls.get_leader(0) is m2
|
||||
|
||||
def test_lower_deciding_round_rejected(self):
|
||||
ls = LeaderStream()
|
||||
m1 = Message(nonce=make_nonce(1), id=make_id("l1"), payload=b"first")
|
||||
m2 = Message(nonce=make_nonce(2), id=make_id("l2"), payload=b"late")
|
||||
|
||||
ls.update(0, 5, m1)
|
||||
updated = ls.update(0, 3, m2)
|
||||
|
||||
assert updated is False
|
||||
assert ls.get_leader(0) is m1
|
||||
|
||||
def test_all_leaders_sorted(self):
|
||||
ls = LeaderStream()
|
||||
m0 = Message(nonce=make_nonce(0), id=make_id("l0"), payload=b"r0")
|
||||
m1 = Message(nonce=make_nonce(1), id=make_id("l1"), payload=b"r1")
|
||||
m2 = Message(nonce=make_nonce(2), id=make_id("l2"), payload=b"r2")
|
||||
|
||||
ls.update(2, 3, m2)
|
||||
ls.update(0, 1, m0)
|
||||
ls.update(1, 2, m1)
|
||||
|
||||
leaders = ls.all_leaders()
|
||||
rounds = [r for r, _ in leaders]
|
||||
assert rounds == sorted(rounds)
|
||||
|
||||
|
||||
class TestKahnsAlgorithm:
|
||||
|
||||
def test_empty_input(self):
|
||||
g = make_graph()
|
||||
result = _kahns_total_order([], g)
|
||||
assert result == []
|
||||
|
||||
def test_single_vertex(self):
|
||||
g = make_graph()
|
||||
msg = Message(nonce=make_nonce(), id=make_id("x"), payload=b"only")
|
||||
v = g.extend(msg)
|
||||
result = _kahns_total_order([v], g)
|
||||
assert result == [v]
|
||||
|
||||
def test_chain_order(self):
|
||||
"""A chain should be ordered causes-first."""
|
||||
g = make_graph()
|
||||
pid = make_id("alice")
|
||||
m1 = Message(nonce=make_nonce(0), id=pid, payload=b"first")
|
||||
v1 = g.extend(m1)
|
||||
|
||||
m2 = Message(nonce=make_nonce(1), id=pid,
|
||||
digests=(m1.compute_digest(),), payload=b"second")
|
||||
v2 = g.extend(m2)
|
||||
|
||||
m3 = Message(nonce=make_nonce(2), id=pid,
|
||||
digests=(m2.compute_digest(),), payload=b"third")
|
||||
v3 = g.extend(m3)
|
||||
|
||||
result = _kahns_total_order([v1, v2, v3], g)
|
||||
# Causes come first
|
||||
assert result.index(v1) < result.index(v2)
|
||||
assert result.index(v2) < result.index(v3)
|
||||
|
||||
def test_respects_causality(self):
|
||||
"""Total order must be consistent with causal order."""
|
||||
g = make_graph()
|
||||
m_a = Message(nonce=make_nonce(0), id=make_id("alice"), payload=b"a")
|
||||
va = g.extend(m_a)
|
||||
|
||||
m_b = Message(nonce=make_nonce(0), id=make_id("bob"), payload=b"b")
|
||||
vb = g.extend(m_b)
|
||||
|
||||
# Carol references both alice and bob
|
||||
m_c = Message(
|
||||
nonce=make_nonce(1), id=make_id("carol"),
|
||||
digests=(m_a.compute_digest(), m_b.compute_digest()),
|
||||
payload=b"c"
|
||||
)
|
||||
vc = g.extend(m_c)
|
||||
|
||||
result = _kahns_total_order([va, vb, vc], g)
|
||||
# Carol must come after both alice and bob
|
||||
assert result.index(va) < result.index(vc)
|
||||
assert result.index(vb) < result.index(vc)
|
||||
110
tests/test_rounds.py
Normal file
110
tests/test_rounds.py
Normal file
|
|
@ -0,0 +1,110 @@
|
|||
"""Tests for virtual synchronous rounds."""
|
||||
|
||||
from crisis.crypto import digest
|
||||
from crisis.graph import LamportGraph
|
||||
from crisis.message import Message, ID_LENGTH, NONCE_LENGTH
|
||||
from crisis.rounds import compute_rounds, max_round, last_vertices_in_round, vertices_in_round
|
||||
from crisis.weight import ProofOfWorkWeight, DifficultyOracle
|
||||
|
||||
|
||||
def make_id(name: str) -> bytes:
|
||||
return digest(name.encode())[:ID_LENGTH]
|
||||
|
||||
|
||||
def make_nonce(n: int = 0) -> bytes:
|
||||
return n.to_bytes(NONCE_LENGTH, "big")
|
||||
|
||||
|
||||
def make_graph() -> LamportGraph:
|
||||
return LamportGraph(weight_system=ProofOfWorkWeight(min_leading_zeros=0))
|
||||
|
||||
|
||||
class TestRoundComputation:
|
||||
|
||||
def test_single_vertex_round_zero(self):
|
||||
"""A single vertex with no causes is in round 0."""
|
||||
g = make_graph()
|
||||
msg = Message(nonce=make_nonce(), id=make_id("alice"), payload=b"genesis")
|
||||
v = g.extend(msg)
|
||||
compute_rounds(g, DifficultyOracle(constant_difficulty=1))
|
||||
assert v.round == 0
|
||||
|
||||
def test_single_vertex_is_last(self):
|
||||
"""Round 0 vertices are always 'last' (bootstrapping)."""
|
||||
g = make_graph()
|
||||
msg = Message(nonce=make_nonce(), id=make_id("alice"))
|
||||
v = g.extend(msg)
|
||||
compute_rounds(g, DifficultyOracle(constant_difficulty=1))
|
||||
assert v.is_last is True
|
||||
|
||||
def test_chain_grows_rounds(self):
|
||||
"""A chain of messages should produce increasing round numbers."""
|
||||
g = make_graph()
|
||||
pid = make_id("alice")
|
||||
difficulty = DifficultyOracle(constant_difficulty=0) # Low difficulty
|
||||
|
||||
# Create a chain
|
||||
prev_msg = None
|
||||
vertices = []
|
||||
for i in range(5):
|
||||
digests = (prev_msg.compute_digest(),) if prev_msg else ()
|
||||
msg = Message(nonce=make_nonce(i), id=pid, digests=digests, payload=f"msg{i}".encode())
|
||||
v = g.extend(msg)
|
||||
vertices.append(v)
|
||||
prev_msg = msg
|
||||
|
||||
compute_rounds(g, difficulty, connectivity_k=0)
|
||||
|
||||
# All should have round numbers assigned
|
||||
for v in vertices:
|
||||
assert v.round is not None
|
||||
|
||||
# First vertex is round 0
|
||||
assert vertices[0].round == 0
|
||||
|
||||
def test_max_round_empty_graph(self):
|
||||
g = make_graph()
|
||||
assert max_round(g) == 0
|
||||
|
||||
def test_max_round_with_vertices(self):
|
||||
g = make_graph()
|
||||
msg = Message(nonce=make_nonce(), id=make_id("x"))
|
||||
g.extend(msg)
|
||||
compute_rounds(g, DifficultyOracle(constant_difficulty=1))
|
||||
assert max_round(g) == 0
|
||||
|
||||
def test_last_vertices_in_round(self):
|
||||
g = make_graph()
|
||||
msg = Message(nonce=make_nonce(), id=make_id("alice"))
|
||||
g.extend(msg)
|
||||
compute_rounds(g, DifficultyOracle(constant_difficulty=1))
|
||||
lasts = last_vertices_in_round(g, 0)
|
||||
assert len(lasts) == 1
|
||||
|
||||
def test_multiple_ids_same_round(self):
|
||||
"""Multiple independent vertices are all in round 0."""
|
||||
g = make_graph()
|
||||
for name in ["alice", "bob", "carol"]:
|
||||
msg = Message(nonce=make_nonce(), id=make_id(name), payload=name.encode())
|
||||
g.extend(msg)
|
||||
|
||||
compute_rounds(g, DifficultyOracle(constant_difficulty=1))
|
||||
|
||||
r0 = vertices_in_round(g, 0)
|
||||
assert len(r0) == 3
|
||||
|
||||
def test_round_invariance(self):
|
||||
"""Proposition 5.3: equivalent vertices in different graphs have same round."""
|
||||
g1 = make_graph()
|
||||
g2 = make_graph()
|
||||
difficulty = DifficultyOracle(constant_difficulty=1)
|
||||
|
||||
msg = Message(nonce=make_nonce(), id=make_id("alice"), payload=b"genesis")
|
||||
v1 = g1.extend(msg)
|
||||
v2 = g2.extend(msg)
|
||||
|
||||
compute_rounds(g1, difficulty)
|
||||
compute_rounds(g2, difficulty)
|
||||
|
||||
assert v1.round == v2.round
|
||||
assert v1.is_last == v2.is_last
|
||||
55
tests/test_simulation.py
Normal file
55
tests/test_simulation.py
Normal file
|
|
@ -0,0 +1,55 @@
|
|||
"""Integration test: run the full simulation and verify basic properties."""
|
||||
|
||||
from crisis.demo import Simulation
|
||||
|
||||
|
||||
class TestSimulation:
|
||||
|
||||
def test_simulation_runs(self):
|
||||
"""The simulation should complete without errors."""
|
||||
sim = Simulation(num_honest=3, num_byzantine=0, seed=42)
|
||||
results = sim.run(num_steps=5, verbose=False)
|
||||
assert len(results) == 5
|
||||
|
||||
def test_graphs_grow(self):
|
||||
"""Each step should add messages to the graphs."""
|
||||
sim = Simulation(num_honest=2, seed=42)
|
||||
sim.run(num_steps=3, verbose=False)
|
||||
for node in sim.nodes:
|
||||
assert node.graph.vertex_count() > 0
|
||||
|
||||
def test_honest_nodes_same_graph_size(self):
|
||||
"""All honest nodes should have the same number of vertices
|
||||
(since all messages are delivered to all nodes)."""
|
||||
sim = Simulation(num_honest=3, seed=42)
|
||||
sim.run(num_steps=5, verbose=False)
|
||||
sizes = [n.graph.vertex_count() for n in sim.nodes]
|
||||
assert all(s == sizes[0] for s in sizes)
|
||||
|
||||
def test_rounds_are_computed(self):
|
||||
"""After running, vertices should have round numbers."""
|
||||
sim = Simulation(num_honest=3, seed=42)
|
||||
sim.run(num_steps=5, verbose=False)
|
||||
for node in sim.nodes:
|
||||
for v in node.graph.all_vertices():
|
||||
assert v.round is not None
|
||||
|
||||
def test_with_byzantine_node(self):
|
||||
"""Simulation should handle byzantine nodes without crashing."""
|
||||
sim = Simulation(num_honest=3, num_byzantine=1, seed=42)
|
||||
results = sim.run(num_steps=5, verbose=False)
|
||||
assert len(results) == 5
|
||||
|
||||
def test_deterministic_with_seed(self):
|
||||
"""Same seed should produce the same results."""
|
||||
sim1 = Simulation(num_honest=3, seed=123)
|
||||
r1 = sim1.run(num_steps=3, verbose=False)
|
||||
|
||||
sim2 = Simulation(num_honest=3, seed=123)
|
||||
r2 = sim2.run(num_steps=3, verbose=False)
|
||||
|
||||
# Same number of messages at each step
|
||||
for s1, s2 in zip(r1, r2):
|
||||
assert len(s1["new_messages"]) == len(s2["new_messages"])
|
||||
for ns1, ns2 in zip(s1["node_states"], s2["node_states"]):
|
||||
assert ns1["vertices"] == ns2["vertices"]
|
||||
78
tests/test_weight.py
Normal file
78
tests/test_weight.py
Normal file
|
|
@ -0,0 +1,78 @@
|
|||
"""Tests for the weight system and difficulty oracle."""
|
||||
|
||||
from crisis.crypto import digest
|
||||
from crisis.message import Message, ID_LENGTH, NONCE_LENGTH
|
||||
from crisis.weight import ProofOfWorkWeight, DifficultyOracle
|
||||
|
||||
|
||||
def make_id(name: str) -> bytes:
|
||||
return digest(name.encode())[:ID_LENGTH]
|
||||
|
||||
|
||||
def make_nonce(n: int = 0) -> bytes:
|
||||
return n.to_bytes(NONCE_LENGTH, "big")
|
||||
|
||||
|
||||
class TestProofOfWorkWeight:
|
||||
|
||||
def test_weight_is_non_negative(self):
|
||||
ws = ProofOfWorkWeight(min_leading_zeros=0)
|
||||
msg = Message(nonce=make_nonce(), id=make_id("x"), payload=b"test")
|
||||
assert ws.weight(msg) >= 0
|
||||
|
||||
def test_weight_sum_is_additive(self):
|
||||
ws = ProofOfWorkWeight()
|
||||
assert ws.weight_sum(3, 5) == 8
|
||||
assert ws.weight_sum(0, 0) == 0
|
||||
|
||||
def test_threshold(self):
|
||||
ws = ProofOfWorkWeight(min_leading_zeros=2)
|
||||
assert ws.threshold == 2
|
||||
|
||||
def test_is_valid_weight_with_zero_threshold(self):
|
||||
ws = ProofOfWorkWeight(min_leading_zeros=0)
|
||||
msg = Message(nonce=make_nonce(), id=make_id("x"))
|
||||
assert ws.is_valid_weight(msg) # Everything passes with 0
|
||||
|
||||
def test_mine_nonce_finds_valid_message(self):
|
||||
ws = ProofOfWorkWeight(min_leading_zeros=1)
|
||||
msg = ws.mine_nonce(
|
||||
id_bytes=make_id("miner"),
|
||||
digests=(),
|
||||
payload=b"test payload",
|
||||
target_weight=1
|
||||
)
|
||||
assert ws.weight(msg) >= 1
|
||||
assert ws.is_valid_weight(msg)
|
||||
|
||||
def test_different_nonces_different_weights(self):
|
||||
"""Uniqueness property: different messages have different weights (w.h.p.)."""
|
||||
ws = ProofOfWorkWeight()
|
||||
weights = set()
|
||||
for i in range(20):
|
||||
msg = Message(nonce=make_nonce(i), id=make_id("x"), payload=b"same")
|
||||
weights.add(ws.weight(msg))
|
||||
# Not all the same (with overwhelming probability)
|
||||
assert len(weights) > 1
|
||||
|
||||
def test_tamper_proof(self):
|
||||
"""Changing a message should change its weight (w.h.p.)."""
|
||||
ws = ProofOfWorkWeight()
|
||||
msg1 = Message(nonce=make_nonce(42), id=make_id("x"), payload=b"original")
|
||||
msg2 = Message(nonce=make_nonce(42), id=make_id("x"), payload=b"tampered")
|
||||
# Weights differ because digests differ
|
||||
# (this is probabilistic, but extremely likely)
|
||||
assert msg1.compute_digest() != msg2.compute_digest()
|
||||
|
||||
|
||||
class TestDifficultyOracle:
|
||||
|
||||
def test_constant_difficulty(self):
|
||||
d = DifficultyOracle(constant_difficulty=5)
|
||||
assert d.difficulty(0) == 5
|
||||
assert d.difficulty(100) == 5
|
||||
assert d.difficulty(999) == 5
|
||||
|
||||
def test_default_difficulty(self):
|
||||
d = DifficultyOracle()
|
||||
assert d.difficulty(0) == 4
|
||||
Loading…
Reference in a new issue