Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/helpers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -111,4 +111,5 @@ Encoders and Decoders
---------------------

.. autofunction:: pyrtl.helperfuncs.one_hot_to_binary
.. autofunction:: pyrtl.helperfuncs.binary_to_one_hot

1 change: 1 addition & 0 deletions pyrtl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from .helperfuncs import wire_struct
from .helperfuncs import wire_matrix
from .helperfuncs import one_hot_to_binary
from .helperfuncs import binary_to_one_hot

from .corecircuits import and_all_bits
from .corecircuits import or_all_bits
Expand Down
42 changes: 41 additions & 1 deletion pyrtl/helperfuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,15 @@
from .core import working_block, _NameIndexer, _get_debug_mode, Block
from .pyrtlexceptions import PyrtlError, PyrtlInternalError
from .wire import WireVector, Input, Output, Const, Register, WrappedWireVector
from .corecircuits import as_wires, rtl_all, rtl_any, concat, concat_list, select
from .corecircuits import (
as_wires,
rtl_all,
rtl_any,
concat,
concat_list,
select,
shift_left_logical
)

# -----------------------------------------------------------------
# ___ __ ___ __ __
Expand Down Expand Up @@ -1715,3 +1723,35 @@ def one_hot_to_binary(w) -> WireVector:
already_found = already_found | w[i]

return pos


def binary_to_one_hot(bit_position, max_bitwidth: int = None) -> WireVector:
'''Takes an input representing a bit position and returns a WireVector
with that bit position set to 1 and the others to 0.

:param bit_position: WireVector, WireVector-like object, or something that can be converted
into a :py:class:`.Const` (in accordance with the :py:func:`.as_wires()`
required input). Example inputs: ``0b10``, ``0b1000``, ``4``.
:param max_bitwidth: Optional integer maximum bitwidth for the resulting one-hot WireVector.
:return: WireVector with the bit position given by the input set to 1 and all other bits
set to 0 (bit position 0 being the least significant bit).

If the max_bitwidth provided is not sufficient for the given bit_position to be set to 1,
a ``0`` WireVector of size max_bitwidth will be returned.

Examples::

binary_to_onehot(0) # returns 0b01
binary_to_onehot(3) # returns 0b1000
binary_to_onehot(0b100) # returns 0b10000
'''

bit_position = as_wires(bit_position)

if max_bitwidth is not None:
bitwidth = max_bitwidth
else:
bitwidth = 2 ** len(bit_position)

# Need to dynamically set the appropriate bit position since bit_position may not be a Const
return shift_left_logical(Const(1, bitwidth=bitwidth), bit_position)
42 changes: 42 additions & 0 deletions tests/test_helperfuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1814,5 +1814,47 @@ def test_no_ones(self):
self.assertEqual(sim.inspect('o'), 0)


class TestBinaryToOneHot(unittest.TestCase):
def setUp(self):
pyrtl.reset_working_block()

def test_simple_binary_to_one_hot(self):
bit_position = pyrtl.Input(bitwidth=8, name='bit_position')
one_hot = pyrtl.Output(name='one_hot')
one_hot <<= pyrtl.binary_to_one_hot(bit_position)

self.assertEqual(one_hot.bitwidth, 256)

sim = pyrtl.Simulation()
sim.step({bit_position: 0})
self.assertEqual(sim.inspect('one_hot'), 0b01)
sim.step({bit_position: 2})
self.assertEqual(sim.inspect('one_hot'), 0b0100)
sim.step({bit_position: 5})
self.assertEqual(sim.inspect('one_hot'), 0b00100000)
sim.step({bit_position: 12})
self.assertEqual(sim.inspect('one_hot'), 0b0001000000000000)
sim.step({bit_position: 15})
self.assertEqual(sim.inspect('one_hot'), 0b1000000000000000)

# Tests with the max_bitwidth set
def test_with_max_bitwidth(self):
bit_position = pyrtl.Input(bitwidth=8, name='bit_position')
one_hot = pyrtl.Output(name='one_hot')
one_hot <<= pyrtl.binary_to_one_hot(bit_position, max_bitwidth=4)

self.assertEqual(one_hot.bitwidth, 4)

sim = pyrtl.Simulation()
sim.step({bit_position: 0})
self.assertEqual(sim.inspect('one_hot'), 0b0001)
sim.step({bit_position: 3})
self.assertEqual(sim.inspect('one_hot'), 0b1000)

# The max_bitwidth set is not enough for a bit position of 4
sim.step({bit_position: 4})
self.assertEqual(sim.inspect('one_hot'), 0b0000)


if __name__ == "__main__":
unittest.main()