diff --git a/pyrtl/passes.py b/pyrtl/passes.py index 0b684303..44d78237 100644 --- a/pyrtl/passes.py +++ b/pyrtl/passes.py @@ -52,11 +52,162 @@ def optimize(update_working_block=True, block=None, skip_sanity_check=False): constant_propagation(block, True) _remove_unlistened_nets(block) common_subexp_elimination(block) + _optimize_inverter_chains(block, skip_sanity_check) if (not skip_sanity_check) or _get_debug_mode(): block.sanity_check() return block +def _get_inverter_chains(wire_creator, wire_users): + """Returns all inverter chains in the block. + + The function returns a list of inverter chains in the block. + Each inverter chain is represented as a list of the WireVectors + in the chain. + + Consider the following circuit, for example: + A -~-> B -~-> C -w-> X + D -~-> E -w-> Y + If the function is called on this circuit, it will return + [[A, B, C], [D, E]]. + """ + + # Build a list of inverter chains. Each inverter chain is a list of WireVectors, + # from source to destination. + inverter_chains = [] + for current_dest, current_creator in wire_creator.items(): + if current_creator.op != "~": + # Skip non-inverters. + continue + + # The current inverter connects current_arg (a WireVector) to current_dest (also + # a WireVector). + current_arg = current_creator.args[0] + # current_users is the number of LogicNets that use current_dest. + current_users = len(wire_users[current_dest]) + + # Add the current inverter to the end of this inverter chain. + append_to = None + # Add the current inverter to the beginning of this inverter chain. + prepend_to = None + next_inverter_chains = [] + for inverter_chain in inverter_chains: + chain_arg = inverter_chain[0] + chain_dest = inverter_chain[-1] + chain_users = len(wire_users[chain_dest]) + + if chain_dest is current_arg and chain_users == 1: + # This chain's only destination is the current inverter. Append the + # current inverter to the chain. + append_to = inverter_chain + elif chain_arg is current_dest and current_users == 1: + # This chain's only argument is the current inverter. Add the current + # inverter to the beginning of the chain. + prepend_to = inverter_chain + else: + # The current inverter is not connected to the inverter chain, so we + # pass the inverter chain through to next_inverter_chains + next_inverter_chains.append(inverter_chain) + + if append_to and prepend_to: + # The current inverter joins two existing inverter chains. + next_inverter_chains.append(append_to + prepend_to) + elif append_to: + # Add the current inverter after 'append_to'. + next_inverter_chains.append(append_to + [current_dest]) + elif prepend_to: + # Add the current inverter before 'prepend_to'. + next_inverter_chains.append([current_arg] + prepend_to) + else: + # The current inverter is not connected to any inverter chain, so + # we start a new inverter chain with it + next_inverter_chains.append([current_arg, current_dest]) + + inverter_chains = next_inverter_chains + return inverter_chains + + +def _optimize_inverter_chains(block, skip_sanity_check=False): + """ Optimizes inverter chains in the block. + + An inverter chain means two or more inverters directly connected + to each other. Inverter chains are redundant and can be removed. + For example, A -~-> B -~-> C -w-> X can be reduced to A -w-> X. + + After optimization, a chain of an even number of inverters will + be reduced a direct connection, and a chain of an odd number of + inverters will be reduced to one inverter. + + If an inverter chain has intermediate users it won't be removed. + For example, the inverter chain in the following circuit won't be removed: + A -~-> B -~-> C -w-> X + B -w-> Y + """ + + # wire_creator maps from WireVector to the LogicNet that defines its value. + # wire_users maps from WireVector to a list of LogicNets that use its value. + wire_creator, wire_users = block.net_connections() + + new_logic = set() + net_removal_set = set() + wire_removal_set = set() + + # This ProducerList maps the end wire of an inverter chain to its beginning wire. + # We need this because when removing an inverter chain its end wire gets removed, + # so we need to replace the source of LogicNets using the end wire of the inverter + # chain with the chain's beginning wire. + # + # We need a ProducerList, rather than a simple dict, because if an inverter chain + # of more than two inverters has intermediate users, we may have to query the dict + # multiple times to get the replacement for the inverter chain's last wire. + # Consider the following circuit, for example: + # A -~-> B -~-> C -w-> X + # C -~-> D -~-> E -w-> Y + # This is the optimized version of the circuit: + # A -w-> X + # A -w-> Y + # The inverter chains found will be A-B-C and C-D-E (two separate chains will be + # found instead of A-B-C-D-E because C has an intermediate user). In the dict, + # C will be mapped to A and E will be mapped to C. Hence, when finding the + # replacement of E, we have to first query the dict to get C, and then query + # the dict again on C to get A. + wire_src_dict = _ProducerList() + + for inverter_chain in _get_inverter_chains(wire_creator, wire_users): + # If len(inverter_chain) = n, there are n-1 inverters in the chain. + # We only remove inverters if there are at least two inverters in a chain. + if len(inverter_chain) > 2: + if len(inverter_chain) % 2 == 1: # There is an even number of inverters in a chain. + start_idx = 1 + else: # There is an odd number of inverters in a chain. + start_idx = 2 + # Remove wires used in the inverter chain. + wires_to_remove = inverter_chain[start_idx:] + wire_removal_set.update(wires_to_remove) + # Remove inverters used in the chain. + inverters_to_remove = {wire_creator[wire] for wire in wires_to_remove} + net_removal_set.update(inverters_to_remove) + # Map the end wire of the inverter chain to the beginning wire. + wire_src_dict[inverter_chain[-1]] = inverter_chain[start_idx - 1] + + # This loop recreates the block with inverter chains removed. It adds each + # LogicNet in the original block to the new block if it is not marked for + # removal, and replaces the source of the LogicNet if its source was the end wire + # of a removed inverter chain. + for net in block.logic: + if net not in net_removal_set: + new_logic.add(LogicNet(net.op, net.op_param, + args=tuple(wire_src_dict.find_producer(x) for x in net.args), + dests=net.dests)) + + block.logic = new_logic + for dead_wirevector in wire_removal_set: + block.remove_wirevector(dead_wirevector) + + if (not skip_sanity_check) or _get_debug_mode(): + block.sanity_check() + + class _ProducerList(object): """ Maps from wire to its immediate producer and finds ultimate producers. """ def __init__(self): diff --git a/tests/test_passes.py b/tests/test_passes.py index 332ad873..f5dcbcd5 100644 --- a/tests/test_passes.py +++ b/tests/test_passes.py @@ -316,6 +316,87 @@ def test_slice_net_removal_4(self): self.num_net_of_type('s', 1, block) self.num_net_of_type('w', 2, block) + def test_remove_double_inverts_1_invert(self): + inwire = pyrtl.Input(bitwidth=1) + outwire = pyrtl.Output(bitwidth=1) + outwire <<= ~inwire + pyrtl.optimize() + block = pyrtl.working_block() + self.assert_num_net(2, block) + self.assert_num_wires(3, block) + + def test_remove_double_inverts_3_inverts(self): + inwire = pyrtl.Input(bitwidth=1) + outwire = pyrtl.Output(bitwidth=1) + outwire <<= ~(~(~inwire)) + pyrtl.optimize() + block = pyrtl.working_block() + self.assert_num_net(2, block) + self.assert_num_wires(3, block) + + def test_remove_double_inverts_5_inverts(self): + inwire = pyrtl.Input(bitwidth=1) + outwire = pyrtl.Output(bitwidth=1) + outwire <<= ~(~(~(~(~inwire)))) + pyrtl.optimize() + block = pyrtl.working_block() + self.assert_num_net(2, block) + self.assert_num_wires(3, block) + + def test_remove_double_inverts_2_inverts(self): + inwire = pyrtl.Input(bitwidth=1) + outwire = pyrtl.Output(bitwidth=1) + outwire <<= ~(~inwire) + pyrtl.optimize() + block = pyrtl.working_block() + self.assert_num_net(1, block) + self.assert_num_wires(2, block) + + def test_remove_double_inverts_4_inverts(self): + inwire = pyrtl.Input(bitwidth=1) + outwire = pyrtl.Output(bitwidth=1) + outwire <<= ~(~(~(~inwire))) + pyrtl.optimize() + block = pyrtl.working_block() + self.assert_num_net(1, block) + self.assert_num_wires(2, block) + + def test_remove_double_inverts_6_inverts(self): + inwire = pyrtl.Input(bitwidth=1) + outwire = pyrtl.Output(bitwidth=1) + outwire <<= ~(~(~(~(~(~inwire))))) + pyrtl.optimize() + block = pyrtl.working_block() + self.assert_num_net(1, block) + self.assert_num_wires(2, block) + + def test_dont_remove_double_inverts_another_user(self): + inwire = pyrtl.Input(bitwidth=1) + outwire = pyrtl.Output(bitwidth=1) + outwire2 = pyrtl.Output(bitwidth=1) + tempwire = pyrtl.WireVector() + tempwire <<= ~inwire + outwire <<= ~tempwire + outwire2 <<= tempwire + pyrtl.optimize() + block = pyrtl.working_block() + self.assert_num_net(4, block) + self.assert_num_wires(5, block) + + def test_multiple_double_invert_chains(self): + # _remove_double_inverts removes double inverts by chains, + # so it is useful to make sure it can remove + # double inverts from multiple chains + inwire = pyrtl.Input(bitwidth=1) + outwire = pyrtl.Output(bitwidth=1) + outwire2 = pyrtl.Output(bitwidth=1) + outwire <<= ~(~inwire) + outwire2 <<= ~(~(~(~(inwire)))) + pyrtl.optimize() + block = pyrtl.working_block() + self.assert_num_net(2, block) + self.assert_num_wires(3, block) + class TestConstFolding(NetWireNumTestCases): def setUp(self):