diff --git a/src/crisis/demo.py b/src/crisis/demo.py index 66f7da1..236a234 100644 --- a/src/crisis/demo.py +++ b/src/crisis/demo.py @@ -78,8 +78,8 @@ class Simulation: """ def __init__(self, num_honest: int = 3, num_byzantine: int = 0, - pow_zeros: int = 0, difficulty: int = 2, - connectivity_k: int = 1, seed: int = 42): + pow_zeros: int = 2, difficulty: int = 1, + connectivity_k: int = 0, seed: int = 42): self.difficulty_oracle = DifficultyOracle(constant_difficulty=difficulty) self.connectivity_k = connectivity_k self.weight_system = ProofOfWorkWeight(min_leading_zeros=pow_zeros) @@ -154,6 +154,7 @@ class Simulation: for node in self.nodes: compute_rounds(node.graph, self.difficulty_oracle, self.connectivity_k) + # Compute SVP for all last vertices for vertex in node.graph.all_vertices(): if vertex.is_last: compute_safe_voting_pattern( @@ -161,13 +162,19 @@ class Simulation: self.connectivity_k ) + # Compute leader election in round order (lower rounds first). + # This ensures that when a higher-round vertex reads votes from + # its voting set members, those members have already computed + # their own votes. leader_dict: dict[int, list[tuple[int, Message]]] = {} - for vertex in node.graph.all_vertices(): - if vertex.svp: - compute_virtual_leader_election( - vertex, node.graph, self.difficulty_oracle, - self.connectivity_k, leader_dict - ) + svp_vertices = [v for v in node.graph.all_vertices() if v.svp] + svp_vertices.sort(key=lambda v: v.round if v.round is not None else 0) + + for vertex in svp_vertices: + compute_virtual_leader_election( + vertex, node.graph, self.difficulty_oracle, + self.connectivity_k, leader_dict + ) for round_num, entries in leader_dict.items(): for deciding_round, leader_msg in entries: @@ -324,10 +331,10 @@ def main(): help="Number of byzantine nodes (default: 0)") parser.add_argument("--steps", type=int, default=10, help="Number of simulation steps (default: 10)") - parser.add_argument("--pow-zeros", type=int, default=0, - help="Min PoW leading zeros (default: 0 = no PoW)") - parser.add_argument("--difficulty", type=int, default=2, - help="Difficulty oracle constant (default: 2)") + parser.add_argument("--pow-zeros", type=int, default=2, + help="Min PoW leading zeros (default: 2)") + parser.add_argument("--difficulty", type=int, default=1, + help="Difficulty oracle constant (default: 1)") parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility (default: 42)") diff --git a/src/crisis/rounds.py b/src/crisis/rounds.py index 227991a..df4d97b 100644 --- a/src/crisis/rounds.py +++ b/src/crisis/rounds.py @@ -145,15 +145,22 @@ def _is_k_reachable(v_from: Vertex, v_to: Vertex, expensive and not really necessary in our setting... all we need is some insurance that information flows through enough real world processes." We use total path weight as a simpler proxy. + + Special case: k <= 0 degenerates to simple reachability (is v_from in + the past of v_to?). This is the appropriate setting for small demos + where weight accumulation is limited. """ - if v_from not in graph.past(v_to): + past_of_to = graph.past(v_to) + + if v_from not in past_of_to: return False - # Compute the weight of all vertices in the path from v_from to v_to - # (all vertices that are in both the future of v_from and the past of v_to) - past_of_to = graph.past(v_to) - future_of_from = graph.future(v_from) + # k <= 0: simple reachability suffices + if k <= 0: + return True + # k > 0: check that enough weight exists on the path + future_of_from = graph.future(v_from) path_vertices = past_of_to & future_of_from total_weight = graph.set_weight(path_vertices) diff --git a/src/crisis/voting.py b/src/crisis/voting.py index 02e08cd..992b0ca 100644 --- a/src/crisis/voting.py +++ b/src/crisis/voting.py @@ -71,6 +71,12 @@ def build_knowledge_graph(vertex: Vertex, round_s: int, Collects all round-s vertices in v's past, groups them by id, and builds the quotient graph. + + Each node represents a virtual process (an id). An edge from id to id' + means some round-s vertex with that id acknowledges a round-s vertex + with id'. Isolated nodes (no same-round edges) are still included -- + they represent processes whose messages are known but who didn't + cross-reference other round-s processes. """ kg = KnowledgeGraph() past = graph.past(vertex) @@ -78,7 +84,7 @@ def build_knowledge_graph(vertex: Vertex, round_s: int, # Find all round-s vertices in v's past round_s_vertices = [v for v in past if v.round == round_s] - # Group by id and compute edges + # Group by id and compute weights for v_s in round_s_vertices: vid = v_s.id if vid not in kg.edges: @@ -90,11 +96,18 @@ def build_knowledge_graph(vertex: Vertex, round_s: int, kg.weights[vid], graph.vertex_weight(v_s) ) - # Add edges based on what this vertex references + # Add edges: if this vertex references another round-s vertex for cause in graph.direct_causes(v_s): if cause.round is not None and cause.round == round_s: kg.edges[vid].add(cause.id) + # Also check if any round-s vertex references this one (reverse) + for other in round_s_vertices: + if other.id != vid: + for cause in graph.direct_causes(other): + if cause.id == vid and cause.round == round_s: + kg.edges[other.id].add(vid) + return kg @@ -112,41 +125,61 @@ def select_quorum(knowledge_graph: KnowledgeGraph, n: int = 3) -> set[bytes]: The quorum selector serves as a filter to reduce byzantine noise that might appear in the voting process. By restricting to a heavily connected component, faulty behavior based on graph partition is reduced. + + Special case: when all processes are isolated (no edges between them, + typical for round 0 genesis vertices), we treat all of them as one + component. This is the bootstrapping case -- we know about all these + processes through the vertex that triggered this query. """ if not knowledge_graph.edges: return set() - # Find weakly connected components using simple BFS all_ids = set(knowledge_graph.edges.keys()) - visited: set[bytes] = set() - components: list[set[bytes]] = [] - for start_id in all_ids: - if start_id in visited: - continue - component: set[bytes] = set() - queue = [start_id] - while queue: - current = queue.pop(0) - if current in visited: + # Check if all edges are empty (all processes are isolated). + # This happens at round 0: genesis vertices don't reference each other. + # In that case, treat all processes as a single component -- the + # triggering vertex has all of them in its past, which is sufficient + # evidence of connectivity. + all_isolated = all( + len(neighbors) == 0 + for neighbors in knowledge_graph.edges.values() + ) + + if all_isolated: + # All ids form one virtual component + best_component = all_ids + else: + # Find weakly connected components using BFS + visited: set[bytes] = set() + components: list[set[bytes]] = [] + + for start_id in all_ids: + if start_id in visited: continue - visited.add(current) - component.add(current) - # Follow edges in both directions (weakly connected) - for neighbor in knowledge_graph.edges.get(current, set()): - if neighbor not in visited and neighbor in all_ids: - queue.append(neighbor) - # Reverse edges - for other_id, neighbors in knowledge_graph.edges.items(): - if current in neighbors and other_id not in visited: - queue.append(other_id) - components.append(component) + component: set[bytes] = set() + queue = [start_id] + while queue: + current = queue.pop(0) + if current in visited: + continue + visited.add(current) + component.add(current) + # Follow edges in both directions (weakly connected) + for neighbor in knowledge_graph.edges.get(current, set()): + if neighbor not in visited and neighbor in all_ids: + queue.append(neighbor) + # Reverse edges + for other_id, neighbors in knowledge_graph.edges.items(): + if current in neighbors and other_id not in visited: + queue.append(other_id) + components.append(component) - # Choose the component with highest total weight - def component_weight(comp: set[bytes]) -> int: - return sum(knowledge_graph.weights.get(pid, 0) for pid in comp) + # Choose the component with highest total weight + def component_weight(comp: set[bytes]) -> int: + return sum(knowledge_graph.weights.get(pid, 0) for pid in comp) - best_component = max(components, key=component_weight) + best_component = max(components, key=component_weight) # Take the n heaviest processes from this component sorted_by_weight = sorted( @@ -232,7 +265,10 @@ def compute_safe_voting_pattern(vertex: Vertex, graph: LamportGraph, r = vertex.round - # Check each previous round for safe voting pattern membership + # Check each previous round for safe voting pattern membership. + # The SVP contains rounds strictly LESS than v.round (Algorithm 6). + # It does NOT include v.round itself -- the current round's peers + # are spacelike and cannot be part of v's voting set. for s in range(r): d_s = difficulty.difficulty(s) @@ -263,10 +299,6 @@ def compute_safe_voting_pattern(vertex: Vertex, graph: LamportGraph, if svps_agree: vertex.svp.append(s) - # svp is a nested sequence: add current round - if vertex.svp: - vertex.svp.append(r) - # --------------------------------------------------------------------------- # Initial Vote Function (Definition 5.16, Example 4) @@ -301,79 +333,104 @@ def compute_virtual_leader_election(vertex: Vertex, graph: LamportGraph, """Algorithm 7: compute votes for all rounds in v's safe voting pattern. This is the core virtual BA* protocol. For each element t in v.svp, - the vertex computes a vote v.vote(t) = (l, b) based on the stage δ - (the position of that round in the svp). + the vertex computes a vote v.vote(t) = (l, b) based on the stage δ. - Stage types (determined by δ = d_{v.svp}(s, t)): - δ = 0: Initial leader proposal - δ = 1: Leader presorting (gradecast step) - δ = 2: BBA* initialization (gradecast step) - δ ≥ 3: Binary agreement rounds - δ mod 3 = 0: Coin fixed to 0 - δ mod 3 = 1: Coin fixed to 1 - δ mod 3 = 2: Genuine coin flip + Key details from the paper (pages 19-20): + - s = max(v.svp): the highest round in the safe voting pattern + - S = S_v(s, k): the voting set is computed ONCE for round s + - δ = d_{v.svp}(s, t): the SVP distance from s to t + - δ = 0 at t = s (the newest round): Initial leader proposal + - δ = 1: Leader presorting (Feldman-Micali gradecast step 1) + - δ = 2: BBA* initialization (gradecast step 2) + - δ ≥ 3: Binary byzantine agreement rounds + - δ mod 3 = 0: Coin fixed to 0 + - δ mod 3 = 1: Coin fixed to 1 + - δ mod 3 = 2: Genuine coin flip (using hash LSB) - The paper notes: "every step is entirely virtual and no votes are - actually sent to other real world processes." + The votes of S members at round t are read from x.vote(t), which those + members computed in their own execution of Algorithm 7. Processing + goes from δ=0 (no dependency) to higher δ (reads lower δ votes from + S members), so votes cascade correctly. """ if not vertex.svp: return - s = max(vertex.svp) if vertex.svp else None - if s is None: + s = max(vertex.svp) + last_idx = len(vertex.svp) - 1 + + # Line 6: S ← v's safe voting pattern S_v(s, k) + # The voting set is computed ONCE for the max round s + voting_set_s = voting_set(vertex, s, connectivity_k, graph) + n = graph.set_weight(voting_set_s) + d_s = difficulty.difficulty(s) + + if n == 0: return - for t_idx, t in enumerate(vertex.svp): - delta = t_idx # stage = position in svp - _compute_vote_for_stage(vertex, t, delta, s, graph, difficulty, - connectivity_k, leader_stream) + # Process in δ order: start at δ=0 (t=s), then δ=1, δ=2, ... + # This means iterating the svp in REVERSE order (newest first) + for t_idx_reversed in range(len(vertex.svp)): + t_idx = last_idx - t_idx_reversed + t = vertex.svp[t_idx] + delta = t_idx_reversed # δ = distance from s (the last element) + + _compute_vote_for_stage(vertex, t, delta, s, voting_set_s, n, d_s, + graph, leader_stream) def _compute_vote_for_stage(vertex: Vertex, t: int, delta: int, s: int, - graph: LamportGraph, difficulty: DifficultyOracle, - connectivity_k: int, + vs: set[Vertex], n: int, d_s: int, + graph: LamportGraph, leader_stream: dict[int, list[tuple[int, Message]]]) -> None: """Compute vertex's vote for a specific stage of the virtual leader election. Implements the branching logic of Algorithm 7 (pages 19-20 of the paper). - """ - d_s = difficulty.difficulty(s) - vs = voting_set(vertex, t, connectivity_k, graph) - n = graph.set_weight(vs) + Args: + vertex: The vertex computing its vote. + t: The round number being voted on. + delta: The SVP distance (stage type). + s: The max round of the SVP. + vs: The voting set S (round-s last vertices). + n: Total weight of S. + d_s: Difficulty for round s. + graph: The Lamport graph. + leader_stream: Dict to update with decided leaders. + """ NON_LEADER = None # ∅ in the paper if delta == 0: # Stage 0: Initial leader proposal + # v.vote(t) ← (INITIAL_VOTE(S), ⊥) l = initial_vote(vs, graph) - vertex.vote[t] = Vote(message=l, binary=None) # (INITIAL_VOTE(S), ⊥) + vertex.vote[t] = Vote(message=l, binary=None) elif delta == 1: - # Stage 1: Leader presorting - # Find message with highest round-t voting weight in S - l = _highest_weight_message(vs, graph) + # Stage 1: Leader presorting (gradecast step 1) + # Read S members' votes at round t (their δ=0 votes, i.e. vote(s)) + # "l ← message with highest round t voting weight in S" + l = _leader_with_most_weight(vs, t, graph) if l is not None: - # Check if l has super majority weight - l_weight = _vote_weight_for(vs, t, l, None, graph) # votes for (l, ⊥) + l_weight = _vote_weight_for(vs, t, l, None, graph) if l_weight > n - d_s: - vertex.vote[t] = Vote(message=l, binary=None) # (l, ⊥) + vertex.vote[t] = Vote(message=l, binary=None) else: - vertex.vote[t] = Vote(message=NON_LEADER, binary=None) # (∅, ⊥) + vertex.vote[t] = Vote(message=NON_LEADER, binary=None) else: vertex.vote[t] = Vote(message=NON_LEADER, binary=None) elif delta == 2: - # Stage 2: BBA* initialization (gradecast) - l = _highest_weight_message(vs, graph) + # Stage 2: BBA* initialization (gradecast step 2) + l = _leader_with_most_weight(vs, t, graph) - if l is not None: + if l is not None and l is not NON_LEADER: l_weight_undecided = _vote_weight_for(vs, t, l, None, graph) if l_weight_undecided > n - d_s: vertex.vote[t] = Vote(message=l, binary=0) else: - l_weight_1 = _vote_weight_for(vs, t, l, 1, graph) - if l_weight_1 > d_s: + l_weight_any = _vote_weight_for_message(vs, t, l, graph) + if l_weight_any > d_s: vertex.vote[t] = Vote(message=l, binary=1) else: vertex.vote[t] = Vote(message=NON_LEADER, binary=1) @@ -383,18 +440,15 @@ def _compute_vote_for_stage(vertex: Vertex, t: int, delta: int, s: int, else: # Stage δ ≥ 3: Binary agreement (BBA*) coin_stage = delta % 3 - l = _highest_weight_message(vs, graph) + l = _leader_with_most_weight(vs, t, graph) if coin_stage == 0: - # Coin fixed to 0 _bba_coin_fixed(vertex, t, vs, l, n, d_s, graph, leader_stream, s, fixed_value=0) elif coin_stage == 1: - # Coin fixed to 1 _bba_coin_fixed(vertex, t, vs, l, n, d_s, graph, leader_stream, s, fixed_value=1) else: - # Genuine coin flip (coin_stage == 2) _bba_genuine_coin(vertex, t, vs, l, n, d_s, graph) @@ -460,10 +514,50 @@ def _highest_weight_message(vs: set[Vertex], graph: LamportGraph) -> Optional[Me return best.m +def _leader_with_most_weight(vs: set[Vertex], round_t: int, + graph: LamportGraph) -> Optional[Message]: + """Find the leader message l that received the most voting weight at round t. + + Looks at x.vote(t) for all x ∈ vs, groups by the leader message, + and returns the one with the highest total weight. + Falls back to the highest-weight vertex's message if no votes exist. + """ + # Tally weight per leader message + leader_weights: dict[bytes, tuple[int, Message]] = {} + has_any_vote = False + + for v in vs: + vote = v.vote.get(round_t) + if vote is None: + continue + has_any_vote = True + if vote.message is not None: + key = vote.message.compute_digest() + w = graph.vertex_weight(v) + if key in leader_weights: + old_w, msg = leader_weights[key] + leader_weights[key] = (graph.weight_system.weight_sum(old_w, w), msg) + else: + leader_weights[key] = (w, vote.message) + + if leader_weights: + _, best_msg = max(leader_weights.values(), key=lambda x: x[0]) + return best_msg + + # No votes yet: fall back to highest weight vertex's message + if not has_any_vote: + return _highest_weight_message(vs, graph) + + return None + + def _vote_weight_for(vs: set[Vertex], round_t: int, target_msg: Optional[Message], target_binary: Optional[int], graph: LamportGraph) -> int: - """Compute total voting weight for a specific vote (l, b) in a voting set.""" + """Compute total voting weight for a specific vote (l, b) in a voting set. + + w(S_v(s,k), t, (l,b)) := w({x ∈ S | x.vote(t) = (l, b)}) + """ total = 0 for v in vs: vote = v.vote.get(round_t) @@ -478,6 +572,25 @@ def _vote_weight_for(vs: set[Vertex], round_t: int, return total +def _vote_weight_for_message(vs: set[Vertex], round_t: int, + target_msg: Message, + graph: LamportGraph) -> int: + """Compute total weight for any vote with a given leader message, any binary. + + Used in the gradecast stage to check if l received any significant weight, + regardless of the binary part. + """ + total = 0 + target_digest = target_msg.compute_digest() + for v in vs: + vote = v.vote.get(round_t) + if vote is None or vote.message is None: + continue + if vote.message.compute_digest() == target_digest: + total = graph.weight_system.weight_sum(total, graph.vertex_weight(v)) + return total + + def _vote_weight_for_binary(vs: set[Vertex], round_t: int, target_binary: int, graph: LamportGraph) -> int: