Add JSON export pipeline + event recorder for visualization

The simulation now optionally records structured events (message creation,
delivery, round computation, voting, leader election) via EventRecorder and
exports a complete simulation dump to JSON via the new export_json module.
crisis_data.json captures a 10-step run that the SwiftUI visualizer consumes.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
saymrwulf 2026-04-30 20:06:21 +02:00
parent 37e9f26204
commit 1491422527
5 changed files with 186369 additions and 25 deletions

185766
crisis_data.json Normal file

File diff suppressed because it is too large Load diff

View file

@ -28,6 +28,10 @@ from crisis.crypto import digest
from crisis.graph import LamportGraph
from crisis.message import Message, Vertex, ID_LENGTH, NONCE_LENGTH
from crisis.order import LeaderStream, compute_order
from crisis.recorder import (
EventRecorder, EventType, capture_snapshot,
record_rounds, record_voting, record_leader_election,
)
from crisis.rounds import compute_rounds, max_round, last_vertices_in_round
from crisis.voting import compute_safe_voting_pattern, compute_virtual_leader_election
from crisis.weight import ProofOfWorkWeight, DifficultyOracle
@ -79,11 +83,13 @@ class Simulation:
def __init__(self, num_honest: int = 3, num_byzantine: int = 0,
pow_zeros: int = 2, difficulty: int = 1,
connectivity_k: int = 0, seed: int = 42):
connectivity_k: int = 0, seed: int = 42,
recorder: Optional[EventRecorder] = None):
self.difficulty_oracle = DifficultyOracle(constant_difficulty=difficulty)
self.connectivity_k = connectivity_k
self.weight_system = ProofOfWorkWeight(min_leading_zeros=pow_zeros)
self.seed = seed
self.recorder = recorder
random.seed(seed)
# Create nodes
@ -106,6 +112,7 @@ class Simulation:
self.step_count = 0
self.all_messages: list[Message] = []
self.snapshots: list[capture_snapshot.__class__] = [] # type: ignore
def step(self) -> dict:
"""Execute one simulation step.
@ -113,6 +120,12 @@ class Simulation:
Returns a dict with step results for display.
"""
self.step_count += 1
rec = self.recorder
if rec:
rec.record(self.step_count, EventType.STEP_BEGIN, "",
sim_step=self.step_count)
step_results = {
"step": self.step_count,
"new_messages": [],
@ -130,57 +143,99 @@ class Simulation:
if msg is not None:
new_messages.append((node, msg))
msg_digest = msg.compute_digest().hex()[:12]
msg_weight = self.weight_system.weight(msg)
step_results["new_messages"].append({
"from": node.name,
"digest": msg.compute_digest().hex()[:12],
"weight": self.weight_system.weight(msg),
"digest": msg_digest,
"weight": msg_weight,
"payload": msg.payload.decode(errors="replace"),
})
if rec:
evt = EventType.BYZANTINE_MUTATION if node.is_byzantine else EventType.MESSAGE_CREATED
rec.record(
self.step_count, evt, node.name,
digest_hex=msg_digest,
process_id_hex=msg.id.hex()[:8],
payload_str=msg.payload.decode(errors="replace")[:60],
weight=msg_weight,
num_refs=len(msg.digests),
)
# Phase 2: Gossip -- deliver all messages to all nodes
for source_node, msg in new_messages:
self.all_messages.append(msg)
for target_node in self.nodes:
# Deliver to all nodes (including source, for consistency)
target_node.graph.extend(msg)
result = target_node.graph.extend(msg)
if rec and result is not None:
rec.record(
self.step_count, EventType.MESSAGE_DELIVERED, target_node.name,
digest_hex=msg.compute_digest().hex()[:12],
from_node=source_node.name,
)
# Also re-deliver older messages that nodes might be missing
# (simulates pull gossip catching up)
for msg in self.all_messages:
for node in self.nodes:
node.graph.extend(msg) # extend() is idempotent (integrity check)
node.graph.extend(msg) # extend() is idempotent
# Phase 3: Compute consensus on each node
self._last_orders: dict[str, list] = {}
for node in self.nodes:
compute_rounds(node.graph, self.difficulty_oracle, self.connectivity_k)
if rec:
record_rounds(node.graph, self.difficulty_oracle,
self.connectivity_k, rec,
self.step_count, node.name)
else:
compute_rounds(node.graph, self.difficulty_oracle,
self.connectivity_k)
# Compute SVP for all last vertices
for vertex in node.graph.all_vertices():
if vertex.is_last:
compute_safe_voting_pattern(
vertex, node.graph, self.difficulty_oracle,
self.connectivity_k
)
if rec:
record_voting(vertex, node.graph,
self.difficulty_oracle,
self.connectivity_k, rec,
self.step_count, node.name)
else:
compute_safe_voting_pattern(
vertex, node.graph, self.difficulty_oracle,
self.connectivity_k
)
# Compute leader election in round order (lower rounds first).
# This ensures that when a higher-round vertex reads votes from
# its voting set members, those members have already computed
# their own votes.
# Compute leader election in round order
leader_dict: dict[int, list[tuple[int, Message]]] = {}
svp_vertices = [v for v in node.graph.all_vertices() if v.svp]
svp_vertices.sort(key=lambda v: v.round if v.round is not None else 0)
for vertex in svp_vertices:
compute_virtual_leader_election(
vertex, node.graph, self.difficulty_oracle,
self.connectivity_k, leader_dict
)
if rec:
record_leader_election(
vertex, node.graph, self.difficulty_oracle,
self.connectivity_k, leader_dict, rec,
self.step_count, node.name
)
else:
compute_virtual_leader_election(
vertex, node.graph, self.difficulty_oracle,
self.connectivity_k, leader_dict
)
for round_num, entries in leader_dict.items():
for deciding_round, leader_msg in entries:
node.leader_stream.update(round_num, deciding_round, leader_msg)
ordered = compute_order(node.graph, node.leader_stream)
self._last_orders[node.name] = ordered
if rec:
rec.record(
self.step_count, EventType.ORDER_COMPUTED, node.name,
count=len(ordered),
)
mr = max_round(node.graph)
step_results["node_states"].append({
@ -192,6 +247,10 @@ class Simulation:
"is_byzantine": node.is_byzantine,
})
if rec:
rec.record(self.step_count, EventType.STEP_END, "",
sim_step=self.step_count)
return step_results
def _byzantine_message(self, node: SimulatedNode) -> Optional[Message]:
@ -221,15 +280,39 @@ class Simulation:
else:
return node.generate_message(payload)
def run(self, num_steps: int = 10, verbose: bool = True) -> list[dict]:
"""Run the simulation for a number of steps."""
def run(self, num_steps: int = 10, verbose: bool = True,
progress_callback=None) -> list[dict]:
"""Run the simulation for a number of steps.
Args:
num_steps: Number of simulation steps to run.
verbose: Print step results to stdout.
progress_callback: Optional callable(step, total) for progress UI.
"""
results = []
for _ in range(num_steps):
for i in range(num_steps):
result = self.step()
results.append(result)
if verbose:
_print_step(result)
# Capture snapshot for visualization
if self.recorder:
snap = capture_snapshot(self.step_count, self.nodes,
self.weight_system,
precomputed_orders=self._last_orders)
self.recorder.snapshots.append(snap)
# Convergence check event
self.recorder.record(
self.step_count, EventType.CONVERGENCE_CHECK, "",
convergence=snap.convergence,
agreed_prefix=snap.agreed_prefix_length,
)
if progress_callback:
progress_callback(i + 1, num_steps)
if verbose:
_print_convergence_summary(self)

176
src/crisis/export_json.py Normal file
View file

@ -0,0 +1,176 @@
"""
JSON Exporter Exports simulation data for the native macOS visualizer.
Runs the simulation and writes a complete JSON file containing:
- Configuration parameters
- Per-step snapshots (vertices, edges, rounds, leaders, order)
- Per-step events (message creation, gossip, round assignment, etc.)
Usage:
python -m crisis.export_json [--nodes 8] [--steps 10] [-o crisis_data.json]
"""
from __future__ import annotations
import json
import sys
import argparse
from dataclasses import asdict
from crisis.demo import Simulation
from crisis.recorder import EventRecorder, EventType
def export_simulation(
num_honest: int = 8,
num_byzantine: int = 1,
num_steps: int = 10,
pow_zeros: int = 1,
difficulty: int = 1,
connectivity_k: int = 0,
seed: int = 42,
) -> dict:
"""Run simulation and return exportable dict."""
recorder = EventRecorder()
sim = Simulation(
num_honest=num_honest,
num_byzantine=num_byzantine,
pow_zeros=pow_zeros,
difficulty=difficulty,
connectivity_k=connectivity_k,
seed=seed,
recorder=recorder,
)
sim.run(num_steps=num_steps, verbose=False)
# Node metadata
from crisis.crypto import digest
from crisis.message import ID_LENGTH
node_meta = []
for n in sim.nodes:
pid = digest(n.name.encode())[:ID_LENGTH].hex()[:8]
node_meta.append({
"name": n.name,
"processIdHex": pid,
"isByzantine": n.is_byzantine,
})
# Config
config = {
"numHonest": num_honest,
"numByzantine": num_byzantine,
"numSteps": num_steps,
"powZeros": pow_zeros,
"difficulty": difficulty,
"connectivityK": connectivity_k,
"seed": seed,
}
# Snapshots
steps_data = []
for snap in recorder.snapshots:
step_obj = {
"step": snap.step,
"convergence": snap.convergence,
"agreedPrefixLength": snap.agreed_prefix_length,
"nodeSnapshots": {},
}
for name, ns in snap.node_snapshots.items():
vertices = []
for v in ns.vertices:
vertices.append({
"digestHex": v.digest_hex,
"digestFull": v.digest_full,
"processIdHex": v.process_id_hex,
"roundNumber": v.round_number,
"isLast": v.is_last,
"weight": v.weight,
"payloadStr": v.payload_str,
"totalPosition": v.total_position,
"isByzantineSource": v.is_byzantine_source,
})
edges = [{"from": e[0], "to": e[1]} for e in ns.edges]
leader_rounds = {str(k): v for k, v in ns.leader_rounds.items()}
step_obj["nodeSnapshots"][name] = {
"name": ns.name,
"vertexCount": ns.vertex_count,
"maxRound": ns.max_round,
"numLeaders": ns.num_leaders,
"numOrdered": ns.num_ordered,
"isByzantine": ns.is_byzantine,
"vertices": vertices,
"edges": edges,
"leaderRounds": leader_rounds,
}
steps_data.append(step_obj)
# Events (grouped by step)
events_by_step: dict[int, list] = {}
for e in recorder.events:
step = e.step
if step not in events_by_step:
events_by_step[step] = []
events_by_step[step].append({
"seq": e.seq,
"type": e.event_type.name,
"nodeName": e.node_name,
"data": _clean_data(e.data),
})
return {
"config": config,
"nodes": node_meta,
"steps": steps_data,
"events": events_by_step,
}
def _clean_data(data: dict) -> dict:
"""Ensure all values are JSON-serializable."""
clean = {}
for k, v in data.items():
if isinstance(v, bytes):
clean[k] = v.hex()
elif isinstance(v, (int, float, str, bool, type(None))):
clean[k] = v
elif isinstance(v, (list, tuple)):
clean[k] = [x.hex() if isinstance(x, bytes) else x for x in v]
elif isinstance(v, dict):
clean[k] = _clean_data(v)
else:
clean[k] = str(v)
return clean
def main():
parser = argparse.ArgumentParser(description="Export Crisis simulation to JSON")
parser.add_argument("--nodes", type=int, default=8)
parser.add_argument("--byzantine", type=int, default=1)
parser.add_argument("--steps", type=int, default=10)
parser.add_argument("--pow-zeros", type=int, default=1)
parser.add_argument("--difficulty", type=int, default=1)
parser.add_argument("--connectivity-k", type=int, default=0)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("-o", "--output", default="crisis_data.json")
args = parser.parse_args()
data = export_simulation(
num_honest=args.nodes,
num_byzantine=args.byzantine,
num_steps=args.steps,
pow_zeros=args.pow_zeros,
difficulty=args.difficulty,
connectivity_k=args.connectivity_k,
seed=args.seed,
)
with open(args.output, "w") as f:
json.dump(data, f, indent=2)
n_events = sum(len(v) for v in data["events"].values())
n_snaps = len(data["steps"])
print(f"Exported: {n_events} events, {n_snaps} snapshots → {args.output}")
if __name__ == "__main__":
main()

View file

@ -60,6 +60,10 @@ class LamportGraph:
# digest -> set of digests that reference this vertex
self.reverse_edges: dict[bytes, set[bytes]] = {}
# Cache for past() results: digest -> frozenset of digests
# Invalidated when new vertices are added via extend()
self._past_cache: dict[bytes, frozenset[bytes]] = {}
# ------------------------------------------------------------------
# Graph queries
# ------------------------------------------------------------------
@ -124,8 +128,13 @@ class LamportGraph:
Returns the set of all vertices that are causally before v
(including v itself -- reflexivity).
"""
d = v.message_digest
cached = self._past_cache.get(d)
if cached is not None:
return {self.vertices[dd] for dd in cached if dd in self.vertices}
visited: set[bytes] = set()
stack = [v.message_digest]
stack = [d]
while stack:
current = stack.pop()
@ -136,7 +145,8 @@ class LamportGraph:
if neighbor in self.vertices and neighbor not in visited:
stack.append(neighbor)
return {self.vertices[d] for d in visited if d in self.vertices}
self._past_cache[d] = frozenset(visited)
return {self.vertices[dd] for dd in visited if dd in self.vertices}
def future(self, v: Vertex) -> set[Vertex]:
"""All vertices that are causally after v (including v itself)."""

309
src/crisis/recorder.py Normal file
View file

@ -0,0 +1,309 @@
"""
Event Recording System for Crisis Protocol Visualization
Records all protocol events during a simulation run, producing a structured
event log and per-step snapshots. The visualization application replays
these recordings with full timeline control.
Design: instrumentation wrappers diff state before/after calling the original
protocol functions, so the core algorithm files remain unmodified.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from enum import Enum, auto
from typing import Any, Optional
from crisis.graph import LamportGraph
from crisis.message import Vertex
from crisis.order import LeaderStream, compute_order
from crisis.rounds import compute_rounds, max_round
from crisis.voting import (
compute_safe_voting_pattern,
compute_virtual_leader_election,
)
from crisis.weight import DifficultyOracle
# ---------------------------------------------------------------------------
# Event Types
# ---------------------------------------------------------------------------
class EventType(Enum):
# Phase 1: Message generation
MESSAGE_CREATED = auto()
BYZANTINE_MUTATION = auto()
# Phase 2: Gossip / delivery
MESSAGE_DELIVERED = auto()
# Phase 3: Consensus
ROUND_ASSIGNED = auto()
VERTEX_BECOMES_LAST = auto()
SVP_COMPUTED = auto()
VOTE_CAST = auto()
LEADER_ELECTED = auto()
ORDER_COMPUTED = auto()
# Meta
STEP_BEGIN = auto()
STEP_END = auto()
CONVERGENCE_CHECK = auto()
# ---------------------------------------------------------------------------
# Event
# ---------------------------------------------------------------------------
@dataclass(frozen=True)
class SimEvent:
"""A single recorded protocol event."""
seq: int
step: int
event_type: EventType
node_name: str
data: dict[str, Any]
# ---------------------------------------------------------------------------
# Snapshots
# ---------------------------------------------------------------------------
@dataclass
class VertexSnapshot:
"""Snapshot of a single vertex at a point in time."""
digest_hex: str # 12-char hex prefix for display
digest_full: str # full hex for lookup
process_id_hex: str # 8-char hex prefix
round_number: Optional[int] = None
is_last: bool = False
weight: int = 0
payload_str: str = ""
total_position: Optional[int] = None
svp: list[int] = field(default_factory=list)
is_byzantine_source: bool = False
@dataclass
class NodeSnapshot:
"""Snapshot of a node's full state at a given step."""
name: str
step: int
vertex_count: int = 0
max_round: int = 0
num_leaders: int = 0
num_ordered: int = 0
is_byzantine: bool = False
vertices: list[VertexSnapshot] = field(default_factory=list)
edges: list[tuple[str, str]] = field(default_factory=list) # (from_hex, to_hex)
leader_rounds: dict[int, str] = field(default_factory=dict) # round -> leader digest_hex
@dataclass
class StepSnapshot:
"""Full simulation state captured at a step boundary."""
step: int
node_snapshots: dict[str, NodeSnapshot] = field(default_factory=dict)
convergence: bool = False
agreed_prefix_length: int = 0
# ---------------------------------------------------------------------------
# EventRecorder
# ---------------------------------------------------------------------------
class EventRecorder:
"""Accumulates events and snapshots during a simulation run."""
def __init__(self):
self.events: list[SimEvent] = []
self.snapshots: list[StepSnapshot] = []
self._seq = 0
def record(self, step: int, event_type: EventType,
node_name: str, **data) -> SimEvent:
self._seq += 1
event = SimEvent(self._seq, step, event_type, node_name, data)
self.events.append(event)
return event
def events_at_step(self, step: int) -> list[SimEvent]:
return [e for e in self.events if e.step == step]
def events_of_type(self, et: EventType) -> list[SimEvent]:
return [e for e in self.events if e.event_type == et]
def max_step(self) -> int:
return max((e.step for e in self.events), default=0)
# ---------------------------------------------------------------------------
# Snapshot capture
# ---------------------------------------------------------------------------
def capture_snapshot(step: int, nodes, weight_system,
precomputed_orders: dict | None = None) -> StepSnapshot:
"""Capture a full StepSnapshot from the current simulation state.
Args:
step: The simulation step number.
nodes: List of SimulatedNode objects.
weight_system: The weight system for computing vertex weights.
precomputed_orders: Optional dict node_name -> list[Vertex] to skip
recomputing total order (expensive).
"""
snap = StepSnapshot(step=step)
for node in nodes:
g = node.graph
mr = max_round(g)
if precomputed_orders and node.name in precomputed_orders:
ordered = precomputed_orders[node.name]
else:
ordered = compute_order(g, node.leader_stream)
# Build vertex snapshots
v_snaps = []
for v in g.all_vertices():
vs = VertexSnapshot(
digest_hex=v.message_digest.hex()[:12],
digest_full=v.message_digest.hex(),
process_id_hex=v.id.hex()[:8],
round_number=v.round,
is_last=bool(v.is_last),
weight=weight_system.weight(v.m),
payload_str=v.payload.decode(errors="replace")[:60],
total_position=v.total_position,
svp=list(v.svp) if v.svp else [],
)
v_snaps.append(vs)
# Build edge list
edge_list = []
for d_from, refs in g.edges.items():
from_hex = d_from.hex()[:12]
for d_to in refs:
if d_to in g.vertices:
edge_list.append((from_hex, d_to.hex()[:12]))
# Leader digest map
leader_rounds = {}
for rn, (_, msg) in node.leader_stream.leaders.items():
leader_rounds[rn] = msg.compute_digest().hex()[:12]
ns = NodeSnapshot(
name=node.name,
step=step,
vertex_count=g.vertex_count(),
max_round=mr,
num_leaders=len(node.leader_stream.leaders),
num_ordered=len(ordered),
is_byzantine=node.is_byzantine,
vertices=v_snaps,
edges=edge_list,
leader_rounds=leader_rounds,
)
snap.node_snapshots[node.name] = ns
# Convergence check across honest nodes
honest = [n for n in nodes if not n.is_byzantine]
if len(honest) >= 2:
orders = []
for n in honest:
if precomputed_orders and n.name in precomputed_orders:
o = precomputed_orders[n.name]
else:
o = compute_order(n.graph, n.leader_stream)
orders.append([v.message_digest.hex()[:12] for v in o])
# Find longest common prefix
if orders:
min_len = min(len(o) for o in orders)
agreed = 0
for i in range(min_len):
if all(o[i] == orders[0][i] for o in orders[1:]):
agreed = i + 1
else:
break
snap.agreed_prefix_length = agreed
snap.convergence = (agreed == min_len and min_len > 0
and all(len(o) == len(orders[0]) for o in orders))
return snap
# ---------------------------------------------------------------------------
# Instrumentation wrappers
# ---------------------------------------------------------------------------
def record_rounds(graph: LamportGraph, difficulty: DifficultyOracle,
connectivity_k: int, recorder: EventRecorder,
step: int, node_name: str) -> None:
"""Wrapper around compute_rounds that records state changes."""
old_state = {
v.message_digest: (v.round, v.is_last)
for v in graph.all_vertices()
}
compute_rounds(graph, difficulty, connectivity_k)
for v in graph.all_vertices():
d = v.message_digest
old_r, old_last = old_state.get(d, (None, None))
if v.round != old_r and v.round is not None:
recorder.record(
step, EventType.ROUND_ASSIGNED, node_name,
digest_hex=d.hex()[:12],
round_number=v.round,
process_id_hex=v.id.hex()[:8],
)
if v.is_last and not old_last:
recorder.record(
step, EventType.VERTEX_BECOMES_LAST, node_name,
digest_hex=d.hex()[:12],
round_number=v.round,
process_id_hex=v.id.hex()[:8],
)
def record_voting(vertex: Vertex, graph: LamportGraph,
difficulty: DifficultyOracle, connectivity_k: int,
recorder: EventRecorder, step: int,
node_name: str) -> None:
"""Wrapper around compute_safe_voting_pattern that records SVP."""
old_svp = list(vertex.svp) if vertex.svp else []
compute_safe_voting_pattern(vertex, graph, difficulty, connectivity_k)
if vertex.svp and vertex.svp != old_svp:
recorder.record(
step, EventType.SVP_COMPUTED, node_name,
digest_hex=vertex.message_digest.hex()[:12],
svp=list(vertex.svp),
round_number=vertex.round,
)
def record_leader_election(vertex: Vertex, graph: LamportGraph,
difficulty: DifficultyOracle,
connectivity_k: int,
leader_dict: dict,
recorder: EventRecorder, step: int,
node_name: str) -> None:
"""Wrapper around compute_virtual_leader_election that records results."""
old_keys = set(leader_dict.keys())
compute_virtual_leader_election(
vertex, graph, difficulty, connectivity_k, leader_dict
)
new_keys = set(leader_dict.keys()) - old_keys
for rn in new_keys:
entries = leader_dict[rn] # list[tuple[int, Message]]
for deciding_round, leader_msg in entries:
recorder.record(
step, EventType.LEADER_ELECTED, node_name,
round_number=rn,
deciding_round=deciding_round,
leader_digest_hex=leader_msg.compute_digest().hex()[:12],
)