From 46b4bac0ee85c28054d237a7d228e68df4b98b41 Mon Sep 17 00:00:00 2001 From: Edwin Chang Date: Tue, 17 Jun 2025 23:29:12 -0400 Subject: [PATCH 1/3] fix XOR implementation in and_inverter_synth pass --- pyrtl/passes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyrtl/passes.py b/pyrtl/passes.py index 6765ef47..e26d1af3 100644 --- a/pyrtl/passes.py +++ b/pyrtl/passes.py @@ -848,7 +848,7 @@ def arg(num): elif net.op == '^': all_1 = arg(0) & arg(1) all_0 = ~arg(0) & ~arg(1) - dest <<= all_0 & ~all_1 + dest <<= ~all_0 & ~all_1 elif net.op == 'n': dest <<= ~(arg(0) & arg(1)) else: From f20fbfe887d536bc7b39a51ce541110d523ea7e0 Mon Sep 17 00:00:00 2001 From: Edwin Chang Date: Wed, 18 Jun 2025 22:36:55 -0400 Subject: [PATCH 2/3] add tests for nand_synth and and_inverter_synth --- tests/test_passes.py | 65 ++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 63 insertions(+), 2 deletions(-) diff --git a/tests/test_passes.py b/tests/test_passes.py index f5dcbcd5..5a3ea201 100644 --- a/tests/test_passes.py +++ b/tests/test_passes.py @@ -5,9 +5,8 @@ import unittest import pyrtl -from pyrtl.wire import Const, Output from pyrtl.rtllib import testingutils as utils - +from pyrtl.wire import Const, Output from .test_transform import NetWireNumTestCases @@ -733,6 +732,68 @@ def test_nested_elimination(self): pyrtl.working_block().sanity_check() +class TestSynthPasses(NetWireNumTestCases): + in1: pyrtl.Input + in2: pyrtl.Input + out: pyrtl.Output + + def setUp(self): + pyrtl.reset_working_block() + self.in1 = pyrtl.Input(bitwidth=5, name='in1') + self.in2 = pyrtl.Input(bitwidth=5, name='in2') + self.out = pyrtl.Output(bitwidth=5, name='out') + + def run_nand_synth(self, values: dict[tuple[int, int], int]): + pyrtl.synthesize() + pyrtl.nand_synth() + sim_trace = pyrtl.SimulationTrace() + sim = pyrtl.Simulation(tracer=sim_trace) + for inputs, output in values.items(): + sim.step({'in1': inputs[0], 'in2': inputs[1]}) + self.assertEqual(sim.inspect('out'), output, msg=f"Failed on inputs {inputs[0]} and {inputs[1]}") + + def test_nand_synth_and(self): + self.out <<= self.in1 & self.in2 + self.run_nand_synth({(1, 2): 0, (4, 5): 4, (7, 11): 3}) + + def test_nand_synth_or(self): + self.out <<= self.in1 | self.in2 + self.run_nand_synth({(1, 2): 3, (4, 5): 5, (7, 11): 15}) + + def test_nand_synth_xor(self): + self.out <<= self.in1 ^ self.in2 + self.run_nand_synth({(1, 2): 3, (4, 5): 1, (7, 11): 12}) + + def test_nand_synth_adder(self): + self.out <<= self.in1 + self.in2 + self.run_nand_synth({(1, 2): 3, (4, 5): 9, (7, 11): 18}) + + def run_and_inverter_synth(self, values: dict[tuple[int, int], int]): + pyrtl.synthesize() + pyrtl.and_inverter_synth() + sim_trace = pyrtl.SimulationTrace() + sim = pyrtl.Simulation(tracer=sim_trace) + for inputs, output in values.items(): + sim.step({'in1': inputs[0], 'in2': inputs[1]}) + self.assertEqual(sim.inspect('out'), output, msg=f"Failed on inputs {inputs[0]} and {inputs[1]}") + + def test_and_inverter_synth_and(self): + self.out <<= self.in1 & self.in2 + self.run_and_inverter_synth({(1, 2): 0, (4, 5): 4, (7, 11): 3}) + + def test_and_inverter_synth_or(self): + self.out <<= self.in1 | self.in2 + self.run_and_inverter_synth({(1, 2): 3, (4, 5): 5, (7, 11): 15}) + + def test_and_inverter_synth_xor(self): + self.out <<= self.in1 ^ self.in2 + self.run_and_inverter_synth({(1, 2): 3, (4, 5): 1, (7, 11): 12}) + + def test_and_inverter_synth_adder(self): + self.out <<= self.in1 + self.in2 + self.run_and_inverter_synth({(1, 2): 3, (4, 5): 9, (7, 11): 18}) + + class TestSynthOptTiming(NetWireNumTestCases): def setUp(self): pyrtl.reset_working_block() From 8f35941e26ed999e66630f2dc1d5823ac0e25192 Mon Sep 17 00:00:00 2001 From: Edwin Chang Date: Thu, 19 Jun 2025 11:03:34 -0400 Subject: [PATCH 3/3] rework synth pass tests --- tests/test_passes.py | 86 +++++++++++++++++++++++++++----------------- 1 file changed, 53 insertions(+), 33 deletions(-) diff --git a/tests/test_passes.py b/tests/test_passes.py index 5a3ea201..bdd66112 100644 --- a/tests/test_passes.py +++ b/tests/test_passes.py @@ -3,6 +3,7 @@ import os import sys import unittest +from typing import Callable import pyrtl from pyrtl.rtllib import testingutils as utils @@ -732,66 +733,85 @@ def test_nested_elimination(self): pyrtl.working_block().sanity_check() -class TestSynthPasses(NetWireNumTestCases): +class TestSynthPasses(unittest.TestCase): + in0: pyrtl.Input in1: pyrtl.Input - in2: pyrtl.Input out: pyrtl.Output def setUp(self): pyrtl.reset_working_block() + self.in0 = pyrtl.Input(bitwidth=5, name='in0') self.in1 = pyrtl.Input(bitwidth=5, name='in1') - self.in2 = pyrtl.Input(bitwidth=5, name='in2') self.out = pyrtl.Output(bitwidth=5, name='out') - def run_nand_synth(self, values: dict[tuple[int, int], int]): - pyrtl.synthesize() - pyrtl.nand_synth() + def check_synth(self, operation: Callable[[int, int], int]): + """ + Simulates the current circuit with some test input pairs (in0, in1) and checks the outputs + against the provided operation. Any synthesis/passes should be run before this gets called. + + :param operation: The operation to test against. This should be a lambda taking an input + pair (in0, in1) and returning the expected output from this circuit. + """ sim_trace = pyrtl.SimulationTrace() sim = pyrtl.Simulation(tracer=sim_trace) - for inputs, output in values.items(): - sim.step({'in1': inputs[0], 'in2': inputs[1]}) - self.assertEqual(sim.inspect('out'), output, msg=f"Failed on inputs {inputs[0]} and {inputs[1]}") + + values = [(1, 2), (4, 5), (7, 11)] + """A list of input pairs (in0, in1) to test on.""" + + for in0, in1 in values: + expected_output = operation(in0, in1) + sim.step({'in0': in0, 'in1': in1}) + # compare simulation output to expected output + self.assertEqual(sim.inspect('out'), expected_output, + msg=f"Failed on inputs {in0} and {in1}") def test_nand_synth_and(self): - self.out <<= self.in1 & self.in2 - self.run_nand_synth({(1, 2): 0, (4, 5): 4, (7, 11): 3}) + self.out <<= self.in0 & self.in1 + pyrtl.synthesize() + pyrtl.nand_synth() + self.check_synth(lambda a, b: a & b) def test_nand_synth_or(self): - self.out <<= self.in1 | self.in2 - self.run_nand_synth({(1, 2): 3, (4, 5): 5, (7, 11): 15}) + self.out <<= self.in0 | self.in1 + pyrtl.synthesize() + pyrtl.nand_synth() + self.check_synth(lambda a, b: a | b) def test_nand_synth_xor(self): - self.out <<= self.in1 ^ self.in2 - self.run_nand_synth({(1, 2): 3, (4, 5): 1, (7, 11): 12}) + self.out <<= self.in0 ^ self.in1 + pyrtl.synthesize() + pyrtl.nand_synth() + self.check_synth(lambda a, b: a ^ b) def test_nand_synth_adder(self): - self.out <<= self.in1 + self.in2 - self.run_nand_synth({(1, 2): 3, (4, 5): 9, (7, 11): 18}) - - def run_and_inverter_synth(self, values: dict[tuple[int, int], int]): + self.out <<= self.in0 + self.in1 pyrtl.synthesize() - pyrtl.and_inverter_synth() - sim_trace = pyrtl.SimulationTrace() - sim = pyrtl.Simulation(tracer=sim_trace) - for inputs, output in values.items(): - sim.step({'in1': inputs[0], 'in2': inputs[1]}) - self.assertEqual(sim.inspect('out'), output, msg=f"Failed on inputs {inputs[0]} and {inputs[1]}") + pyrtl.nand_synth() + self.check_synth(lambda a, b: a + b) def test_and_inverter_synth_and(self): - self.out <<= self.in1 & self.in2 - self.run_and_inverter_synth({(1, 2): 0, (4, 5): 4, (7, 11): 3}) + self.out <<= self.in0 & self.in1 + pyrtl.synthesize() + pyrtl.and_inverter_synth() + self.check_synth(lambda a, b: a & b) def test_and_inverter_synth_or(self): - self.out <<= self.in1 | self.in2 - self.run_and_inverter_synth({(1, 2): 3, (4, 5): 5, (7, 11): 15}) + self.out <<= self.in0 | self.in1 + pyrtl.synthesize() + pyrtl.and_inverter_synth() + self.check_synth(lambda a, b: a | b) def test_and_inverter_synth_xor(self): - self.out <<= self.in1 ^ self.in2 - self.run_and_inverter_synth({(1, 2): 3, (4, 5): 1, (7, 11): 12}) + self.out <<= self.in0 ^ self.in1 + pyrtl.synthesize() + pyrtl.and_inverter_synth() + self.check_synth(lambda a, b: a ^ b) def test_and_inverter_synth_adder(self): - self.out <<= self.in1 + self.in2 - self.run_and_inverter_synth({(1, 2): 3, (4, 5): 9, (7, 11): 18}) + self.out <<= self.in0 + self.in1 + pyrtl.synthesize() + pyrtl.and_inverter_synth() + self.check_synth(lambda a, b: a + b) class TestSynthOptTiming(NetWireNumTestCases):