From c68d016d4c922a38e3b16475f749005785bb2656 Mon Sep 17 00:00:00 2001 From: Gabor Szita Date: Sat, 11 Jan 2025 17:09:55 -0800 Subject: [PATCH 01/19] Eliminate double inverts in pyrtl.optimize --- pyrtl/passes.py | 50 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/pyrtl/passes.py b/pyrtl/passes.py index 0b684303..33e4d24e 100644 --- a/pyrtl/passes.py +++ b/pyrtl/passes.py @@ -52,11 +52,61 @@ def optimize(update_working_block=True, block=None, skip_sanity_check=False): constant_propagation(block, True) _remove_unlistened_nets(block) common_subexp_elimination(block) + _remove_double_inverts(block, skip_sanity_check) if (not skip_sanity_check) or _get_debug_mode(): block.sanity_check() return block +def _remove_double_inverts(block, skip_sanity_check=False): + """ Removes all double invert nets from the block. """ + + # checks if the wirevector is used at a LogicNet other than used_at_net + def is_wirevector_used_elsewhere(wire, used_at_nets): + for net in block.logic: + if net not in used_at_nets: + if wire.name in [x.name for x in net.args] \ + or wire.name in [x.name for x in net.dests]: + return True + return False + + new_logic = set() + net_exclude_set = set() # removed nets + wire_removal_set = set() + for net1 in block.logic: + for net2 in block.logic: + # Conditions need to be satisfied for the nets to be removed: + # 1. Both nets should be invert nets + # 2. Nets should not be in net_exclude_set (nets that are already removed) + # 3. The destination of net1 should be the argument of net2 + # (so we know the nets are connected) + # 4. The destination of net1 should not be used elsewhere + # (because we can't remove a wire that is used in another net) + if net1.op == '~' and net2.op == '~' \ + and net1 not in net_exclude_set and net2 not in net_exclude_set \ + and net1.dests[0].name == net2.args[0].name \ + and not is_wirevector_used_elsewhere(net1.dests[0], (net1, net2)): + new_logic.add(LogicNet('w', None, args=net1.args, dests=net2.dests)) + net_exclude_set.add(net1) + net_exclude_set.add(net2) + wire_removal_set.add(net1.dests[0]) + break + + for net in block.logic: + if net not in net_exclude_set: + new_logic.add(net) + + block.logic = new_logic + for dead_wirevector in wire_removal_set: + block.remove_wirevector(dead_wirevector) + + if (not skip_sanity_check) or _get_debug_mode(): + block.sanity_check() + + # clean up wire nodes + _remove_wire_nets(block, skip_sanity_check) + + class _ProducerList(object): """ Maps from wire to its immediate producer and finds ultimate producers. """ def __init__(self): From 4cc9e6a083f63a01d01956be91ca6dcfae62c5e3 Mon Sep 17 00:00:00 2001 From: Gabor Szita Date: Fri, 17 Jan 2025 12:21:56 -0800 Subject: [PATCH 02/19] Optimize _remove_double_inverts --- pyrtl/passes.py | 36 +++++++++++++++++------------------- 1 file changed, 17 insertions(+), 19 deletions(-) diff --git a/pyrtl/passes.py b/pyrtl/passes.py index 33e4d24e..5ec8c0b4 100644 --- a/pyrtl/passes.py +++ b/pyrtl/passes.py @@ -61,7 +61,7 @@ def optimize(update_working_block=True, block=None, skip_sanity_check=False): def _remove_double_inverts(block, skip_sanity_check=False): """ Removes all double invert nets from the block. """ - # checks if the wirevector is used at a LogicNet other than used_at_net + # checks if the wirevector is used at a LogicNet other than used_at_nets def is_wirevector_used_elsewhere(wire, used_at_nets): for net in block.logic: if net not in used_at_nets: @@ -73,24 +73,22 @@ def is_wirevector_used_elsewhere(wire, used_at_nets): new_logic = set() net_exclude_set = set() # removed nets wire_removal_set = set() - for net1 in block.logic: - for net2 in block.logic: - # Conditions need to be satisfied for the nets to be removed: - # 1. Both nets should be invert nets - # 2. Nets should not be in net_exclude_set (nets that are already removed) - # 3. The destination of net1 should be the argument of net2 - # (so we know the nets are connected) - # 4. The destination of net1 should not be used elsewhere - # (because we can't remove a wire that is used in another net) - if net1.op == '~' and net2.op == '~' \ - and net1 not in net_exclude_set and net2 not in net_exclude_set \ - and net1.dests[0].name == net2.args[0].name \ - and not is_wirevector_used_elsewhere(net1.dests[0], (net1, net2)): - new_logic.add(LogicNet('w', None, args=net1.args, dests=net2.dests)) - net_exclude_set.add(net1) - net_exclude_set.add(net2) - wire_removal_set.add(net1.dests[0]) - break + # Dictionary, key is the destination wire of the invert net, value is the invert net + invert_destination_wires = {} + for net in block.logic: + if net.op == "~": + invert_destination_wires[net.dests[0].name] = net + for net in invert_destination_wires.values(): + # If the argument of the net is in invert_destination_wires, then it is a double invert + # If the net is in net_exclude_set, then it was already removed, so we do not process it + if net.args[0].name in invert_destination_wires and net not in net_exclude_set: + previous_net = invert_destination_wires[net.args[0].name] + if not is_wirevector_used_elsewhere(net.args[0], (net, previous_net)) \ + and previous_net not in net_exclude_set: + new_logic.add(LogicNet('w', None, args=previous_net.args, dests=net.dests)) + wire_removal_set.add(net.args[0]) + net_exclude_set.add(net) + net_exclude_set.add(previous_net) for net in block.logic: if net not in net_exclude_set: From 378f86a684d14ea3c1a30c0e5af75afd1179088f Mon Sep 17 00:00:00 2001 From: Gabor Szita Date: Fri, 17 Jan 2025 14:17:21 -0800 Subject: [PATCH 03/19] Fix bug that sometimes double inverts are not removed --- pyrtl/passes.py | 47 ++++++++++++++++++++++++++++++++++++----------- 1 file changed, 36 insertions(+), 11 deletions(-) diff --git a/pyrtl/passes.py b/pyrtl/passes.py index 5ec8c0b4..c63d1e6d 100644 --- a/pyrtl/passes.py +++ b/pyrtl/passes.py @@ -78,17 +78,42 @@ def is_wirevector_used_elsewhere(wire, used_at_nets): for net in block.logic: if net.op == "~": invert_destination_wires[net.dests[0].name] = net - for net in invert_destination_wires.values(): - # If the argument of the net is in invert_destination_wires, then it is a double invert - # If the net is in net_exclude_set, then it was already removed, so we do not process it - if net.args[0].name in invert_destination_wires and net not in net_exclude_set: - previous_net = invert_destination_wires[net.args[0].name] - if not is_wirevector_used_elsewhere(net.args[0], (net, previous_net)) \ - and previous_net not in net_exclude_set: - new_logic.add(LogicNet('w', None, args=previous_net.args, dests=net.dests)) - wire_removal_set.add(net.args[0]) - net_exclude_set.add(net) - net_exclude_set.add(previous_net) + # If double invert nets are removed randomly, this may leave some double inverts behind. + # Example: ~(~(~(~a))) + # If we remove the middle two inverts first, we will end up with ~((~a)). These remaining + # double inverts won't get removed because they aren't directly connected. + # To avoid this, we remove double inverts in a chain sequentially from start to end. + # For example, we first remove the two outer inverts from ~(~(~(~a))) to get ~(~a), + # and then remove the remaining two inner inverts. To do this, we need iterate through + # the invert_destination_wires dictionary multiple times, hence the outer while loop. + repeat = True + while repeat: + repeat = False + removed_nets = set() + for net in invert_destination_wires.values(): + # If the argument of the net is in invert_destination_wires, then it is a double invert + # If the net is in net_exclude_set, then it was already removed, so we do not process it + if net.args[0].name in invert_destination_wires and net not in net_exclude_set: + previous_net = invert_destination_wires[net.args[0].name] + if not is_wirevector_used_elsewhere(net.args[0], (net, previous_net)) \ + and previous_net not in net_exclude_set: + # If previous_net is in invert_destination_wires, we have a chain of + # 3 or more double inverts. To make sure we remove double inverts + # in these chains sequentially, we only remove the double invert + # we found if the invert net whose destination is previous_net + # was not removed yet. If it was not yet removed, the for loop + # needs to run again, so we set repeat to True. + if previous_net.args[0].name in invert_destination_wires: + repeat = True + else: + new_logic.add(LogicNet('w', None, args=previous_net.args, dests=net.dests)) + wire_removal_set.add(net.args[0]) + removed_nets.add(net) + removed_nets.add(previous_net) + # remove removed_nets from invert_destination_wires to optimize the for loop + for net in removed_nets: + del invert_destination_wires[net.dests[0].name] + net_exclude_set.update(removed_nets) for net in block.logic: if net not in net_exclude_set: From 6e872b3760a3b76b815e419af53674353d69ba3b Mon Sep 17 00:00:00 2001 From: Gabor Szita Date: Fri, 17 Jan 2025 14:19:01 -0800 Subject: [PATCH 04/19] Add tests for removing double inverts --- tests/test_passes.py | 81 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 81 insertions(+) diff --git a/tests/test_passes.py b/tests/test_passes.py index 332ad873..f5dcbcd5 100644 --- a/tests/test_passes.py +++ b/tests/test_passes.py @@ -316,6 +316,87 @@ def test_slice_net_removal_4(self): self.num_net_of_type('s', 1, block) self.num_net_of_type('w', 2, block) + def test_remove_double_inverts_1_invert(self): + inwire = pyrtl.Input(bitwidth=1) + outwire = pyrtl.Output(bitwidth=1) + outwire <<= ~inwire + pyrtl.optimize() + block = pyrtl.working_block() + self.assert_num_net(2, block) + self.assert_num_wires(3, block) + + def test_remove_double_inverts_3_inverts(self): + inwire = pyrtl.Input(bitwidth=1) + outwire = pyrtl.Output(bitwidth=1) + outwire <<= ~(~(~inwire)) + pyrtl.optimize() + block = pyrtl.working_block() + self.assert_num_net(2, block) + self.assert_num_wires(3, block) + + def test_remove_double_inverts_5_inverts(self): + inwire = pyrtl.Input(bitwidth=1) + outwire = pyrtl.Output(bitwidth=1) + outwire <<= ~(~(~(~(~inwire)))) + pyrtl.optimize() + block = pyrtl.working_block() + self.assert_num_net(2, block) + self.assert_num_wires(3, block) + + def test_remove_double_inverts_2_inverts(self): + inwire = pyrtl.Input(bitwidth=1) + outwire = pyrtl.Output(bitwidth=1) + outwire <<= ~(~inwire) + pyrtl.optimize() + block = pyrtl.working_block() + self.assert_num_net(1, block) + self.assert_num_wires(2, block) + + def test_remove_double_inverts_4_inverts(self): + inwire = pyrtl.Input(bitwidth=1) + outwire = pyrtl.Output(bitwidth=1) + outwire <<= ~(~(~(~inwire))) + pyrtl.optimize() + block = pyrtl.working_block() + self.assert_num_net(1, block) + self.assert_num_wires(2, block) + + def test_remove_double_inverts_6_inverts(self): + inwire = pyrtl.Input(bitwidth=1) + outwire = pyrtl.Output(bitwidth=1) + outwire <<= ~(~(~(~(~(~inwire))))) + pyrtl.optimize() + block = pyrtl.working_block() + self.assert_num_net(1, block) + self.assert_num_wires(2, block) + + def test_dont_remove_double_inverts_another_user(self): + inwire = pyrtl.Input(bitwidth=1) + outwire = pyrtl.Output(bitwidth=1) + outwire2 = pyrtl.Output(bitwidth=1) + tempwire = pyrtl.WireVector() + tempwire <<= ~inwire + outwire <<= ~tempwire + outwire2 <<= tempwire + pyrtl.optimize() + block = pyrtl.working_block() + self.assert_num_net(4, block) + self.assert_num_wires(5, block) + + def test_multiple_double_invert_chains(self): + # _remove_double_inverts removes double inverts by chains, + # so it is useful to make sure it can remove + # double inverts from multiple chains + inwire = pyrtl.Input(bitwidth=1) + outwire = pyrtl.Output(bitwidth=1) + outwire2 = pyrtl.Output(bitwidth=1) + outwire <<= ~(~inwire) + outwire2 <<= ~(~(~(~(inwire)))) + pyrtl.optimize() + block = pyrtl.working_block() + self.assert_num_net(2, block) + self.assert_num_wires(3, block) + class TestConstFolding(NetWireNumTestCases): def setUp(self): From 8f3246fffed613920707242f889519afa70f0d4e Mon Sep 17 00:00:00 2001 From: Gabor Szita Date: Sat, 22 Feb 2025 18:05:39 -0800 Subject: [PATCH 05/19] Remove unnecessary outer while loop --- pyrtl/passes.py | 52 +++++++++++++------------------------------------ 1 file changed, 14 insertions(+), 38 deletions(-) diff --git a/pyrtl/passes.py b/pyrtl/passes.py index c63d1e6d..0970af17 100644 --- a/pyrtl/passes.py +++ b/pyrtl/passes.py @@ -63,7 +63,7 @@ def _remove_double_inverts(block, skip_sanity_check=False): # checks if the wirevector is used at a LogicNet other than used_at_nets def is_wirevector_used_elsewhere(wire, used_at_nets): - for net in block.logic: + for net in block: if net not in used_at_nets: if wire.name in [x.name for x in net.args] \ or wire.name in [x.name for x in net.dests]: @@ -75,45 +75,21 @@ def is_wirevector_used_elsewhere(wire, used_at_nets): wire_removal_set = set() # Dictionary, key is the destination wire of the invert net, value is the invert net invert_destination_wires = {} - for net in block.logic: + for net in block: if net.op == "~": invert_destination_wires[net.dests[0].name] = net - # If double invert nets are removed randomly, this may leave some double inverts behind. - # Example: ~(~(~(~a))) - # If we remove the middle two inverts first, we will end up with ~((~a)). These remaining - # double inverts won't get removed because they aren't directly connected. - # To avoid this, we remove double inverts in a chain sequentially from start to end. - # For example, we first remove the two outer inverts from ~(~(~(~a))) to get ~(~a), - # and then remove the remaining two inner inverts. To do this, we need iterate through - # the invert_destination_wires dictionary multiple times, hence the outer while loop. - repeat = True - while repeat: - repeat = False - removed_nets = set() - for net in invert_destination_wires.values(): - # If the argument of the net is in invert_destination_wires, then it is a double invert - # If the net is in net_exclude_set, then it was already removed, so we do not process it - if net.args[0].name in invert_destination_wires and net not in net_exclude_set: - previous_net = invert_destination_wires[net.args[0].name] - if not is_wirevector_used_elsewhere(net.args[0], (net, previous_net)) \ - and previous_net not in net_exclude_set: - # If previous_net is in invert_destination_wires, we have a chain of - # 3 or more double inverts. To make sure we remove double inverts - # in these chains sequentially, we only remove the double invert - # we found if the invert net whose destination is previous_net - # was not removed yet. If it was not yet removed, the for loop - # needs to run again, so we set repeat to True. - if previous_net.args[0].name in invert_destination_wires: - repeat = True - else: - new_logic.add(LogicNet('w', None, args=previous_net.args, dests=net.dests)) - wire_removal_set.add(net.args[0]) - removed_nets.add(net) - removed_nets.add(previous_net) - # remove removed_nets from invert_destination_wires to optimize the for loop - for net in removed_nets: - del invert_destination_wires[net.dests[0].name] - net_exclude_set.update(removed_nets) + + for net in invert_destination_wires.values(): + # If the argument of the net is in invert_destination_wires, then it is a double invert + # If the net is in net_exclude_set, then it was already removed, so we do not process it + if net.args[0].name in invert_destination_wires and net not in net_exclude_set: + previous_net = invert_destination_wires[net.args[0].name] + if not is_wirevector_used_elsewhere(net.args[0], (net, previous_net)) \ + and previous_net not in net_exclude_set: + new_logic.add(LogicNet('w', None, args=previous_net.args, dests=net.dests)) + wire_removal_set.add(net.args[0]) + net_exclude_set.add(net) + net_exclude_set.add(previous_net) for net in block.logic: if net not in net_exclude_set: From 54d0c2ce359014e040ebf1ae8e91f36bd89ddeec Mon Sep 17 00:00:00 2001 From: Gabor Szita Date: Tue, 4 Mar 2025 12:17:13 -0800 Subject: [PATCH 06/19] Work on optimizing double invert removal --- pyrtl/passes.py | 96 ++++++++++++++++++++++++++++++++++++------------- 1 file changed, 71 insertions(+), 25 deletions(-) diff --git a/pyrtl/passes.py b/pyrtl/passes.py index 0970af17..5d0f3790 100644 --- a/pyrtl/passes.py +++ b/pyrtl/passes.py @@ -61,35 +61,81 @@ def optimize(update_working_block=True, block=None, skip_sanity_check=False): def _remove_double_inverts(block, skip_sanity_check=False): """ Removes all double invert nets from the block. """ - # checks if the wirevector is used at a LogicNet other than used_at_nets - def is_wirevector_used_elsewhere(wire, used_at_nets): - for net in block: - if net not in used_at_nets: - if wire.name in [x.name for x in net.args] \ - or wire.name in [x.name for x in net.dests]: - return True - return False + # wire_creator maps from WireVector to the LogicNet that defines its value. + # wire_users maps from WireVector to a list of LogicNets that use its value. + wire_creator, wire_users = block.net_connections() + + # Build a list of inverter chains. Each inverter chain is a list of WireVectors, + # from source to destination. + inverter_wirenet_chains = [] + for current_dest, current_creator in wire_creator.items(): + if current_creator.op != "~": + # Skip non-inverters. + continue + + # The current inverter connects current_arg (a WireVector) to current_dest (also + # a WireVector). + current_arg = current_creator.args[0] + # current_users is the number of LogicNets that use current_dest. + current_users = len(wire_users[current_dest]) + + # Add the current inverter to the end of this inverter chain. + append_to = None + # Add the current inverter to the beginning of this inverter chain. + prepend_to = None + next_inverter_chains = [] + for inverter_wirenet_chain in inverter_wirenet_chains: + chain_arg = inverter_wirenet_chain[0] + chain_dest = inverter_wirenet_chain[-1] + chain_users = len(wire_users[chain_dest]) + + if chain_dest is current_arg and chain_users == 1: + # This chain's only destination is the current inverter. Append the + # current inverter to the chain. + append_to = inverter_wirenet_chain + elif chain_arg is current_dest and current_users <= 1: + # This chain's only argument is the current inverter. Add the current + # inverter to the beginning of the chain. + prepend_to = inverter_wirenet_chain + else: + next_inverter_chains.append(inverter_wirenet_chain) + + #print("current inverter: ", current_arg, "->", current_dest) + if append_to and prepend_to: + # The current inverter joins two existing inverter chains. + next_inverter_chains.append(append_to + prepend_to) + #print(" joined", next_inverter_chains) + elif append_to: + # Add the current inverter after 'append_to'. + next_inverter_chains.append(append_to + [current_dest]) + #print(" appended", next_inverter_chains) + elif prepend_to: + # Add the current inverter before 'prepend_to'. + next_inverter_chains.append([current_arg] + prepend_to) + #print(" prepended", next_inverter_chains) + else: + next_inverter_chains.append([current_arg, current_dest]) + #print(" start new chain", next_inverter_chains) + + inverter_wirenet_chains = next_inverter_chains new_logic = set() net_exclude_set = set() # removed nets wire_removal_set = set() - # Dictionary, key is the destination wire of the invert net, value is the invert net - invert_destination_wires = {} - for net in block: - if net.op == "~": - invert_destination_wires[net.dests[0].name] = net - - for net in invert_destination_wires.values(): - # If the argument of the net is in invert_destination_wires, then it is a double invert - # If the net is in net_exclude_set, then it was already removed, so we do not process it - if net.args[0].name in invert_destination_wires and net not in net_exclude_set: - previous_net = invert_destination_wires[net.args[0].name] - if not is_wirevector_used_elsewhere(net.args[0], (net, previous_net)) \ - and previous_net not in net_exclude_set: - new_logic.add(LogicNet('w', None, args=previous_net.args, dests=net.dests)) - wire_removal_set.add(net.args[0]) - net_exclude_set.add(net) - net_exclude_set.add(previous_net) + for inverter_wirenet_chain in inverter_wirenet_chains: + if len(inverter_wirenet_chain) > 1: + if len(inverter_wirenet_chain) % 2 == 1: # even number of inverters in a chain + end_idx = len(inverter_wirenet_chain) - 1 + else: # odd number of inverters in a chain + end_idx = len(inverter_wirenet_chain) - 2 + wires_to_remove = inverter_wirenet_chain[1:end_idx] + new_logic.add(LogicNet('w', None, args=(inverter_wirenet_chain[0],), \ + dests=(inverter_wirenet_chain[end_idx],))) + inverters_to_remove = {wire_creator[wire] for wire in wires_to_remove} + inverters_to_remove.add(wire_creator[inverter_wirenet_chain[end_idx]]) + + wire_removal_set.update(wires_to_remove) + net_exclude_set.update(inverters_to_remove) for net in block.logic: if net not in net_exclude_set: From 2a352acd7abd4792e0665704655de60e7a0b4aa6 Mon Sep 17 00:00:00 2001 From: Gabor Szita Date: Tue, 4 Mar 2025 14:55:50 -0800 Subject: [PATCH 07/19] Fix bug checking single-invert in remove double inverts --- pyrtl/passes.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pyrtl/passes.py b/pyrtl/passes.py index 5d0f3790..86022203 100644 --- a/pyrtl/passes.py +++ b/pyrtl/passes.py @@ -123,7 +123,7 @@ def _remove_double_inverts(block, skip_sanity_check=False): net_exclude_set = set() # removed nets wire_removal_set = set() for inverter_wirenet_chain in inverter_wirenet_chains: - if len(inverter_wirenet_chain) > 1: + if len(inverter_wirenet_chain) > 2: if len(inverter_wirenet_chain) % 2 == 1: # even number of inverters in a chain end_idx = len(inverter_wirenet_chain) - 1 else: # odd number of inverters in a chain @@ -134,8 +134,8 @@ def _remove_double_inverts(block, skip_sanity_check=False): inverters_to_remove = {wire_creator[wire] for wire in wires_to_remove} inverters_to_remove.add(wire_creator[inverter_wirenet_chain[end_idx]]) - wire_removal_set.update(wires_to_remove) - net_exclude_set.update(inverters_to_remove) + wire_removal_set.update(wires_to_remove) + net_exclude_set.update(inverters_to_remove) for net in block.logic: if net not in net_exclude_set: From 986ae629a393323a0268066b5376d46f88c5be10 Mon Sep 17 00:00:00 2001 From: Gabor Szita Date: Tue, 4 Mar 2025 16:55:20 -0800 Subject: [PATCH 08/19] Clean up code --- pyrtl/passes.py | 48 +++++++++++++++++++++--------------------------- 1 file changed, 21 insertions(+), 27 deletions(-) diff --git a/pyrtl/passes.py b/pyrtl/passes.py index 86022203..029d109b 100644 --- a/pyrtl/passes.py +++ b/pyrtl/passes.py @@ -67,7 +67,7 @@ def _remove_double_inverts(block, skip_sanity_check=False): # Build a list of inverter chains. Each inverter chain is a list of WireVectors, # from source to destination. - inverter_wirenet_chains = [] + inverter_chains = [] for current_dest, current_creator in wire_creator.items(): if current_creator.op != "~": # Skip non-inverters. @@ -84,61 +84,55 @@ def _remove_double_inverts(block, skip_sanity_check=False): # Add the current inverter to the beginning of this inverter chain. prepend_to = None next_inverter_chains = [] - for inverter_wirenet_chain in inverter_wirenet_chains: - chain_arg = inverter_wirenet_chain[0] - chain_dest = inverter_wirenet_chain[-1] + for inverter_chain in inverter_chains: + chain_arg = inverter_chain[0] + chain_dest = inverter_chain[-1] chain_users = len(wire_users[chain_dest]) if chain_dest is current_arg and chain_users == 1: # This chain's only destination is the current inverter. Append the # current inverter to the chain. - append_to = inverter_wirenet_chain + append_to = inverter_chain elif chain_arg is current_dest and current_users <= 1: # This chain's only argument is the current inverter. Add the current # inverter to the beginning of the chain. - prepend_to = inverter_wirenet_chain + prepend_to = inverter_chain else: - next_inverter_chains.append(inverter_wirenet_chain) + next_inverter_chains.append(inverter_chain) - #print("current inverter: ", current_arg, "->", current_dest) if append_to and prepend_to: # The current inverter joins two existing inverter chains. next_inverter_chains.append(append_to + prepend_to) - #print(" joined", next_inverter_chains) elif append_to: # Add the current inverter after 'append_to'. next_inverter_chains.append(append_to + [current_dest]) - #print(" appended", next_inverter_chains) elif prepend_to: # Add the current inverter before 'prepend_to'. next_inverter_chains.append([current_arg] + prepend_to) - #print(" prepended", next_inverter_chains) else: next_inverter_chains.append([current_arg, current_dest]) - #print(" start new chain", next_inverter_chains) - inverter_wirenet_chains = next_inverter_chains + inverter_chains = next_inverter_chains new_logic = set() - net_exclude_set = set() # removed nets + net_removal_set = set() wire_removal_set = set() - for inverter_wirenet_chain in inverter_wirenet_chains: - if len(inverter_wirenet_chain) > 2: - if len(inverter_wirenet_chain) % 2 == 1: # even number of inverters in a chain - end_idx = len(inverter_wirenet_chain) - 1 - else: # odd number of inverters in a chain - end_idx = len(inverter_wirenet_chain) - 2 - wires_to_remove = inverter_wirenet_chain[1:end_idx] - new_logic.add(LogicNet('w', None, args=(inverter_wirenet_chain[0],), \ - dests=(inverter_wirenet_chain[end_idx],))) + for inverter_chain in inverter_chains: + if len(inverter_chain) > 2: + if len(inverter_chain) % 2 == 1: # even number of inverters in a chain + end_idx = len(inverter_chain) - 1 + else: # odd number of inverters in a chain + end_idx = len(inverter_chain) - 2 + wires_to_remove = inverter_chain[1:end_idx] + new_logic.add(LogicNet('w', None, args=(inverter_chain[0],), + dests=(inverter_chain[end_idx],))) inverters_to_remove = {wire_creator[wire] for wire in wires_to_remove} - inverters_to_remove.add(wire_creator[inverter_wirenet_chain[end_idx]]) - + inverters_to_remove.add(wire_creator[inverter_chain[end_idx]]) wire_removal_set.update(wires_to_remove) - net_exclude_set.update(inverters_to_remove) + net_removal_set.update(inverters_to_remove) for net in block.logic: - if net not in net_exclude_set: + if net not in net_removal_set: new_logic.add(net) block.logic = new_logic From c430106919afbc370cbbbc34a9c93e71eaa8eec8 Mon Sep 17 00:00:00 2001 From: Gabor Szita Date: Wed, 5 Mar 2025 13:50:40 -0800 Subject: [PATCH 09/19] Add comments and better code organization --- pyrtl/passes.py | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/pyrtl/passes.py b/pyrtl/passes.py index 029d109b..aca884a7 100644 --- a/pyrtl/passes.py +++ b/pyrtl/passes.py @@ -58,12 +58,8 @@ def optimize(update_working_block=True, block=None, skip_sanity_check=False): return block -def _remove_double_inverts(block, skip_sanity_check=False): - """ Removes all double invert nets from the block. """ - - # wire_creator maps from WireVector to the LogicNet that defines its value. - # wire_users maps from WireVector to a list of LogicNets that use its value. - wire_creator, wire_users = block.net_connections() +def _get_inverter_chains(wire_creator, wire_users): + """Returns all inverter chains in the block""" # Build a list of inverter chains. Each inverter chain is a list of WireVectors, # from source to destination. @@ -93,11 +89,13 @@ def _remove_double_inverts(block, skip_sanity_check=False): # This chain's only destination is the current inverter. Append the # current inverter to the chain. append_to = inverter_chain - elif chain_arg is current_dest and current_users <= 1: + elif chain_arg is current_dest and current_users == 1: # This chain's only argument is the current inverter. Add the current # inverter to the beginning of the chain. prepend_to = inverter_chain else: + # The current inverter is not connected to the inverter chain, so we + # pass the inverter chain through to next_inverter_chains next_inverter_chains.append(inverter_chain) if append_to and prepend_to: @@ -110,14 +108,27 @@ def _remove_double_inverts(block, skip_sanity_check=False): # Add the current inverter before 'prepend_to'. next_inverter_chains.append([current_arg] + prepend_to) else: + # The current inverter is not connected to any inverter chain, so + # we start a new inverter chain with it next_inverter_chains.append([current_arg, current_dest]) inverter_chains = next_inverter_chains + return inverter_chains + + +def _remove_double_inverts(block, skip_sanity_check=False): + """ Removes all double invert nets from the block. """ + + # wire_creator maps from WireVector to the LogicNet that defines its value. + # wire_users maps from WireVector to a list of LogicNets that use its value. + wire_creator, wire_users = block.net_connections() new_logic = set() net_removal_set = set() wire_removal_set = set() - for inverter_chain in inverter_chains: + for inverter_chain in _get_inverter_chains(wire_creator, wire_users): + # if len(inverter_chain) = n, there are n-1 inverters in the chain + # only remove inverters if there are at least two inverters in a chain if len(inverter_chain) > 2: if len(inverter_chain) % 2 == 1: # even number of inverters in a chain end_idx = len(inverter_chain) - 1 From 99c351099326b92d607876bddf5c8534536ba5cd Mon Sep 17 00:00:00 2001 From: Gabor Szita Date: Wed, 5 Mar 2025 15:16:27 -0800 Subject: [PATCH 10/19] Remove double inverts do not create temporary wire nets working partially --- pyrtl/passes.py | 30 +++++++++++++++++++++++++----- 1 file changed, 25 insertions(+), 5 deletions(-) diff --git a/pyrtl/passes.py b/pyrtl/passes.py index aca884a7..2ba78e0b 100644 --- a/pyrtl/passes.py +++ b/pyrtl/passes.py @@ -126,6 +126,8 @@ def _remove_double_inverts(block, skip_sanity_check=False): new_logic = set() net_removal_set = set() wire_removal_set = set() + wire_destinations_replacement = dict() + wire_sources_replacement = dict() for inverter_chain in _get_inverter_chains(wire_creator, wire_users): # if len(inverter_chain) = n, there are n-1 inverters in the chain # only remove inverters if there are at least two inverters in a chain @@ -134,9 +136,18 @@ def _remove_double_inverts(block, skip_sanity_check=False): end_idx = len(inverter_chain) - 1 else: # odd number of inverters in a chain end_idx = len(inverter_chain) - 2 - wires_to_remove = inverter_chain[1:end_idx] - new_logic.add(LogicNet('w', None, args=(inverter_chain[0],), - dests=(inverter_chain[end_idx],))) + wires_to_remove = inverter_chain[1:end_idx+1] + #if inverter_chain[0] in wire_creator: + # chain_source = wire_creator[inverter_chain[0]] + # new_wire_dests = tuple(dest for dest in chain_source.dests if dest is not inverter_chain[0]) + # new_wire_dests += (inverter_chain[end_idx],) + # wire_destinations_replacement[chain_source] = new_wire_dests + for chain_dest in wire_users[inverter_chain[end_idx]]: + new_wire_sources = tuple(arg for arg in chain_dest.args if arg is not inverter_chain[end_idx]) + new_wire_sources += (inverter_chain[0],) + wire_sources_replacement[chain_dest] = new_wire_sources + #new_logic.add(LogicNet('w', None, args=(inverter_chain[0],), + # dests=(inverter_chain[end_idx],))) inverters_to_remove = {wire_creator[wire] for wire in wires_to_remove} inverters_to_remove.add(wire_creator[inverter_chain[end_idx]]) wire_removal_set.update(wires_to_remove) @@ -144,7 +155,16 @@ def _remove_double_inverts(block, skip_sanity_check=False): for net in block.logic: if net not in net_removal_set: - new_logic.add(net) + #new_logic.add(net) + #continue + if net in wire_sources_replacement: + new_logic.add(LogicNet(net.op, net.op_param, args=wire_sources_replacement[net], + dests=net.dests)) + #if net in wire_destinations_replacement: + # new_logic.add(LogicNet(net.op, net.op_param, args=net.args, + # dests=wire_destinations_replacement[net])) + else: + new_logic.add(net) block.logic = new_logic for dead_wirevector in wire_removal_set: @@ -154,7 +174,7 @@ def _remove_double_inverts(block, skip_sanity_check=False): block.sanity_check() # clean up wire nodes - _remove_wire_nets(block, skip_sanity_check) + #_remove_wire_nets(block, skip_sanity_check) class _ProducerList(object): From 626c78309a128824f04eea707751cba10f49cbcc Mon Sep 17 00:00:00 2001 From: Gabor Szita Date: Wed, 5 Mar 2025 16:33:51 -0800 Subject: [PATCH 11/19] all tests passing for optimizing wire nets in remove double inverts --- pyrtl/passes.py | 28 ++++------------------------ 1 file changed, 4 insertions(+), 24 deletions(-) diff --git a/pyrtl/passes.py b/pyrtl/passes.py index 2ba78e0b..23d3f877 100644 --- a/pyrtl/passes.py +++ b/pyrtl/passes.py @@ -126,8 +126,7 @@ def _remove_double_inverts(block, skip_sanity_check=False): new_logic = set() net_removal_set = set() wire_removal_set = set() - wire_destinations_replacement = dict() - wire_sources_replacement = dict() + wire_src_dict = _ProducerList() for inverter_chain in _get_inverter_chains(wire_creator, wire_users): # if len(inverter_chain) = n, there are n-1 inverters in the chain # only remove inverters if there are at least two inverters in a chain @@ -137,17 +136,9 @@ def _remove_double_inverts(block, skip_sanity_check=False): else: # odd number of inverters in a chain end_idx = len(inverter_chain) - 2 wires_to_remove = inverter_chain[1:end_idx+1] - #if inverter_chain[0] in wire_creator: - # chain_source = wire_creator[inverter_chain[0]] - # new_wire_dests = tuple(dest for dest in chain_source.dests if dest is not inverter_chain[0]) - # new_wire_dests += (inverter_chain[end_idx],) - # wire_destinations_replacement[chain_source] = new_wire_dests for chain_dest in wire_users[inverter_chain[end_idx]]: - new_wire_sources = tuple(arg for arg in chain_dest.args if arg is not inverter_chain[end_idx]) - new_wire_sources += (inverter_chain[0],) - wire_sources_replacement[chain_dest] = new_wire_sources - #new_logic.add(LogicNet('w', None, args=(inverter_chain[0],), - # dests=(inverter_chain[end_idx],))) + for arg in chain_dest.args: + wire_src_dict[arg] = inverter_chain[0] inverters_to_remove = {wire_creator[wire] for wire in wires_to_remove} inverters_to_remove.add(wire_creator[inverter_chain[end_idx]]) wire_removal_set.update(wires_to_remove) @@ -155,16 +146,8 @@ def _remove_double_inverts(block, skip_sanity_check=False): for net in block.logic: if net not in net_removal_set: - #new_logic.add(net) - #continue - if net in wire_sources_replacement: - new_logic.add(LogicNet(net.op, net.op_param, args=wire_sources_replacement[net], + new_logic.add(LogicNet(net.op, net.op_param, args=tuple(wire_src_dict.find_producer(x) for x in net.args), dests=net.dests)) - #if net in wire_destinations_replacement: - # new_logic.add(LogicNet(net.op, net.op_param, args=net.args, - # dests=wire_destinations_replacement[net])) - else: - new_logic.add(net) block.logic = new_logic for dead_wirevector in wire_removal_set: @@ -173,9 +156,6 @@ def _remove_double_inverts(block, skip_sanity_check=False): if (not skip_sanity_check) or _get_debug_mode(): block.sanity_check() - # clean up wire nodes - #_remove_wire_nets(block, skip_sanity_check) - class _ProducerList(object): """ Maps from wire to its immediate producer and finds ultimate producers. """ From f1e5b99e1091956e55188fcf5fcc6673377c6e15 Mon Sep 17 00:00:00 2001 From: Gabor Szita Date: Wed, 5 Mar 2025 17:28:21 -0800 Subject: [PATCH 12/19] fix bug destinations with multiple args have their args screwed up --- pyrtl/passes.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pyrtl/passes.py b/pyrtl/passes.py index 23d3f877..86378c19 100644 --- a/pyrtl/passes.py +++ b/pyrtl/passes.py @@ -136,9 +136,7 @@ def _remove_double_inverts(block, skip_sanity_check=False): else: # odd number of inverters in a chain end_idx = len(inverter_chain) - 2 wires_to_remove = inverter_chain[1:end_idx+1] - for chain_dest in wire_users[inverter_chain[end_idx]]: - for arg in chain_dest.args: - wire_src_dict[arg] = inverter_chain[0] + wire_src_dict[inverter_chain[end_idx]] = inverter_chain[0] inverters_to_remove = {wire_creator[wire] for wire in wires_to_remove} inverters_to_remove.add(wire_creator[inverter_chain[end_idx]]) wire_removal_set.update(wires_to_remove) From 6dcad0d8423d7c151c0231d2523283102422945c Mon Sep 17 00:00:00 2001 From: Gabor Szita Date: Wed, 5 Mar 2025 17:37:48 -0800 Subject: [PATCH 13/19] code cleanup and add comments --- pyrtl/passes.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/pyrtl/passes.py b/pyrtl/passes.py index 86378c19..466483c6 100644 --- a/pyrtl/passes.py +++ b/pyrtl/passes.py @@ -135,17 +135,21 @@ def _remove_double_inverts(block, skip_sanity_check=False): end_idx = len(inverter_chain) - 1 else: # odd number of inverters in a chain end_idx = len(inverter_chain) - 2 - wires_to_remove = inverter_chain[1:end_idx+1] - wire_src_dict[inverter_chain[end_idx]] = inverter_chain[0] + # remove wires used in the inverter chain + wires_to_remove = inverter_chain[1:end_idx + 1] + wire_removal_set.update(wires_to_remove) + # remove inverters used in the chain inverters_to_remove = {wire_creator[wire] for wire in wires_to_remove} inverters_to_remove.add(wire_creator[inverter_chain[end_idx]]) - wire_removal_set.update(wires_to_remove) net_removal_set.update(inverters_to_remove) + # map the end wire of the inverter chain to the beginning wire + wire_src_dict[inverter_chain[end_idx]] = inverter_chain[0] for net in block.logic: if net not in net_removal_set: - new_logic.add(LogicNet(net.op, net.op_param, args=tuple(wire_src_dict.find_producer(x) for x in net.args), - dests=net.dests)) + new_logic.add(LogicNet(net.op, net.op_param, + args=tuple(wire_src_dict.find_producer(x) for x in net.args), + dests=net.dests)) block.logic = new_logic for dead_wirevector in wire_removal_set: From d2a63034514f6e5b2a6b2287b5d7c4c205db4666 Mon Sep 17 00:00:00 2001 From: Gabor Szita Date: Thu, 6 Mar 2025 11:24:27 -0800 Subject: [PATCH 14/19] remove unnecessary line --- pyrtl/passes.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pyrtl/passes.py b/pyrtl/passes.py index 466483c6..8adf2d7f 100644 --- a/pyrtl/passes.py +++ b/pyrtl/passes.py @@ -140,7 +140,6 @@ def _remove_double_inverts(block, skip_sanity_check=False): wire_removal_set.update(wires_to_remove) # remove inverters used in the chain inverters_to_remove = {wire_creator[wire] for wire in wires_to_remove} - inverters_to_remove.add(wire_creator[inverter_chain[end_idx]]) net_removal_set.update(inverters_to_remove) # map the end wire of the inverter chain to the beginning wire wire_src_dict[inverter_chain[end_idx]] = inverter_chain[0] From 7cadcb2218a72df5e41277bb5e7d4d86b0eda09c Mon Sep 17 00:00:00 2001 From: Gabor Szita Date: Thu, 6 Mar 2025 11:25:45 -0800 Subject: [PATCH 15/19] use start_idx instead of end_idx --- pyrtl/passes.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pyrtl/passes.py b/pyrtl/passes.py index 8adf2d7f..05566195 100644 --- a/pyrtl/passes.py +++ b/pyrtl/passes.py @@ -132,17 +132,17 @@ def _remove_double_inverts(block, skip_sanity_check=False): # only remove inverters if there are at least two inverters in a chain if len(inverter_chain) > 2: if len(inverter_chain) % 2 == 1: # even number of inverters in a chain - end_idx = len(inverter_chain) - 1 + start_idx = 1 else: # odd number of inverters in a chain - end_idx = len(inverter_chain) - 2 + start_idx = 2 # remove wires used in the inverter chain - wires_to_remove = inverter_chain[1:end_idx + 1] + wires_to_remove = inverter_chain[start_idx:] wire_removal_set.update(wires_to_remove) # remove inverters used in the chain inverters_to_remove = {wire_creator[wire] for wire in wires_to_remove} net_removal_set.update(inverters_to_remove) # map the end wire of the inverter chain to the beginning wire - wire_src_dict[inverter_chain[end_idx]] = inverter_chain[0] + wire_src_dict[inverter_chain[-1]] = inverter_chain[start_idx-1] for net in block.logic: if net not in net_removal_set: From 9a602454ab2778a2d3ad77eb419ca82c68c8e8cf Mon Sep 17 00:00:00 2001 From: Gabor Szita Date: Thu, 20 Mar 2025 14:10:12 -0700 Subject: [PATCH 16/19] Add comments and improve function name --- pyrtl/passes.py | 47 +++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 43 insertions(+), 4 deletions(-) diff --git a/pyrtl/passes.py b/pyrtl/passes.py index 05566195..70228299 100644 --- a/pyrtl/passes.py +++ b/pyrtl/passes.py @@ -52,7 +52,7 @@ def optimize(update_working_block=True, block=None, skip_sanity_check=False): constant_propagation(block, True) _remove_unlistened_nets(block) common_subexp_elimination(block) - _remove_double_inverts(block, skip_sanity_check) + _remove_inverter_chains(block, skip_sanity_check) if (not skip_sanity_check) or _get_debug_mode(): block.sanity_check() return block @@ -116,8 +116,22 @@ def _get_inverter_chains(wire_creator, wire_users): return inverter_chains -def _remove_double_inverts(block, skip_sanity_check=False): - """ Removes all double invert nets from the block. """ +def _remove_inverter_chains(block, skip_sanity_check=False): + """ Removes all inverter chains from the block. + + An inverter chain means two or more inverters directly connected + to each other. Inverter chains are redundant and can be removed. + For example, A -~-> B -~-> C -w-> X can be reduced to A -w-> X. + + After optimization, a chain of an even number of inverters will + be reduced a direct connection, and a chain of an odd number of + inverters will be reduced to one inverter. + + If an inverter chain has intermediate users it won't be removed. + For example, the inverter chain in the following circuit won't be removed: + A -~-> B -~-> C -w-> X + B -w-> Y + """ # wire_creator maps from WireVector to the LogicNet that defines its value. # wire_users maps from WireVector to a list of LogicNets that use its value. @@ -126,7 +140,28 @@ def _remove_double_inverts(block, skip_sanity_check=False): new_logic = set() net_removal_set = set() wire_removal_set = set() + + # This ProducerList maps the end wire of an inverter chain to its beginning wire. + # We need this because when removing an inverter chain its end wire gets removed, + # so we need to replace the source of LogicNets using the end wire of the inverter + # chain with the chain's beginning wire. + # + # We need a ProducerList, rather than a simple dict, because if an inverter chain + # of more than two inverters has intermediate users, we may have to query the dict + # multiple times to get the replacement for the inverter chain's last wire. + # Consider the following circuit, for example: + # A -~-> B -~-> C -w-> X + # C -~-> D -~-> E -w-> Y + # This is the optimized version of the circuit: + # A -w-> X + # A -w-> Y + # The inverter chains found will be B-C and D-E (two separate chains will be + # found instead of B-C-D-E because C has an intermediate user). In the dict, + # C will be mapped to A and E will be maped to C. Hence, when finding the + # replacement of E, we have to first query the dict to get C, and then query + # the dict again on C to get A. wire_src_dict = _ProducerList() + for inverter_chain in _get_inverter_chains(wire_creator, wire_users): # if len(inverter_chain) = n, there are n-1 inverters in the chain # only remove inverters if there are at least two inverters in a chain @@ -142,8 +177,12 @@ def _remove_double_inverts(block, skip_sanity_check=False): inverters_to_remove = {wire_creator[wire] for wire in wires_to_remove} net_removal_set.update(inverters_to_remove) # map the end wire of the inverter chain to the beginning wire - wire_src_dict[inverter_chain[-1]] = inverter_chain[start_idx-1] + wire_src_dict[inverter_chain[-1]] = inverter_chain[start_idx - 1] + # This loop recreates the LogicNet with inverter chains removed. It adds each + # block in the original LogicNet to the new LogicNet if it is not marked for + # removal, and replaces the source of the block if its source was the end wire + # of a removed inverter chain. for net in block.logic: if net not in net_removal_set: new_logic.add(LogicNet(net.op, net.op_param, From 7e7b69c1b895f085cde76bef48b10479b1aa3753 Mon Sep 17 00:00:00 2001 From: Gabor Szita Date: Fri, 21 Mar 2025 13:32:45 -0700 Subject: [PATCH 17/19] Use full sentences in commits --- pyrtl/passes.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/pyrtl/passes.py b/pyrtl/passes.py index 70228299..6c810435 100644 --- a/pyrtl/passes.py +++ b/pyrtl/passes.py @@ -163,20 +163,20 @@ def _remove_inverter_chains(block, skip_sanity_check=False): wire_src_dict = _ProducerList() for inverter_chain in _get_inverter_chains(wire_creator, wire_users): - # if len(inverter_chain) = n, there are n-1 inverters in the chain - # only remove inverters if there are at least two inverters in a chain + # If len(inverter_chain) = n, there are n-1 inverters in the chain. + # We only remove inverters if there are at least two inverters in a chain. if len(inverter_chain) > 2: - if len(inverter_chain) % 2 == 1: # even number of inverters in a chain + if len(inverter_chain) % 2 == 1: # There is an even number of inverters in a chain. start_idx = 1 - else: # odd number of inverters in a chain + else: # There is an odd number of inverters in a chain. start_idx = 2 - # remove wires used in the inverter chain + # Remove wires used in the inverter chain. wires_to_remove = inverter_chain[start_idx:] wire_removal_set.update(wires_to_remove) - # remove inverters used in the chain + # Remove inverters used in the chain. inverters_to_remove = {wire_creator[wire] for wire in wires_to_remove} net_removal_set.update(inverters_to_remove) - # map the end wire of the inverter chain to the beginning wire + # Map the end wire of the inverter chain to the beginning wire. wire_src_dict[inverter_chain[-1]] = inverter_chain[start_idx - 1] # This loop recreates the LogicNet with inverter chains removed. It adds each From 70d0571335aa0d4965b9d9d513f5d77cb325fe0d Mon Sep 17 00:00:00 2001 From: Gabor Szita Date: Mon, 24 Mar 2025 14:39:07 -0700 Subject: [PATCH 18/19] fix comments --- pyrtl/passes.py | 31 +++++++++++++++++++++---------- 1 file changed, 21 insertions(+), 10 deletions(-) diff --git a/pyrtl/passes.py b/pyrtl/passes.py index 6c810435..d2a415e6 100644 --- a/pyrtl/passes.py +++ b/pyrtl/passes.py @@ -52,14 +52,25 @@ def optimize(update_working_block=True, block=None, skip_sanity_check=False): constant_propagation(block, True) _remove_unlistened_nets(block) common_subexp_elimination(block) - _remove_inverter_chains(block, skip_sanity_check) + _optimize_inverter_chains(block, skip_sanity_check) if (not skip_sanity_check) or _get_debug_mode(): block.sanity_check() return block def _get_inverter_chains(wire_creator, wire_users): - """Returns all inverter chains in the block""" + """Returns all inverter chains in the block. + + The function returns a list of inverter chains in the block. + Each inverter chain is represented as a list of the LogicNets + in the chain. + + Consider the following circuit, for example: + A -~-> B -~-> C -w-> X + D -~-> E -w-> Y + If the function is called on this circuit, it will return + [[A, B, C], [D, E]]. + """ # Build a list of inverter chains. Each inverter chain is a list of WireVectors, # from source to destination. @@ -116,8 +127,8 @@ def _get_inverter_chains(wire_creator, wire_users): return inverter_chains -def _remove_inverter_chains(block, skip_sanity_check=False): - """ Removes all inverter chains from the block. +def _optimize_inverter_chains(block, skip_sanity_check=False): + """ Optimizes inverter chains in the block. An inverter chain means two or more inverters directly connected to each other. Inverter chains are redundant and can be removed. @@ -155,9 +166,9 @@ def _remove_inverter_chains(block, skip_sanity_check=False): # This is the optimized version of the circuit: # A -w-> X # A -w-> Y - # The inverter chains found will be B-C and D-E (two separate chains will be - # found instead of B-C-D-E because C has an intermediate user). In the dict, - # C will be mapped to A and E will be maped to C. Hence, when finding the + # The inverter chains found will be A-B-C and C-D-E (two separate chains will be + # found instead of A-B-C-D-E because C has an intermediate user). In the dict, + # C will be mapped to A and E will be mapped to C. Hence, when finding the # replacement of E, we have to first query the dict to get C, and then query # the dict again on C to get A. wire_src_dict = _ProducerList() @@ -179,9 +190,9 @@ def _remove_inverter_chains(block, skip_sanity_check=False): # Map the end wire of the inverter chain to the beginning wire. wire_src_dict[inverter_chain[-1]] = inverter_chain[start_idx - 1] - # This loop recreates the LogicNet with inverter chains removed. It adds each - # block in the original LogicNet to the new LogicNet if it is not marked for - # removal, and replaces the source of the block if its source was the end wire + # This loop recreates the block with inverter chains removed. It adds each + # LogicNet in the original block to the new block if it is not marked for + # removal, and replaces the source of the LocigNet if its source was the end wire # of a removed inverter chain. for net in block.logic: if net not in net_removal_set: From 4306a0082bd4f1d937db1d2a40d97ee1a1561be7 Mon Sep 17 00:00:00 2001 From: Gabor Szita Date: Tue, 25 Mar 2025 02:00:19 -0700 Subject: [PATCH 19/19] fix comment typos --- pyrtl/passes.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyrtl/passes.py b/pyrtl/passes.py index d2a415e6..44d78237 100644 --- a/pyrtl/passes.py +++ b/pyrtl/passes.py @@ -62,7 +62,7 @@ def _get_inverter_chains(wire_creator, wire_users): """Returns all inverter chains in the block. The function returns a list of inverter chains in the block. - Each inverter chain is represented as a list of the LogicNets + Each inverter chain is represented as a list of the WireVectors in the chain. Consider the following circuit, for example: @@ -192,7 +192,7 @@ def _optimize_inverter_chains(block, skip_sanity_check=False): # This loop recreates the block with inverter chains removed. It adds each # LogicNet in the original block to the new block if it is not marked for - # removal, and replaces the source of the LocigNet if its source was the end wire + # removal, and replaces the source of the LogicNet if its source was the end wire # of a removed inverter chain. for net in block.logic: if net not in net_removal_set: