diff --git a/pyrtl/__init__.py b/pyrtl/__init__.py index cdc3f2c2..821b9941 100644 --- a/pyrtl/__init__.py +++ b/pyrtl/__init__.py @@ -40,6 +40,7 @@ from .helperfuncs import find_and_print_loop from .helperfuncs import wire_struct from .helperfuncs import wire_matrix +from .helperfuncs import one_hot_to_binary from .corecircuits import and_all_bits from .corecircuits import or_all_bits diff --git a/pyrtl/helperfuncs.py b/pyrtl/helperfuncs.py index 728dc2b3..f4d3ea6b 100644 --- a/pyrtl/helperfuncs.py +++ b/pyrtl/helperfuncs.py @@ -13,7 +13,7 @@ from .core import working_block, _NameIndexer, _get_debug_mode, Block from .pyrtlexceptions import PyrtlError, PyrtlInternalError from .wire import WireVector, Input, Output, Const, Register, WrappedWireVector -from .corecircuits import as_wires, rtl_all, rtl_any, concat, concat_list +from .corecircuits import as_wires, rtl_all, rtl_any, concat, concat_list, select # ----------------------------------------------------------------- # ___ __ ___ __ __ @@ -1683,3 +1683,34 @@ def __len__(self): return len(self._components) return _WireMatrix + + +def one_hot_to_binary(w) -> WireVector: + '''Takes a one-hot input and returns the bit position of the high bit in binary. + + :param w: WireVector or a WireVector-like object or something that can be converted + into a Const (in accordance with the :py:func:`as_wires()` required input). Example + inputs: 0b0010, 64, 0b01. + :return: The bit position of the high bit in binary as a WireVector. + + If the input contains multiple 1s, the bit position of the first 1 will + be returned. If the input contains no 1s, 0 will be returned. + + Examples:: + + one_hot_to_binary(0b0010) # returns 1 + one_hot_to_binary(64) # returns 6 + one_hot_to_binary(0b1100) # returns 2, the bit position of the first 1 + one_hot_to_binary(0) # returns 0 + ''' + + w = as_wires(w) + + pos = 0 # Bit position of the first 1 + already_found = as_wires(False) # True if first 1 already found, False otherwise + + for i in range(len(w)): + pos = select(w[i] & ~already_found, i, pos) + already_found = already_found | w[i] + + return pos diff --git a/tests/test_helperfuncs.py b/tests/test_helperfuncs.py index 4d9e75ef..972fb502 100644 --- a/tests/test_helperfuncs.py +++ b/tests/test_helperfuncs.py @@ -10,9 +10,9 @@ import pyrtl.helperfuncs from pyrtl.rtllib import testingutils as utils - # --------------------------------------------------------------- + class TestWireVectorList(unittest.TestCase): def setUp(self): pass @@ -1772,5 +1772,47 @@ def test_byte_matrix_input_concatenate(self): self.assertEqual(sim.inspect('byte_matrix[0].low'), 0xB) +class TestOneHotToBinary(unittest.TestCase): + def setUp(self): + pyrtl.reset_working_block() + + def test_simple_onehot(self): + i = pyrtl.Input(bitwidth=8, name='i') + o = pyrtl.Output(bitwidth=3, name='o') + o <<= pyrtl.one_hot_to_binary(i) + + sim = pyrtl.Simulation() + sim.step({i: 0b00000001}) + self.assertEqual(sim.inspect('o'), 0) + sim.step({i: 0b10000000}) + self.assertEqual(sim.inspect('o'), 7) + sim.step({i: 32}) + self.assertEqual(sim.inspect('o'), 5) + sim.step({i: 16}) + self.assertEqual(sim.inspect('o'), 4) + + def test_multiple_ones(self): + i = pyrtl.Input(bitwidth=8, name='i') + o = pyrtl.Output(bitwidth=3, name='o') + o <<= pyrtl.one_hot_to_binary(i) + + sim = pyrtl.Simulation() + sim.step({i: 0b00000101}) + self.assertEqual(sim.inspect('o'), 0) + sim.step({i: 0b11000000}) + self.assertEqual(sim.inspect('o'), 6) + sim.step({i: 0b10010010}) + self.assertEqual(sim.inspect('o'), 1) + + def test_no_ones(self): + i = pyrtl.Input(bitwidth=8, name='i') + o = pyrtl.Output(bitwidth=3, name='o') + o <<= pyrtl.one_hot_to_binary(i) + + sim = pyrtl.Simulation() + sim.step({i: 0b00000000}) + self.assertEqual(sim.inspect('o'), 0) + + if __name__ == "__main__": unittest.main()