From 637e77893c5c63590282f9f4f52d06e77469ea74 Mon Sep 17 00:00:00 2001 From: vaani arora <91294025+vaaniarora@users.noreply.github.com> Date: Thu, 20 Feb 2025 22:10:21 -0800 Subject: [PATCH 1/3] Add helper function for binary to 1-hot --- docs/helpers.rst | 1 + pyrtl/__init__.py | 1 + pyrtl/helperfuncs.py | 44 ++++++++++++++++++++++++++++++++++++++- tests/test_helperfuncs.py | 42 +++++++++++++++++++++++++++++++++++++ 4 files changed, 87 insertions(+), 1 deletion(-) diff --git a/docs/helpers.rst b/docs/helpers.rst index 1dba2319..894a2dcf 100644 --- a/docs/helpers.rst +++ b/docs/helpers.rst @@ -111,4 +111,5 @@ Encoders and Decoders --------------------- .. autofunction:: pyrtl.helperfuncs.one_hot_to_binary +.. autofunction:: pyrtl.helperfuncs.binary_to_one_hot diff --git a/pyrtl/__init__.py b/pyrtl/__init__.py index 821b9941..e30e6b41 100644 --- a/pyrtl/__init__.py +++ b/pyrtl/__init__.py @@ -41,6 +41,7 @@ from .helperfuncs import wire_struct from .helperfuncs import wire_matrix from .helperfuncs import one_hot_to_binary +from .helperfuncs import binary_to_one_hot from .corecircuits import and_all_bits from .corecircuits import or_all_bits diff --git a/pyrtl/helperfuncs.py b/pyrtl/helperfuncs.py index 2c8fec8b..6dd8a817 100644 --- a/pyrtl/helperfuncs.py +++ b/pyrtl/helperfuncs.py @@ -13,7 +13,15 @@ from .core import working_block, _NameIndexer, _get_debug_mode, Block from .pyrtlexceptions import PyrtlError, PyrtlInternalError from .wire import WireVector, Input, Output, Const, Register, WrappedWireVector -from .corecircuits import as_wires, rtl_all, rtl_any, concat, concat_list, select +from .corecircuits import ( + as_wires, + rtl_all, + rtl_any, + concat, + concat_list, + select, + shift_left_logical +) # ----------------------------------------------------------------- # ___ __ ___ __ __ @@ -1715,3 +1723,37 @@ def one_hot_to_binary(w) -> WireVector: already_found = already_found | w[i] return pos + + +def binary_to_one_hot(bit_position, max_bitwidth: int = None) -> WireVector: + '''Takes an input representing a bit position and returns a WireVector + with that bit position set to 1 and the others to 0. + + :param bit_position: WireVector, WireVector-like object, or something that can be converted + into a :py:class:`.Const` (in accordance with the :py:func:`.as_wires()` + required input). Example inputs: ``0b10``, ``0b1000, ``4``. + :param max_bitwidth: Optional integer maximum bitwidth for the resulting one-hot WireVector. + :return: WireVector with the bit position, given by the input, set to 1 and all others set to 0. + + If the max_bitwidth provided is not sufficient for the given bit_position to be set to 1, + a ``0`` WireVector of size max_bitwidth will be returned. + + Examples:: + + binary_to_onehot(2) # returns 0b0100 + binary_to_onehot(8) # returns 0b00100000 + binary_to_onehot(12) # returns 0b0001000000000000 + binary_to_onehot(15) # returns 0b1000000000000000 + + ''' + + bit_position = as_wires(bit_position) + + if max_bitwidth is not None: + bitwidth = max_bitwidth + else: + bitwidth = 2 ** len(bit_position) + + onehot = Const(1, bitwidth=bitwidth) + + return shift_left_logical(onehot, bit_position) diff --git a/tests/test_helperfuncs.py b/tests/test_helperfuncs.py index 972fb502..7f9b3d89 100644 --- a/tests/test_helperfuncs.py +++ b/tests/test_helperfuncs.py @@ -1814,5 +1814,47 @@ def test_no_ones(self): self.assertEqual(sim.inspect('o'), 0) +class TestBinaryToOneHot(unittest.TestCase): + def setUp(self): + pyrtl.reset_working_block() + + def test_simple_binary_to_one_hot(self): + i = pyrtl.Input(bitwidth=8, name='i') + o = pyrtl.Output(bitwidth=16, name='o') + o <<= pyrtl.binary_to_one_hot(i) + + sim = pyrtl.Simulation() + sim.step({i: 0}) + self.assertEqual(sim.inspect('o'), 0b01) + sim.step({i: 2}) + self.assertEqual(sim.inspect('o'), 0b0100) + sim.step({i: 5}) + self.assertEqual(sim.inspect('o'), 0b00100000) + sim.step({i: 12}) + self.assertEqual(sim.inspect('o'), 0b0001000000000000) + sim.step({i: 15}) + self.assertEqual(sim.inspect('o'), 0b1000000000000000) + + def test_sufficient_max_bitwidth(self): + i = pyrtl.Input(bitwidth=8, name='i') + o = pyrtl.Output(bitwidth=16, name='o') + o <<= pyrtl.binary_to_one_hot(i, max_bitwidth=8) + + sim = pyrtl.Simulation() + sim.step({i: 0}) + self.assertEqual(sim.inspect('o'), 0b0001) + sim.step({i: 6}) + self.assertEqual(sim.inspect('o'), 0b01000000) + + def test_insufficient_max_bitwidth(self): + i = pyrtl.Input(bitwidth=8, name='i') + o = pyrtl.Output(bitwidth=16, name='o') + o <<= pyrtl.binary_to_one_hot(i, max_bitwidth=4) + + sim = pyrtl.Simulation() + sim.step({i: 5}) + self.assertEqual(sim.inspect('o'), 0b0000) + + if __name__ == "__main__": unittest.main() From 53061cb3fe0e7a910bd0435b5cb59e56dfd50c49 Mon Sep 17 00:00:00 2001 From: vaani arora <91294025+vaaniarora@users.noreply.github.com> Date: Sun, 23 Feb 2025 22:26:20 -0800 Subject: [PATCH 2/3] Addressed comments for binary to onehot helper function --- pyrtl/helperfuncs.py | 18 ++++++------ tests/test_helperfuncs.py | 58 ++++++++++++++++++--------------------- 2 files changed, 35 insertions(+), 41 deletions(-) diff --git a/pyrtl/helperfuncs.py b/pyrtl/helperfuncs.py index 6dd8a817..d57722e8 100644 --- a/pyrtl/helperfuncs.py +++ b/pyrtl/helperfuncs.py @@ -1731,20 +1731,19 @@ def binary_to_one_hot(bit_position, max_bitwidth: int = None) -> WireVector: :param bit_position: WireVector, WireVector-like object, or something that can be converted into a :py:class:`.Const` (in accordance with the :py:func:`.as_wires()` - required input). Example inputs: ``0b10``, ``0b1000, ``4``. + required input). Example inputs: ``0b10``, ``0b1000``, ``4``. :param max_bitwidth: Optional integer maximum bitwidth for the resulting one-hot WireVector. - :return: WireVector with the bit position, given by the input, set to 1 and all others set to 0. + :return: WireVector with the bit position given by the input set to 1 and all other bits + set to 0 (bit position 0 being the least significant bit). If the max_bitwidth provided is not sufficient for the given bit_position to be set to 1, a ``0`` WireVector of size max_bitwidth will be returned. Examples:: - binary_to_onehot(2) # returns 0b0100 - binary_to_onehot(8) # returns 0b00100000 - binary_to_onehot(12) # returns 0b0001000000000000 - binary_to_onehot(15) # returns 0b1000000000000000 - + binary_to_onehot(0) # returns 0b01 + binary_to_onehot(3) # returns 0b1000 + binary_to_onehot(0b100) # returns 0b10000 ''' bit_position = as_wires(bit_position) @@ -1754,6 +1753,5 @@ def binary_to_one_hot(bit_position, max_bitwidth: int = None) -> WireVector: else: bitwidth = 2 ** len(bit_position) - onehot = Const(1, bitwidth=bitwidth) - - return shift_left_logical(onehot, bit_position) + # Need to dynamically set the appropriate bit position since bit_position may not be a Const + return shift_left_logical(Const(1, bitwidth=bitwidth), bit_position) diff --git a/tests/test_helperfuncs.py b/tests/test_helperfuncs.py index 7f9b3d89..2546608e 100644 --- a/tests/test_helperfuncs.py +++ b/tests/test_helperfuncs.py @@ -1819,41 +1819,37 @@ def setUp(self): pyrtl.reset_working_block() def test_simple_binary_to_one_hot(self): - i = pyrtl.Input(bitwidth=8, name='i') - o = pyrtl.Output(bitwidth=16, name='o') - o <<= pyrtl.binary_to_one_hot(i) - - sim = pyrtl.Simulation() - sim.step({i: 0}) - self.assertEqual(sim.inspect('o'), 0b01) - sim.step({i: 2}) - self.assertEqual(sim.inspect('o'), 0b0100) - sim.step({i: 5}) - self.assertEqual(sim.inspect('o'), 0b00100000) - sim.step({i: 12}) - self.assertEqual(sim.inspect('o'), 0b0001000000000000) - sim.step({i: 15}) - self.assertEqual(sim.inspect('o'), 0b1000000000000000) - - def test_sufficient_max_bitwidth(self): - i = pyrtl.Input(bitwidth=8, name='i') - o = pyrtl.Output(bitwidth=16, name='o') - o <<= pyrtl.binary_to_one_hot(i, max_bitwidth=8) + bit_position = pyrtl.Input(bitwidth=8, name='bit_position') + one_hot = pyrtl.Output(bitwidth=16, name='one_hot') + one_hot <<= pyrtl.binary_to_one_hot(bit_position) sim = pyrtl.Simulation() - sim.step({i: 0}) - self.assertEqual(sim.inspect('o'), 0b0001) - sim.step({i: 6}) - self.assertEqual(sim.inspect('o'), 0b01000000) - - def test_insufficient_max_bitwidth(self): - i = pyrtl.Input(bitwidth=8, name='i') - o = pyrtl.Output(bitwidth=16, name='o') - o <<= pyrtl.binary_to_one_hot(i, max_bitwidth=4) + sim.step({bit_position: 0}) + self.assertEqual(sim.inspect('one_hot'), 0b01) + sim.step({bit_position: 2}) + self.assertEqual(sim.inspect('one_hot'), 0b0100) + sim.step({bit_position: 5}) + self.assertEqual(sim.inspect('one_hot'), 0b00100000) + sim.step({bit_position: 12}) + self.assertEqual(sim.inspect('one_hot'), 0b0001000000000000) + sim.step({bit_position: 15}) + self.assertEqual(sim.inspect('one_hot'), 0b1000000000000000) + + # Tests with the max_bitwidth set + def test_with_max_bitwidth(self): + bit_position = pyrtl.Input(bitwidth=8, name='bit_position') + one_hot = pyrtl.Output(bitwidth=16, name='one_hot') + one_hot <<= pyrtl.binary_to_one_hot(bit_position, max_bitwidth=4) sim = pyrtl.Simulation() - sim.step({i: 5}) - self.assertEqual(sim.inspect('o'), 0b0000) + sim.step({bit_position: 0}) + self.assertEqual(sim.inspect('one_hot'), 0b0001) + sim.step({bit_position: 3}) + self.assertEqual(sim.inspect('one_hot'), 0b1000) + + # The max_bitwidth set is not enough for a bit position of 4 + sim.step({bit_position: 4}) + self.assertEqual(sim.inspect('one_hot'), 0b0000) if __name__ == "__main__": From c51b4dc24dac5d86fe14c9825ffe03652e069b5a Mon Sep 17 00:00:00 2001 From: vaani arora <91294025+vaaniarora@users.noreply.github.com> Date: Mon, 24 Feb 2025 11:01:29 -0800 Subject: [PATCH 3/3] Fixed test cases to check for bitwidth --- pyrtl/helperfuncs.py | 2 +- tests/test_helperfuncs.py | 8 ++++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/pyrtl/helperfuncs.py b/pyrtl/helperfuncs.py index d57722e8..9731b239 100644 --- a/pyrtl/helperfuncs.py +++ b/pyrtl/helperfuncs.py @@ -1734,7 +1734,7 @@ def binary_to_one_hot(bit_position, max_bitwidth: int = None) -> WireVector: required input). Example inputs: ``0b10``, ``0b1000``, ``4``. :param max_bitwidth: Optional integer maximum bitwidth for the resulting one-hot WireVector. :return: WireVector with the bit position given by the input set to 1 and all other bits - set to 0 (bit position 0 being the least significant bit). + set to 0 (bit position 0 being the least significant bit). If the max_bitwidth provided is not sufficient for the given bit_position to be set to 1, a ``0`` WireVector of size max_bitwidth will be returned. diff --git a/tests/test_helperfuncs.py b/tests/test_helperfuncs.py index 2546608e..a9798fee 100644 --- a/tests/test_helperfuncs.py +++ b/tests/test_helperfuncs.py @@ -1820,9 +1820,11 @@ def setUp(self): def test_simple_binary_to_one_hot(self): bit_position = pyrtl.Input(bitwidth=8, name='bit_position') - one_hot = pyrtl.Output(bitwidth=16, name='one_hot') + one_hot = pyrtl.Output(name='one_hot') one_hot <<= pyrtl.binary_to_one_hot(bit_position) + self.assertEqual(one_hot.bitwidth, 256) + sim = pyrtl.Simulation() sim.step({bit_position: 0}) self.assertEqual(sim.inspect('one_hot'), 0b01) @@ -1838,9 +1840,11 @@ def test_simple_binary_to_one_hot(self): # Tests with the max_bitwidth set def test_with_max_bitwidth(self): bit_position = pyrtl.Input(bitwidth=8, name='bit_position') - one_hot = pyrtl.Output(bitwidth=16, name='one_hot') + one_hot = pyrtl.Output(name='one_hot') one_hot <<= pyrtl.binary_to_one_hot(bit_position, max_bitwidth=4) + self.assertEqual(one_hot.bitwidth, 4) + sim = pyrtl.Simulation() sim.step({bit_position: 0}) self.assertEqual(sim.inspect('one_hot'), 0b0001)