Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 151 additions & 0 deletions pyrtl/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,162 @@ def optimize(update_working_block=True, block=None, skip_sanity_check=False):
constant_propagation(block, True)
_remove_unlistened_nets(block)
common_subexp_elimination(block)
_optimize_inverter_chains(block, skip_sanity_check)
if (not skip_sanity_check) or _get_debug_mode():
block.sanity_check()
return block


def _get_inverter_chains(wire_creator, wire_users):
"""Returns all inverter chains in the block.

The function returns a list of inverter chains in the block.
Each inverter chain is represented as a list of the WireVectors
in the chain.

Consider the following circuit, for example:
A -~-> B -~-> C -w-> X
D -~-> E -w-> Y
If the function is called on this circuit, it will return
[[A, B, C], [D, E]].
"""

# Build a list of inverter chains. Each inverter chain is a list of WireVectors,
# from source to destination.
inverter_chains = []
for current_dest, current_creator in wire_creator.items():
if current_creator.op != "~":
# Skip non-inverters.
continue

# The current inverter connects current_arg (a WireVector) to current_dest (also
# a WireVector).
current_arg = current_creator.args[0]
# current_users is the number of LogicNets that use current_dest.
current_users = len(wire_users[current_dest])

# Add the current inverter to the end of this inverter chain.
append_to = None
# Add the current inverter to the beginning of this inverter chain.
prepend_to = None
next_inverter_chains = []
for inverter_chain in inverter_chains:
chain_arg = inverter_chain[0]
chain_dest = inverter_chain[-1]
chain_users = len(wire_users[chain_dest])

if chain_dest is current_arg and chain_users == 1:
# This chain's only destination is the current inverter. Append the
# current inverter to the chain.
append_to = inverter_chain
elif chain_arg is current_dest and current_users == 1:
# This chain's only argument is the current inverter. Add the current
# inverter to the beginning of the chain.
prepend_to = inverter_chain
else:
# The current inverter is not connected to the inverter chain, so we
# pass the inverter chain through to next_inverter_chains
next_inverter_chains.append(inverter_chain)

if append_to and prepend_to:
# The current inverter joins two existing inverter chains.
next_inverter_chains.append(append_to + prepend_to)
elif append_to:
# Add the current inverter after 'append_to'.
next_inverter_chains.append(append_to + [current_dest])
elif prepend_to:
# Add the current inverter before 'prepend_to'.
next_inverter_chains.append([current_arg] + prepend_to)
else:
# The current inverter is not connected to any inverter chain, so
# we start a new inverter chain with it
next_inverter_chains.append([current_arg, current_dest])

inverter_chains = next_inverter_chains
return inverter_chains


def _optimize_inverter_chains(block, skip_sanity_check=False):
""" Optimizes inverter chains in the block.

An inverter chain means two or more inverters directly connected
to each other. Inverter chains are redundant and can be removed.
For example, A -~-> B -~-> C -w-> X can be reduced to A -w-> X.

After optimization, a chain of an even number of inverters will
be reduced a direct connection, and a chain of an odd number of
inverters will be reduced to one inverter.

If an inverter chain has intermediate users it won't be removed.
For example, the inverter chain in the following circuit won't be removed:
A -~-> B -~-> C -w-> X
B -w-> Y
"""

# wire_creator maps from WireVector to the LogicNet that defines its value.
# wire_users maps from WireVector to a list of LogicNets that use its value.
wire_creator, wire_users = block.net_connections()

new_logic = set()
net_removal_set = set()
wire_removal_set = set()

# This ProducerList maps the end wire of an inverter chain to its beginning wire.
# We need this because when removing an inverter chain its end wire gets removed,
# so we need to replace the source of LogicNets using the end wire of the inverter
# chain with the chain's beginning wire.
#
# We need a ProducerList, rather than a simple dict, because if an inverter chain
# of more than two inverters has intermediate users, we may have to query the dict
# multiple times to get the replacement for the inverter chain's last wire.
# Consider the following circuit, for example:
# A -~-> B -~-> C -w-> X
# C -~-> D -~-> E -w-> Y
# This is the optimized version of the circuit:
# A -w-> X
# A -w-> Y
# The inverter chains found will be A-B-C and C-D-E (two separate chains will be
# found instead of A-B-C-D-E because C has an intermediate user). In the dict,
# C will be mapped to A and E will be mapped to C. Hence, when finding the
# replacement of E, we have to first query the dict to get C, and then query
# the dict again on C to get A.
wire_src_dict = _ProducerList()

