From bab5ba64bdbd449613a7525da9d87e7fc971cae4 Mon Sep 17 00:00:00 2001 From: vaani arora <91294025+vaaniarora@users.noreply.github.com> Date: Tue, 10 Dec 2024 21:07:04 -0800 Subject: [PATCH 1/2] Adding one-hot to binary helper function --- pyrtl/__init__.py | 1 + pyrtl/helperfuncs.py | 26 +++++++++++++++++++- tests/test_helperfuncs.py | 52 ++++++++++++++++++++++++++++++++++++++- 3 files changed, 77 insertions(+), 2 deletions(-) diff --git a/pyrtl/__init__.py b/pyrtl/__init__.py index cdc3f2c2..821b9941 100644 --- a/pyrtl/__init__.py +++ b/pyrtl/__init__.py @@ -40,6 +40,7 @@ from .helperfuncs import find_and_print_loop from .helperfuncs import wire_struct from .helperfuncs import wire_matrix +from .helperfuncs import one_hot_to_binary from .corecircuits import and_all_bits from .corecircuits import or_all_bits diff --git a/pyrtl/helperfuncs.py b/pyrtl/helperfuncs.py index 728dc2b3..27bc6897 100644 --- a/pyrtl/helperfuncs.py +++ b/pyrtl/helperfuncs.py @@ -13,7 +13,7 @@ from .core import working_block, _NameIndexer, _get_debug_mode, Block from .pyrtlexceptions import PyrtlError, PyrtlInternalError from .wire import WireVector, Input, Output, Const, Register, WrappedWireVector -from .corecircuits import as_wires, rtl_all, rtl_any, concat, concat_list +from .corecircuits import as_wires, rtl_all, rtl_any, concat, concat_list, mux, select # ----------------------------------------------------------------- # ___ __ ___ __ __ @@ -1683,3 +1683,27 @@ def __len__(self): return len(self._components) return _WireMatrix + +def one_hot_to_binary(w) -> WireVector: + '''Takes a one-hot input and returns the bit position of the high bit in binary. + + :param w: WireVector or a WireVector-like object or something that can be converted + into a Const (in accordance with the as_wires() required input). + :return: The bit position of the high bit in binary as a WireVector. + + If the input contains multiple 1s, the bit position of the first 1 will + be returned. If the input contains no 1s, 0 will be returned. + ''' + + if not isinstance(w, WireVector): + w = as_wires(w) + + n = len(w) + pos = 0 + found = as_wires(0) + + for i in range(n): + pos = select(found | ~w[i], pos, i) + found = select(found, found, w[i]) + + return pos diff --git a/tests/test_helperfuncs.py b/tests/test_helperfuncs.py index 4d9e75ef..74e97c58 100644 --- a/tests/test_helperfuncs.py +++ b/tests/test_helperfuncs.py @@ -10,7 +10,6 @@ import pyrtl.helperfuncs from pyrtl.rtllib import testingutils as utils - # --------------------------------------------------------------- class TestWireVectorList(unittest.TestCase): @@ -1771,6 +1770,57 @@ def test_byte_matrix_input_concatenate(self): self.assertEqual(sim.inspect('byte_matrix[0].high'), 0xA) self.assertEqual(sim.inspect('byte_matrix[0].low'), 0xB) +class TestOneHotToBinary(unittest.TestCase): + + def test_simple_onehot(self): + + pyrtl.reset_working_block() + + o1 = pyrtl.WireVector(name='o1') + o1 <<= pyrtl.one_hot_to_binary(0b00000001) + + o2 = pyrtl.WireVector(name='o2') + o2 <<= pyrtl.one_hot_to_binary(0b10000000) + + o3 = pyrtl.WireVector(name='o3') + o3 <<= pyrtl.one_hot_to_binary(0b00100000) + + o4 = pyrtl.WireVector(name='o4') + o4 <<= pyrtl.one_hot_to_binary(0b00010000) + + sim = pyrtl.Simulation() + sim.step({}) + self.assertEqual(sim.inspect(o1), 0) + self.assertEqual(sim.inspect(o2), 7) + self.assertEqual(sim.inspect(o3), 5) + self.assertEqual(sim.inspect(o4), 4) + + def test_multiple_ones(self): + + pyrtl.reset_working_block() + + o5 = pyrtl.WireVector(name='o5') + o5 <<= pyrtl.one_hot_to_binary(0b00000101) + + o6 = pyrtl.WireVector(name='o6') + o6 <<= pyrtl.one_hot_to_binary(0b11000000) + + sim = pyrtl.Simulation() + sim.step({}) + self.assertEqual(sim.inspect(o5), 0) + self.assertEqual(sim.inspect(o6), 6) + + def test_no_ones(self): + + pyrtl.reset_working_block() + + o7 = pyrtl.WireVector(name='o7') + o7 <<= pyrtl.one_hot_to_binary(0b00000000) + + sim = pyrtl.Simulation() + sim.step({}) + self.assertEqual(sim.inspect(o7), 0) + if __name__ == "__main__": unittest.main() From e9bbb0c7a48972ea5a76b1056a8bec1ea99b9f29 Mon Sep 17 00:00:00 2001 From: vaani arora <91294025+vaaniarora@users.noreply.github.com> Date: Thu, 12 Dec 2024 13:35:36 -0800 Subject: [PATCH 2/2] Addressed comments for one-hot to binary helper function --- pyrtl/helperfuncs.py | 29 ++++++++++------- tests/test_helperfuncs.py | 66 +++++++++++++++++---------------------- 2 files changed, 47 insertions(+), 48 deletions(-) diff --git a/pyrtl/helperfuncs.py b/pyrtl/helperfuncs.py index 27bc6897..f4d3ea6b 100644 --- a/pyrtl/helperfuncs.py +++ b/pyrtl/helperfuncs.py @@ -13,7 +13,7 @@ from .core import working_block, _NameIndexer, _get_debug_mode, Block from .pyrtlexceptions import PyrtlError, PyrtlInternalError from .wire import WireVector, Input, Output, Const, Register, WrappedWireVector -from .corecircuits import as_wires, rtl_all, rtl_any, concat, concat_list, mux, select +from .corecircuits import as_wires, rtl_all, rtl_any, concat, concat_list, select # ----------------------------------------------------------------- # ___ __ ___ __ __ @@ -1684,26 +1684,33 @@ def __len__(self): return _WireMatrix + def one_hot_to_binary(w) -> WireVector: '''Takes a one-hot input and returns the bit position of the high bit in binary. :param w: WireVector or a WireVector-like object or something that can be converted - into a Const (in accordance with the as_wires() required input). + into a Const (in accordance with the :py:func:`as_wires()` required input). Example + inputs: 0b0010, 64, 0b01. :return: The bit position of the high bit in binary as a WireVector. If the input contains multiple 1s, the bit position of the first 1 will - be returned. If the input contains no 1s, 0 will be returned. + be returned. If the input contains no 1s, 0 will be returned. + + Examples:: + + one_hot_to_binary(0b0010) # returns 1 + one_hot_to_binary(64) # returns 6 + one_hot_to_binary(0b1100) # returns 2, the bit position of the first 1 + one_hot_to_binary(0) # returns 0 ''' - if not isinstance(w, WireVector): - w = as_wires(w) + w = as_wires(w) - n = len(w) - pos = 0 - found = as_wires(0) + pos = 0 # Bit position of the first 1 + already_found = as_wires(False) # True if first 1 already found, False otherwise - for i in range(n): - pos = select(found | ~w[i], pos, i) - found = select(found, found, w[i]) + for i in range(len(w)): + pos = select(w[i] & ~already_found, i, pos) + already_found = already_found | w[i] return pos diff --git a/tests/test_helperfuncs.py b/tests/test_helperfuncs.py index 74e97c58..972fb502 100644 --- a/tests/test_helperfuncs.py +++ b/tests/test_helperfuncs.py @@ -12,6 +12,7 @@ # --------------------------------------------------------------- + class TestWireVectorList(unittest.TestCase): def setUp(self): pass @@ -1770,56 +1771,47 @@ def test_byte_matrix_input_concatenate(self): self.assertEqual(sim.inspect('byte_matrix[0].high'), 0xA) self.assertEqual(sim.inspect('byte_matrix[0].low'), 0xB) + class TestOneHotToBinary(unittest.TestCase): - - def test_simple_onehot(self): - + def setUp(self): pyrtl.reset_working_block() - - o1 = pyrtl.WireVector(name='o1') - o1 <<= pyrtl.one_hot_to_binary(0b00000001) - o2 = pyrtl.WireVector(name='o2') - o2 <<= pyrtl.one_hot_to_binary(0b10000000) - - o3 = pyrtl.WireVector(name='o3') - o3 <<= pyrtl.one_hot_to_binary(0b00100000) + def test_simple_onehot(self): + i = pyrtl.Input(bitwidth=8, name='i') + o = pyrtl.Output(bitwidth=3, name='o') + o <<= pyrtl.one_hot_to_binary(i) - o4 = pyrtl.WireVector(name='o4') - o4 <<= pyrtl.one_hot_to_binary(0b00010000) - sim = pyrtl.Simulation() - sim.step({}) - self.assertEqual(sim.inspect(o1), 0) - self.assertEqual(sim.inspect(o2), 7) - self.assertEqual(sim.inspect(o3), 5) - self.assertEqual(sim.inspect(o4), 4) + sim.step({i: 0b00000001}) + self.assertEqual(sim.inspect('o'), 0) + sim.step({i: 0b10000000}) + self.assertEqual(sim.inspect('o'), 7) + sim.step({i: 32}) + self.assertEqual(sim.inspect('o'), 5) + sim.step({i: 16}) + self.assertEqual(sim.inspect('o'), 4) def test_multiple_ones(self): - - pyrtl.reset_working_block() - - o5 = pyrtl.WireVector(name='o5') - o5 <<= pyrtl.one_hot_to_binary(0b00000101) - - o6 = pyrtl.WireVector(name='o6') - o6 <<= pyrtl.one_hot_to_binary(0b11000000) + i = pyrtl.Input(bitwidth=8, name='i') + o = pyrtl.Output(bitwidth=3, name='o') + o <<= pyrtl.one_hot_to_binary(i) sim = pyrtl.Simulation() - sim.step({}) - self.assertEqual(sim.inspect(o5), 0) - self.assertEqual(sim.inspect(o6), 6) + sim.step({i: 0b00000101}) + self.assertEqual(sim.inspect('o'), 0) + sim.step({i: 0b11000000}) + self.assertEqual(sim.inspect('o'), 6) + sim.step({i: 0b10010010}) + self.assertEqual(sim.inspect('o'), 1) def test_no_ones(self): + i = pyrtl.Input(bitwidth=8, name='i') + o = pyrtl.Output(bitwidth=3, name='o') + o <<= pyrtl.one_hot_to_binary(i) - pyrtl.reset_working_block() - - o7 = pyrtl.WireVector(name='o7') - o7 <<= pyrtl.one_hot_to_binary(0b00000000) - sim = pyrtl.Simulation() - sim.step({}) - self.assertEqual(sim.inspect(o7), 0) + sim.step({i: 0b00000000}) + self.assertEqual(sim.inspect('o'), 0) if __name__ == "__main__":