mirror of
https://github.com/saymrwulf/crisis.git
synced 2026-05-14 20:37:54 +00:00
Add JSON export pipeline + event recorder for visualization
The simulation now optionally records structured events (message creation, delivery, round computation, voting, leader election) via EventRecorder and exports a complete simulation dump to JSON via the new export_json module. crisis_data.json captures a 10-step run that the SwiftUI visualizer consumes. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
parent
37e9f26204
commit
1491422527
5 changed files with 186369 additions and 25 deletions
185766
crisis_data.json
Normal file
185766
crisis_data.json
Normal file
File diff suppressed because it is too large
Load diff
|
|
@ -28,6 +28,10 @@ from crisis.crypto import digest
|
|||
from crisis.graph import LamportGraph
|
||||
from crisis.message import Message, Vertex, ID_LENGTH, NONCE_LENGTH
|
||||
from crisis.order import LeaderStream, compute_order
|
||||
from crisis.recorder import (
|
||||
EventRecorder, EventType, capture_snapshot,
|
||||
record_rounds, record_voting, record_leader_election,
|
||||
)
|
||||
from crisis.rounds import compute_rounds, max_round, last_vertices_in_round
|
||||
from crisis.voting import compute_safe_voting_pattern, compute_virtual_leader_election
|
||||
from crisis.weight import ProofOfWorkWeight, DifficultyOracle
|
||||
|
|
@ -79,11 +83,13 @@ class Simulation:
|
|||
|
||||
def __init__(self, num_honest: int = 3, num_byzantine: int = 0,
|
||||
pow_zeros: int = 2, difficulty: int = 1,
|
||||
connectivity_k: int = 0, seed: int = 42):
|
||||
connectivity_k: int = 0, seed: int = 42,
|
||||
recorder: Optional[EventRecorder] = None):
|
||||
self.difficulty_oracle = DifficultyOracle(constant_difficulty=difficulty)
|
||||
self.connectivity_k = connectivity_k
|
||||
self.weight_system = ProofOfWorkWeight(min_leading_zeros=pow_zeros)
|
||||
self.seed = seed
|
||||
self.recorder = recorder
|
||||
random.seed(seed)
|
||||
|
||||
# Create nodes
|
||||
|
|
@ -106,6 +112,7 @@ class Simulation:
|
|||
|
||||
self.step_count = 0
|
||||
self.all_messages: list[Message] = []
|
||||
self.snapshots: list[capture_snapshot.__class__] = [] # type: ignore
|
||||
|
||||
def step(self) -> dict:
|
||||
"""Execute one simulation step.
|
||||
|
|
@ -113,6 +120,12 @@ class Simulation:
|
|||
Returns a dict with step results for display.
|
||||
"""
|
||||
self.step_count += 1
|
||||
rec = self.recorder
|
||||
|
||||
if rec:
|
||||
rec.record(self.step_count, EventType.STEP_BEGIN, "",
|
||||
sim_step=self.step_count)
|
||||
|
||||
step_results = {
|
||||
"step": self.step_count,
|
||||
"new_messages": [],
|
||||
|
|
@ -130,57 +143,99 @@ class Simulation:
|
|||
|
||||
if msg is not None:
|
||||
new_messages.append((node, msg))
|
||||
msg_digest = msg.compute_digest().hex()[:12]
|
||||
msg_weight = self.weight_system.weight(msg)
|
||||
|
||||
step_results["new_messages"].append({
|
||||
"from": node.name,
|
||||
"digest": msg.compute_digest().hex()[:12],
|
||||
"weight": self.weight_system.weight(msg),
|
||||
"digest": msg_digest,
|
||||
"weight": msg_weight,
|
||||
"payload": msg.payload.decode(errors="replace"),
|
||||
})
|
||||
|
||||
if rec:
|
||||
evt = EventType.BYZANTINE_MUTATION if node.is_byzantine else EventType.MESSAGE_CREATED
|
||||
rec.record(
|
||||
self.step_count, evt, node.name,
|
||||
digest_hex=msg_digest,
|
||||
process_id_hex=msg.id.hex()[:8],
|
||||
payload_str=msg.payload.decode(errors="replace")[:60],
|
||||
weight=msg_weight,
|
||||
num_refs=len(msg.digests),
|
||||
)
|
||||
|
||||
# Phase 2: Gossip -- deliver all messages to all nodes
|
||||
for source_node, msg in new_messages:
|
||||
self.all_messages.append(msg)
|
||||
for target_node in self.nodes:
|
||||
# Deliver to all nodes (including source, for consistency)
|
||||
target_node.graph.extend(msg)
|
||||
result = target_node.graph.extend(msg)
|
||||
if rec and result is not None:
|
||||
rec.record(
|
||||
self.step_count, EventType.MESSAGE_DELIVERED, target_node.name,
|
||||
digest_hex=msg.compute_digest().hex()[:12],
|
||||
from_node=source_node.name,
|
||||
)
|
||||
|
||||
# Also re-deliver older messages that nodes might be missing
|
||||
# (simulates pull gossip catching up)
|
||||
for msg in self.all_messages:
|
||||
for node in self.nodes:
|
||||
node.graph.extend(msg) # extend() is idempotent (integrity check)
|
||||
node.graph.extend(msg) # extend() is idempotent
|
||||
|
||||
# Phase 3: Compute consensus on each node
|
||||
self._last_orders: dict[str, list] = {}
|
||||
for node in self.nodes:
|
||||
compute_rounds(node.graph, self.difficulty_oracle, self.connectivity_k)
|
||||
if rec:
|
||||
record_rounds(node.graph, self.difficulty_oracle,
|
||||
self.connectivity_k, rec,
|
||||
self.step_count, node.name)
|
||||
else:
|
||||
compute_rounds(node.graph, self.difficulty_oracle,
|
||||
self.connectivity_k)
|
||||
|
||||
# Compute SVP for all last vertices
|
||||
for vertex in node.graph.all_vertices():
|
||||
if vertex.is_last:
|
||||
compute_safe_voting_pattern(
|
||||
vertex, node.graph, self.difficulty_oracle,
|
||||
self.connectivity_k
|
||||
)
|
||||
if rec:
|
||||
record_voting(vertex, node.graph,
|
||||
self.difficulty_oracle,
|
||||
self.connectivity_k, rec,
|
||||
self.step_count, node.name)
|
||||
else:
|
||||
compute_safe_voting_pattern(
|
||||
vertex, node.graph, self.difficulty_oracle,
|
||||
self.connectivity_k
|
||||
)
|
||||
|
||||
# Compute leader election in round order (lower rounds first).
|
||||
# This ensures that when a higher-round vertex reads votes from
|
||||
# its voting set members, those members have already computed
|
||||
# their own votes.
|
||||
# Compute leader election in round order
|
||||
leader_dict: dict[int, list[tuple[int, Message]]] = {}
|
||||
svp_vertices = [v for v in node.graph.all_vertices() if v.svp]
|
||||
svp_vertices.sort(key=lambda v: v.round if v.round is not None else 0)
|
||||
|
||||
for vertex in svp_vertices:
|
||||
compute_virtual_leader_election(
|
||||
vertex, node.graph, self.difficulty_oracle,
|
||||
self.connectivity_k, leader_dict
|
||||
)
|
||||
if rec:
|
||||
record_leader_election(
|
||||
vertex, node.graph, self.difficulty_oracle,
|
||||
self.connectivity_k, leader_dict, rec,
|
||||
self.step_count, node.name
|
||||
)
|
||||
else:
|
||||
compute_virtual_leader_election(
|
||||
vertex, node.graph, self.difficulty_oracle,
|
||||
self.connectivity_k, leader_dict
|
||||
)
|
||||
|
||||
for round_num, entries in leader_dict.items():
|
||||
for deciding_round, leader_msg in entries:
|
||||
node.leader_stream.update(round_num, deciding_round, leader_msg)
|
||||
|
||||
ordered = compute_order(node.graph, node.leader_stream)
|
||||
self._last_orders[node.name] = ordered
|
||||
|
||||
if rec:
|
||||
rec.record(
|
||||
self.step_count, EventType.ORDER_COMPUTED, node.name,
|
||||
count=len(ordered),
|
||||
)
|
||||
|
||||
mr = max_round(node.graph)
|
||||
step_results["node_states"].append({
|
||||
|
|
@ -192,6 +247,10 @@ class Simulation:
|
|||
"is_byzantine": node.is_byzantine,
|
||||
})
|
||||
|
||||
if rec:
|
||||
rec.record(self.step_count, EventType.STEP_END, "",
|
||||
sim_step=self.step_count)
|
||||
|
||||
return step_results
|
||||
|
||||
def _byzantine_message(self, node: SimulatedNode) -> Optional[Message]:
|
||||
|
|
@ -221,15 +280,39 @@ class Simulation:
|
|||
else:
|
||||
return node.generate_message(payload)
|
||||
|
||||
def run(self, num_steps: int = 10, verbose: bool = True) -> list[dict]:
|
||||
"""Run the simulation for a number of steps."""
|
||||
def run(self, num_steps: int = 10, verbose: bool = True,
|
||||
progress_callback=None) -> list[dict]:
|
||||
"""Run the simulation for a number of steps.
|
||||
|
||||
Args:
|
||||
num_steps: Number of simulation steps to run.
|
||||
verbose: Print step results to stdout.
|
||||
progress_callback: Optional callable(step, total) for progress UI.
|
||||
"""
|
||||
results = []
|
||||
for _ in range(num_steps):
|
||||
for i in range(num_steps):
|
||||
result = self.step()
|
||||
results.append(result)
|
||||
if verbose:
|
||||
_print_step(result)
|
||||
|
||||
# Capture snapshot for visualization
|
||||
if self.recorder:
|
||||
snap = capture_snapshot(self.step_count, self.nodes,
|
||||
self.weight_system,
|
||||
precomputed_orders=self._last_orders)
|
||||
self.recorder.snapshots.append(snap)
|
||||
|
||||
# Convergence check event
|
||||
self.recorder.record(
|
||||
self.step_count, EventType.CONVERGENCE_CHECK, "",
|
||||
convergence=snap.convergence,
|
||||
agreed_prefix=snap.agreed_prefix_length,
|
||||
)
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(i + 1, num_steps)
|
||||
|
||||
if verbose:
|
||||
_print_convergence_summary(self)
|
||||
|
||||
|
|
|
|||
176
src/crisis/export_json.py
Normal file
176
src/crisis/export_json.py
Normal file
|
|
@ -0,0 +1,176 @@
|
|||
"""
|
||||
JSON Exporter — Exports simulation data for the native macOS visualizer.
|
||||
|
||||
Runs the simulation and writes a complete JSON file containing:
|
||||
- Configuration parameters
|
||||
- Per-step snapshots (vertices, edges, rounds, leaders, order)
|
||||
- Per-step events (message creation, gossip, round assignment, etc.)
|
||||
|
||||
Usage:
|
||||
python -m crisis.export_json [--nodes 8] [--steps 10] [-o crisis_data.json]
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sys
|
||||
import argparse
|
||||
from dataclasses import asdict
|
||||
|
||||
from crisis.demo import Simulation
|
||||
from crisis.recorder import EventRecorder, EventType
|
||||
|
||||
|
||||
def export_simulation(
|
||||
num_honest: int = 8,
|
||||
num_byzantine: int = 1,
|
||||
num_steps: int = 10,
|
||||
pow_zeros: int = 1,
|
||||
difficulty: int = 1,
|
||||
connectivity_k: int = 0,
|
||||
seed: int = 42,
|
||||
) -> dict:
|
||||
"""Run simulation and return exportable dict."""
|
||||
recorder = EventRecorder()
|
||||
sim = Simulation(
|
||||
num_honest=num_honest,
|
||||
num_byzantine=num_byzantine,
|
||||
pow_zeros=pow_zeros,
|
||||
difficulty=difficulty,
|
||||
connectivity_k=connectivity_k,
|
||||
seed=seed,
|
||||
recorder=recorder,
|
||||
)
|
||||
sim.run(num_steps=num_steps, verbose=False)
|
||||
|
||||
# Node metadata
|
||||
from crisis.crypto import digest
|
||||
from crisis.message import ID_LENGTH
|
||||
node_meta = []
|
||||
for n in sim.nodes:
|
||||
pid = digest(n.name.encode())[:ID_LENGTH].hex()[:8]
|
||||
node_meta.append({
|
||||
"name": n.name,
|
||||
"processIdHex": pid,
|
||||
"isByzantine": n.is_byzantine,
|
||||
})
|
||||
|
||||
# Config
|
||||
config = {
|
||||
"numHonest": num_honest,
|
||||
"numByzantine": num_byzantine,
|
||||
"numSteps": num_steps,
|
||||
"powZeros": pow_zeros,
|
||||
"difficulty": difficulty,
|
||||
"connectivityK": connectivity_k,
|
||||
"seed": seed,
|
||||
}
|
||||
|
||||
# Snapshots
|
||||
steps_data = []
|
||||
for snap in recorder.snapshots:
|
||||
step_obj = {
|
||||
"step": snap.step,
|
||||
"convergence": snap.convergence,
|
||||
"agreedPrefixLength": snap.agreed_prefix_length,
|
||||
"nodeSnapshots": {},
|
||||
}
|
||||
for name, ns in snap.node_snapshots.items():
|
||||
vertices = []
|
||||
for v in ns.vertices:
|
||||
vertices.append({
|
||||
"digestHex": v.digest_hex,
|
||||
"digestFull": v.digest_full,
|
||||
"processIdHex": v.process_id_hex,
|
||||
"roundNumber": v.round_number,
|
||||
"isLast": v.is_last,
|
||||
"weight": v.weight,
|
||||
"payloadStr": v.payload_str,
|
||||
"totalPosition": v.total_position,
|
||||
"isByzantineSource": v.is_byzantine_source,
|
||||
})
|
||||
edges = [{"from": e[0], "to": e[1]} for e in ns.edges]
|
||||
leader_rounds = {str(k): v for k, v in ns.leader_rounds.items()}
|
||||
step_obj["nodeSnapshots"][name] = {
|
||||
"name": ns.name,
|
||||
"vertexCount": ns.vertex_count,
|
||||
"maxRound": ns.max_round,
|
||||
"numLeaders": ns.num_leaders,
|
||||
"numOrdered": ns.num_ordered,
|
||||
"isByzantine": ns.is_byzantine,
|
||||
"vertices": vertices,
|
||||
"edges": edges,
|
||||
"leaderRounds": leader_rounds,
|
||||
}
|
||||
steps_data.append(step_obj)
|
||||
|
||||
# Events (grouped by step)
|
||||
events_by_step: dict[int, list] = {}
|
||||
for e in recorder.events:
|
||||
step = e.step
|
||||
if step not in events_by_step:
|
||||
events_by_step[step] = []
|
||||
events_by_step[step].append({
|
||||
"seq": e.seq,
|
||||
"type": e.event_type.name,
|
||||
"nodeName": e.node_name,
|
||||
"data": _clean_data(e.data),
|
||||
})
|
||||
|
||||
return {
|
||||
"config": config,
|
||||
"nodes": node_meta,
|
||||
"steps": steps_data,
|
||||
"events": events_by_step,
|
||||
}
|
||||
|
||||
|
||||
def _clean_data(data: dict) -> dict:
|
||||
"""Ensure all values are JSON-serializable."""
|
||||
clean = {}
|
||||
for k, v in data.items():
|
||||
if isinstance(v, bytes):
|
||||
clean[k] = v.hex()
|
||||
elif isinstance(v, (int, float, str, bool, type(None))):
|
||||
clean[k] = v
|
||||
elif isinstance(v, (list, tuple)):
|
||||
clean[k] = [x.hex() if isinstance(x, bytes) else x for x in v]
|
||||
elif isinstance(v, dict):
|
||||
clean[k] = _clean_data(v)
|
||||
else:
|
||||
clean[k] = str(v)
|
||||
return clean
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Export Crisis simulation to JSON")
|
||||
parser.add_argument("--nodes", type=int, default=8)
|
||||
parser.add_argument("--byzantine", type=int, default=1)
|
||||
parser.add_argument("--steps", type=int, default=10)
|
||||
parser.add_argument("--pow-zeros", type=int, default=1)
|
||||
parser.add_argument("--difficulty", type=int, default=1)
|
||||
parser.add_argument("--connectivity-k", type=int, default=0)
|
||||
parser.add_argument("--seed", type=int, default=42)
|
||||
parser.add_argument("-o", "--output", default="crisis_data.json")
|
||||
args = parser.parse_args()
|
||||
|
||||
data = export_simulation(
|
||||
num_honest=args.nodes,
|
||||
num_byzantine=args.byzantine,
|
||||
num_steps=args.steps,
|
||||
pow_zeros=args.pow_zeros,
|
||||
difficulty=args.difficulty,
|
||||
connectivity_k=args.connectivity_k,
|
||||
seed=args.seed,
|
||||
)
|
||||
|
||||
with open(args.output, "w") as f:
|
||||
json.dump(data, f, indent=2)
|
||||
|
||||
n_events = sum(len(v) for v in data["events"].values())
|
||||
n_snaps = len(data["steps"])
|
||||
print(f"Exported: {n_events} events, {n_snaps} snapshots → {args.output}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -60,6 +60,10 @@ class LamportGraph:
|
|||
# digest -> set of digests that reference this vertex
|
||||
self.reverse_edges: dict[bytes, set[bytes]] = {}
|
||||
|
||||
# Cache for past() results: digest -> frozenset of digests
|
||||
# Invalidated when new vertices are added via extend()
|
||||
self._past_cache: dict[bytes, frozenset[bytes]] = {}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Graph queries
|
||||
# ------------------------------------------------------------------
|
||||
|
|
@ -124,8 +128,13 @@ class LamportGraph:
|
|||
Returns the set of all vertices that are causally before v
|
||||
(including v itself -- reflexivity).
|
||||
"""
|
||||
d = v.message_digest
|
||||
cached = self._past_cache.get(d)
|
||||
if cached is not None:
|
||||
return {self.vertices[dd] for dd in cached if dd in self.vertices}
|
||||
|
||||
visited: set[bytes] = set()
|
||||
stack = [v.message_digest]
|
||||
stack = [d]
|
||||
|
||||
while stack:
|
||||
current = stack.pop()
|
||||
|
|
@ -136,7 +145,8 @@ class LamportGraph:
|
|||
if neighbor in self.vertices and neighbor not in visited:
|
||||
stack.append(neighbor)
|
||||
|
||||
return {self.vertices[d] for d in visited if d in self.vertices}
|
||||
self._past_cache[d] = frozenset(visited)
|
||||
return {self.vertices[dd] for dd in visited if dd in self.vertices}
|
||||
|
||||
def future(self, v: Vertex) -> set[Vertex]:
|
||||
"""All vertices that are causally after v (including v itself)."""
|
||||
|
|
|
|||
309
src/crisis/recorder.py
Normal file
309
src/crisis/recorder.py
Normal file
|
|
@ -0,0 +1,309 @@
|
|||
"""
|
||||
Event Recording System for Crisis Protocol Visualization
|
||||
|
||||
Records all protocol events during a simulation run, producing a structured
|
||||
event log and per-step snapshots. The visualization application replays
|
||||
these recordings with full timeline control.
|
||||
|
||||
Design: instrumentation wrappers diff state before/after calling the original
|
||||
protocol functions, so the core algorithm files remain unmodified.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum, auto
|
||||
from typing import Any, Optional
|
||||
|
||||
from crisis.graph import LamportGraph
|
||||
from crisis.message import Vertex
|
||||
from crisis.order import LeaderStream, compute_order
|
||||
from crisis.rounds import compute_rounds, max_round
|
||||
from crisis.voting import (
|
||||
compute_safe_voting_pattern,
|
||||
compute_virtual_leader_election,
|
||||
)
|
||||
from crisis.weight import DifficultyOracle
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Event Types
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class EventType(Enum):
|
||||
# Phase 1: Message generation
|
||||
MESSAGE_CREATED = auto()
|
||||
BYZANTINE_MUTATION = auto()
|
||||
|
||||
# Phase 2: Gossip / delivery
|
||||
MESSAGE_DELIVERED = auto()
|
||||
|
||||
# Phase 3: Consensus
|
||||
ROUND_ASSIGNED = auto()
|
||||
VERTEX_BECOMES_LAST = auto()
|
||||
SVP_COMPUTED = auto()
|
||||
VOTE_CAST = auto()
|
||||
LEADER_ELECTED = auto()
|
||||
ORDER_COMPUTED = auto()
|
||||
|
||||
# Meta
|
||||
STEP_BEGIN = auto()
|
||||
STEP_END = auto()
|
||||
CONVERGENCE_CHECK = auto()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Event
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SimEvent:
|
||||
"""A single recorded protocol event."""
|
||||
seq: int
|
||||
step: int
|
||||
event_type: EventType
|
||||
node_name: str
|
||||
data: dict[str, Any]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Snapshots
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass
|
||||
class VertexSnapshot:
|
||||
"""Snapshot of a single vertex at a point in time."""
|
||||
digest_hex: str # 12-char hex prefix for display
|
||||
digest_full: str # full hex for lookup
|
||||
process_id_hex: str # 8-char hex prefix
|
||||
round_number: Optional[int] = None
|
||||
is_last: bool = False
|
||||
weight: int = 0
|
||||
payload_str: str = ""
|
||||
total_position: Optional[int] = None
|
||||
svp: list[int] = field(default_factory=list)
|
||||
is_byzantine_source: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class NodeSnapshot:
|
||||
"""Snapshot of a node's full state at a given step."""
|
||||
name: str
|
||||
step: int
|
||||
vertex_count: int = 0
|
||||
max_round: int = 0
|
||||
num_leaders: int = 0
|
||||
num_ordered: int = 0
|
||||
is_byzantine: bool = False
|
||||
vertices: list[VertexSnapshot] = field(default_factory=list)
|
||||
edges: list[tuple[str, str]] = field(default_factory=list) # (from_hex, to_hex)
|
||||
leader_rounds: dict[int, str] = field(default_factory=dict) # round -> leader digest_hex
|
||||
|
||||
|
||||
@dataclass
|
||||
class StepSnapshot:
|
||||
"""Full simulation state captured at a step boundary."""
|
||||
step: int
|
||||
node_snapshots: dict[str, NodeSnapshot] = field(default_factory=dict)
|
||||
convergence: bool = False
|
||||
agreed_prefix_length: int = 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# EventRecorder
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class EventRecorder:
|
||||
"""Accumulates events and snapshots during a simulation run."""
|
||||
|
||||
def __init__(self):
|
||||
self.events: list[SimEvent] = []
|
||||
self.snapshots: list[StepSnapshot] = []
|
||||
self._seq = 0
|
||||
|
||||
def record(self, step: int, event_type: EventType,
|
||||
node_name: str, **data) -> SimEvent:
|
||||
self._seq += 1
|
||||
event = SimEvent(self._seq, step, event_type, node_name, data)
|
||||
self.events.append(event)
|
||||
return event
|
||||
|
||||
def events_at_step(self, step: int) -> list[SimEvent]:
|
||||
return [e for e in self.events if e.step == step]
|
||||
|
||||
def events_of_type(self, et: EventType) -> list[SimEvent]:
|
||||
return [e for e in self.events if e.event_type == et]
|
||||
|
||||
def max_step(self) -> int:
|
||||
return max((e.step for e in self.events), default=0)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Snapshot capture
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def capture_snapshot(step: int, nodes, weight_system,
|
||||
precomputed_orders: dict | None = None) -> StepSnapshot:
|
||||
"""Capture a full StepSnapshot from the current simulation state.
|
||||
|
||||
Args:
|
||||
step: The simulation step number.
|
||||
nodes: List of SimulatedNode objects.
|
||||
weight_system: The weight system for computing vertex weights.
|
||||
precomputed_orders: Optional dict node_name -> list[Vertex] to skip
|
||||
recomputing total order (expensive).
|
||||
"""
|
||||
snap = StepSnapshot(step=step)
|
||||
|
||||
for node in nodes:
|
||||
g = node.graph
|
||||
mr = max_round(g)
|
||||
if precomputed_orders and node.name in precomputed_orders:
|
||||
ordered = precomputed_orders[node.name]
|
||||
else:
|
||||
ordered = compute_order(g, node.leader_stream)
|
||||
|
||||
# Build vertex snapshots
|
||||
v_snaps = []
|
||||
for v in g.all_vertices():
|
||||
vs = VertexSnapshot(
|
||||
digest_hex=v.message_digest.hex()[:12],
|
||||
digest_full=v.message_digest.hex(),
|
||||
process_id_hex=v.id.hex()[:8],
|
||||
round_number=v.round,
|
||||
is_last=bool(v.is_last),
|
||||
weight=weight_system.weight(v.m),
|
||||
payload_str=v.payload.decode(errors="replace")[:60],
|
||||
total_position=v.total_position,
|
||||
svp=list(v.svp) if v.svp else [],
|
||||
)
|
||||
v_snaps.append(vs)
|
||||
|
||||
# Build edge list
|
||||
edge_list = []
|
||||
for d_from, refs in g.edges.items():
|
||||
from_hex = d_from.hex()[:12]
|
||||
for d_to in refs:
|
||||
if d_to in g.vertices:
|
||||
edge_list.append((from_hex, d_to.hex()[:12]))
|
||||
|
||||
# Leader digest map
|
||||
leader_rounds = {}
|
||||
for rn, (_, msg) in node.leader_stream.leaders.items():
|
||||
leader_rounds[rn] = msg.compute_digest().hex()[:12]
|
||||
|
||||
ns = NodeSnapshot(
|
||||
name=node.name,
|
||||
step=step,
|
||||
vertex_count=g.vertex_count(),
|
||||
max_round=mr,
|
||||
num_leaders=len(node.leader_stream.leaders),
|
||||
num_ordered=len(ordered),
|
||||
is_byzantine=node.is_byzantine,
|
||||
vertices=v_snaps,
|
||||
edges=edge_list,
|
||||
leader_rounds=leader_rounds,
|
||||
)
|
||||
snap.node_snapshots[node.name] = ns
|
||||
|
||||
# Convergence check across honest nodes
|
||||
honest = [n for n in nodes if not n.is_byzantine]
|
||||
if len(honest) >= 2:
|
||||
orders = []
|
||||
for n in honest:
|
||||
if precomputed_orders and n.name in precomputed_orders:
|
||||
o = precomputed_orders[n.name]
|
||||
else:
|
||||
o = compute_order(n.graph, n.leader_stream)
|
||||
orders.append([v.message_digest.hex()[:12] for v in o])
|
||||
# Find longest common prefix
|
||||
if orders:
|
||||
min_len = min(len(o) for o in orders)
|
||||
agreed = 0
|
||||
for i in range(min_len):
|
||||
if all(o[i] == orders[0][i] for o in orders[1:]):
|
||||
agreed = i + 1
|
||||
else:
|
||||
break
|
||||
snap.agreed_prefix_length = agreed
|
||||
snap.convergence = (agreed == min_len and min_len > 0
|
||||
and all(len(o) == len(orders[0]) for o in orders))
|
||||
|
||||
return snap
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Instrumentation wrappers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def record_rounds(graph: LamportGraph, difficulty: DifficultyOracle,
|
||||
connectivity_k: int, recorder: EventRecorder,
|
||||
step: int, node_name: str) -> None:
|
||||
"""Wrapper around compute_rounds that records state changes."""
|
||||
old_state = {
|
||||
v.message_digest: (v.round, v.is_last)
|
||||
for v in graph.all_vertices()
|
||||
}
|
||||
|
||||
compute_rounds(graph, difficulty, connectivity_k)
|
||||
|
||||
for v in graph.all_vertices():
|
||||
d = v.message_digest
|
||||
old_r, old_last = old_state.get(d, (None, None))
|
||||
if v.round != old_r and v.round is not None:
|
||||
recorder.record(
|
||||
step, EventType.ROUND_ASSIGNED, node_name,
|
||||
digest_hex=d.hex()[:12],
|
||||
round_number=v.round,
|
||||
process_id_hex=v.id.hex()[:8],
|
||||
)
|
||||
if v.is_last and not old_last:
|
||||
recorder.record(
|
||||
step, EventType.VERTEX_BECOMES_LAST, node_name,
|
||||
digest_hex=d.hex()[:12],
|
||||
round_number=v.round,
|
||||
process_id_hex=v.id.hex()[:8],
|
||||
)
|
||||
|
||||
|
||||
def record_voting(vertex: Vertex, graph: LamportGraph,
|
||||
difficulty: DifficultyOracle, connectivity_k: int,
|
||||
recorder: EventRecorder, step: int,
|
||||
node_name: str) -> None:
|
||||
"""Wrapper around compute_safe_voting_pattern that records SVP."""
|
||||
old_svp = list(vertex.svp) if vertex.svp else []
|
||||
|
||||
compute_safe_voting_pattern(vertex, graph, difficulty, connectivity_k)
|
||||
|
||||
if vertex.svp and vertex.svp != old_svp:
|
||||
recorder.record(
|
||||
step, EventType.SVP_COMPUTED, node_name,
|
||||
digest_hex=vertex.message_digest.hex()[:12],
|
||||
svp=list(vertex.svp),
|
||||
round_number=vertex.round,
|
||||
)
|
||||
|
||||
|
||||
def record_leader_election(vertex: Vertex, graph: LamportGraph,
|
||||
difficulty: DifficultyOracle,
|
||||
connectivity_k: int,
|
||||
leader_dict: dict,
|
||||
recorder: EventRecorder, step: int,
|
||||
node_name: str) -> None:
|
||||
"""Wrapper around compute_virtual_leader_election that records results."""
|
||||
old_keys = set(leader_dict.keys())
|
||||
|
||||
compute_virtual_leader_election(
|
||||
vertex, graph, difficulty, connectivity_k, leader_dict
|
||||
)
|
||||
|
||||
new_keys = set(leader_dict.keys()) - old_keys
|
||||
for rn in new_keys:
|
||||
entries = leader_dict[rn] # list[tuple[int, Message]]
|
||||
for deciding_round, leader_msg in entries:
|
||||
recorder.record(
|
||||
step, EventType.LEADER_ELECTED, node_name,
|
||||
round_number=rn,
|
||||
deciding_round=deciding_round,
|
||||
leader_digest_hex=leader_msg.compute_digest().hex()[:12],
|
||||
)
|
||||
Loading…
Reference in a new issue