for inverter_chain in _get_inverter_chains(wire_creator, wire_users):
# If len(inverter_chain) = n, there are n-1 inverters in the chain.
# We only remove inverters if there are at least two inverters in a chain.
if len(inverter_chain) > 2:
if len(inverter_chain) % 2 == 1: # There is an even number of inverters in a chain.
start_idx = 1
else: # There is an odd number of inverters in a chain.
start_idx = 2
# Remove wires used in the inverter chain.
wires_to_remove = inverter_chain[start_idx:]
wire_removal_set.update(wires_to_remove)
# Remove inverters used in the chain.
inverters_to_remove = {wire_creator[wire] for wire in wires_to_remove}
net_removal_set.update(inverters_to_remove)
# Map the end wire of the inverter chain to the beginning wire.
wire_src_dict[inverter_chain[-1]] = inverter_chain[start_idx - 1]

# This loop recreates the block with inverter chains removed. It adds each
# LogicNet in the original block to the new block if it is not marked for
# removal, and replaces the source of the LogicNet if its source was the end wire
# of a removed inverter chain.
for net in block.logic:
if net not in net_removal_set:
new_logic.add(LogicNet(net.op, net.op_param,
args=tuple(wire_src_dict.find_producer(x) for x in net.args),
dests=net.dests))

block.logic = new_logic
for dead_wirevector in wire_removal_set:
block.remove_wirevector(dead_wirevector)

if (not skip_sanity_check) or _get_debug_mode():
block.sanity_check()


class _ProducerList(object):
""" Maps from wire to its immediate producer and finds ultimate producers. """
def __init__(self):
Expand Down
81 changes: 81 additions & 0 deletions tests/test_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,87 @@ def test_slice_net_removal_4(self):
self.num_net_of_type('s', 1, block)
self.num_net_of_type('w', 2, block)

def test_remove_double_inverts_1_invert(self):
inwire = pyrtl.Input(bitwidth=1)
outwire = pyrtl.Output(bitwidth=1)
outwire <<= ~inwire
pyrtl.optimize()
block = pyrtl.working_block()
self.assert_num_net(2, block)
self.assert_num_wires(3, block)

def test_remove_double_inverts_3_inverts(self):
inwire = pyrtl.Input(bitwidth=1)
outwire = pyrtl.Output(bitwidth=1)
outwire <<= ~(~(~inwire))
pyrtl.optimize()
block = pyrtl.working_block()
self.assert_num_net(2, block)
self.assert_num_wires(3, block)

def test_remove_double_inverts_5_inverts(self):
inwire = pyrtl.Input(bitwidth=1)
outwire = pyrtl.Output(bitwidth=1)
outwire <<= ~(~(~(~(~inwire))))
pyrtl.optimize()
block = pyrtl.working_block()
self.assert_num_net(2, block)
self.assert_num_wires(3, block)

def test_remove_double_inverts_2_inverts(self):
inwire = pyrtl.Input(bitwidth=1)
outwire = pyrtl.Output(bitwidth=1)
outwire <<= ~(~inwire)
pyrtl.optimize()
block = pyrtl.working_block()
self.assert_num_net(1, block)
self.assert_num_wires(2, block)

def test_remove_double_inverts_4_inverts(self):
inwire = pyrtl.Input(bitwidth=1)
outwire = pyrtl.Output(bitwidth=1)
outwire <<= ~(~(~(~inwire)))
pyrtl.optimize()
block = pyrtl.working_block()
self.assert_num_net(1, block)
self.assert_num_wires(2, block)

def test_remove_double_inverts_6_inverts(self):
inwire = pyrtl.Input(bitwidth=1)
outwire = pyrtl.Output(bitwidth=1)
outwire <<= ~(~(~(~(~(~inwire)))))
pyrtl.optimize()
block = pyrtl.working_block()
self.assert_num_net(1, block)
self.assert_num_wires(2, block)

def test_dont_remove_double_inverts_another_user(self):
inwire = pyrtl.Input(bitwidth=1)
outwire = pyrtl.Output(bitwidth=1)
outwire2 = pyrtl.Output(bitwidth=1)
tempwire = pyrtl.WireVector()
tempwire <<= ~inwire
outwire <<= ~tempwire
outwire2 <<= tempwire
pyrtl.optimize()
block = pyrtl.working_block()
self.assert_num_net(4, block)
self.assert_num_wires(5, block)

def test_multiple_double_invert_chains(self):
# _remove_double_inverts removes double inverts by chains,
# so it is useful to make sure it can remove
# double inverts from multiple chains
inwire = pyrtl.Input(bitwidth=1)
outwire = pyrtl.Output(bitwidth=1)
outwire2 = pyrtl.Output(bitwidth=1)
outwire <<= ~(~inwire)
outwire2 <<= ~(~(~(~(inwire))))
pyrtl.optimize()
block = pyrtl.working_block()
self.assert_num_net(2, block)
self.assert_num_wires(3, block)


class TestConstFolding(NetWireNumTestCases):
def setUp(self):
Expand Down