diff --git a/docs/blocks.rst b/docs/blocks.rst index 4c0d5d6b..41b5b3ff 100644 --- a/docs/blocks.rst +++ b/docs/blocks.rst @@ -1,9 +1,13 @@ Block and Logic Nets ===================== -:class:`.Block` and :class:`.LogicNet` are lower level PyRTL abstractions. Most -users won't need to understand them, unless they are implementing -:ref:`analysis_and_optimization` passes or modifying PyRTL itself. +:class:`.Block` and :class:`.LogicNet` are lower level PyRTL abstractions for +representing a hardware design. Most users won't need to understand them, +unless they are implementing :ref:`analysis_and_optimization` passes or +modifying PyRTL itself. + +:ref:`gate_graphs` are an alternative representation that makes it easier to +write analysis passes. Blocks ------ @@ -34,3 +38,19 @@ LogicNets .. autoclass:: pyrtl.LogicNet :members: :undoc-members: + +.. _gate_graphs: + +GateGraphs +---------- + +.. automodule:: pyrtl.gate_graph + +.. autoclass:: pyrtl.Gate + :members: + :special-members: __str__ + +.. autoclass:: pyrtl.GateGraph + :members: + :special-members: __init__, __str__ + diff --git a/pyrtl/__init__.py b/pyrtl/__init__.py index 375495b9..2b75673a 100644 --- a/pyrtl/__init__.py +++ b/pyrtl/__init__.py @@ -19,6 +19,8 @@ # convenience classes for building hardware from .wire import WireVector, Input, Output, Const, Register +from .gate_graph import GateGraph, Gate + # helper functions from .helperfuncs import ( input_list, @@ -160,6 +162,9 @@ "Output", "Const", "Register", + # gate_graph + "GateGraph", + "Gate", # helperfuncs "input_list", "output_list", diff --git a/pyrtl/core.py b/pyrtl/core.py index 233468f7..7d73ba5d 100644 --- a/pyrtl/core.py +++ b/pyrtl/core.py @@ -581,6 +581,10 @@ def net_connections( This information helps when building a graph representation for the ``Block``. See :func:`net_graph` for an example. + .. note:: + + Consider using :ref:`gate_graphs` instead. + :param include_virtual_nodes: If ``True``, external `sources` (such as an :class:`Inputs` and :class:`Consts`) will be represented as wires that set themselves, and external `sinks` (such as diff --git a/pyrtl/gate_graph.py b/pyrtl/gate_graph.py new file mode 100644 index 00000000..8de6feef --- /dev/null +++ b/pyrtl/gate_graph.py @@ -0,0 +1,1116 @@ +""":class:`GateGraph` is an alternative representation for PyRTL logic. + +.. _gate_motivation: + +Motivation +---------- + +.. doctest only:: + + >>> import pyrtl + >>> pyrtl.reset_working_block() + +PyRTL represents logic internally with :class:`WireVectors<.WireVector>` and +:class:`LogicNets<.LogicNet>`. For example, the following code creates five +:class:`WireVectors<.WireVector>` and two :class:`LogicNets<.LogicNet>`:: + + >>> a = pyrtl.Input(name="a", bitwidth=1) + >>> b = pyrtl.Input(name="b", bitwidth=1) + >>> c = pyrtl.Input(name="c", bitwidth=1) + + >>> x = a & b + >>> x.name = "x" + + >>> y = x | c + >>> y.name = "y" + + >>> print(pyrtl.working_block()) + x/1W <-- & -- a/1I, b/1I + y/1W <-- | -- x/1W, c/1I + +The :class:`WireVectors<.WireVector>` and :class:`LogicNets<.LogicNet>` are arranged +like this:: + + ┌──────────────┐ + │ LogicNet "&" │ + │ op: "&" │ ┌────────────────┐ + │ args:────┼───▶│ WireVector "a" │ + │ args:────┼─┐ └────────────────┘ + │ │ │ ┌────────────────┐ + │ │ └─▶│ WireVector "b" │ + │ │ └────────────────┘ + │ │ ┌────────────────┐ + │ dests:───┼───▶│ WireVector "x" │ + └──────────────┘ ┌─▶└────────────────┘ + ┌──────────────┐ │ + │ LogicNet "|" │ │ + │ op: "|" │ │ + │ args:────┼─┘ ┌────────────────┐ + │ args:────┼───▶│ WireVector "c" │ + │ │ └────────────────┘ + │ │ ┌────────────────┐ + │ dests:───┼───▶│ WireVector "y" │ + └──────────────┘ └────────────────┘ + +This data structure is difficult to work with for three reasons: + +1. The arrows do not consistently point from producer to consumer, or from consumer to + producer. For example, there is no arrow from :class:`.WireVector` ``x`` (producer) + to :class:`.LogicNet` ``|`` (consumer). Similarly, there is no arrow from + :class:`.WireVector` ``x`` (consumer) to :class:`.LogicNet` ``&`` (producer). These + missing arrows make it impossible to iteratively traverse the data structure. This + creates a need for methods like :meth:`~.Block.net_connections`, which creates + ``wire_src_dict`` and ``wire_sink_dict`` with the missing pointers. + +2. The data structure is composed of two different classes, :class:`.LogicNet` and + :class:`.WireVector`, and these two classes have completely different interfaces. As + we follow pointers from one class to another, we must keep track of the current + object's class, and interact with it appropriately. + +3. :class:`.WireVector` is part of PyRTL's user interface, but also a key part of + PyRTL's internal representation. This makes :class:`.WireVector` complex and + difficult to modify, because it must implement user-facing features like inferring + bitwidth from assignment, while also maintaining a consistent internal representation + for simulation, analysis, and optimization. + +:class:`GateGraph` is an alternative representation that addresses these issues. A +:class:`GateGraph` is just a collection of :class:`Gates`, so we'll cover +:class:`Gate` first. +""" + +from __future__ import annotations + +from pyrtl.core import Block, LogicNet, working_block +from pyrtl.pyrtlexceptions import PyrtlError +from pyrtl.wire import Const, Input, Register, WireVector + + +class Gate: + """:class:`Gate` is an alternative to PyRTL's default :class:`.LogicNet` and + :class:`.WireVector` representation. + + :class:`Gate` makes it easy to iteratively explore a circuit, while simplifying the + circuit's representation by making everything a :class:`Gate`. A :class:`Gate` is + equivalent to a :class:`.LogicNet` fused with its :attr:`dest<.LogicNet.dests>` + :class:`WireVector`. So this :class:`.LogicNet` and :class:`.WireVector`:: + + ┌──────────────────┐ + │ LogicNet │ ┌───────────────────┐ + │ op: o │ │ WireVector │ + │ args: [x, y] │ │ name: n │ + │ dests:───────┼───▶│ bitwidth: b │ + └──────────────────┘ └───────────────────┘ + + Are equivalent to this :class:`Gate`:: + + ┌─────────────────────┐ + │ Gate │ + │ op: o │ + │ args: [x, y] │ + │ name: n │ + │ bitwidth: b │ + │ dests: [g1, g2] │ + └─────────────────────┘ + + Key differences between the two representations: + + 1. The :class:`Gate`'s :attr:`~Gate.args` ``[x, y]`` are references to other + :class:`Gates`. + + 2. The :class:`.WireVector`'s :attr:`~WireVector.name` and + :attr:`~WireVector.bitwidth` are stored as :attr:`Gate.name` and + :attr:`Gate.bitwidth`. If the :class:`.LogicNet` produces no output, like a + :class:`.MemBlock` write, the :class:`Gate`'s :attr:`~Gate.name` and + :attr:`~Gate.bitwidth` will be ``None``. PyRTL does not have an + :attr:`~.LogicNet.op` with multiple :attr:`.LogicNet.dests`. + + 3. The :class:`Gate` has a new :attr:`Gate.dests` attribute, which has no direct + equivalent in the :class:`.LogicNet`/:class:`.WireVector` representation. + :attr:`Gate.dests` is a list of the :class:`Gates` that use this + :class:`Gate`'s output as one of their :attr:`~Gate.args`. + + :attr:`.LogicNet.dests` and :attr:`Gate.dests` represent slightly different things, + despite having similar names: + + - :attr:`.LogicNet.dests` represents the :class:`LogicNet`'s output wire. It is a + list of :class:`WireVectors<.WireVector>` which hold the :class:`.LogicNet`'s + output. There can only be zero or one :class:`WireVectors<.WireVector>` in + :attr:`.LogicNet.dests`, but that :class:`.WireVector` can be an + :attr:`arg<.LogicNet.args>` to any number of :class:`LogicNets<.LogicNet>`. + + - :attr:`Gate.dests` represents the :class:`Gate`'s users. It is a list of + :class:`Gates` that use the :class:`Gate`'s output as one of their + :attr:`~Gate.args`. There can be any number of :class:`Gates` in + :attr:`Gate.dests`. + + With :class:`Gates`, the example from the :ref:`gate_motivation` section looks + like:: + + ┌─────────────────┐ + │ Gate "a" │ + │ op: "I" │ + │ name: "a" │ + │ bitwidth: 1 │ ┌─────────────────┐ + │ dests:──────┼───▶│ Gate "&" │ + └─────────────────┘◀─┐ │ op: "&" │ + ┌─────────────────┐ └─┼─────args │ + │ Gate "b" │◀───┼─────args │ + │ op: "I" │ │ name: "x" │ ┌─────────────────┐ + │ name: "b" │ │ bitwidth: 1 │ │ Gate "|" │ + │ bitwidth: 1 │ ┌─▶│ dests:──────┼───▶│ op: "|" │ + │ dests:──────┼─┘ └─────────────────┘◀───┼─────args │ + └─────────────────┘ ┌────────────────────────┼─────args │ + ┌─────────────────┐ │ │ name: "y" │ + │ Gate "c" │◀─┘ ┌───────────────────▶│ bitwidth: 1 │ + │ op: "I" │ │ └─────────────────┘ + │ name: "c" │ │ + │ bitwidth: 1 │ │ + │ dests:──────┼──────┘ + └─────────────────┘ + + With a :class:`Gate` representation, it is easy to iteratively traverse the data + structure: + + 1. Forwards, from producer to consumer, by following :attr:`~Gate.dests` references. + + 2. Backwards, from consumer to producer, by following :attr:`~Gate.args` references. + + The :class:`Gate` representation addresses the issues raised in the + :ref:`gate_motivation` section: + + 1. The :class:`Gate` representation is easy to iteratively explore by following + :attr:`~Gate.args` and :attr:`~Gate.dests` references, which are shown as arrows + in the figure above. + + 2. There is only one class in the :class:`Gate` graph, so we don't need to keep + track of the current object's type as we follow arrows in the graph, like we did + with :class:`LogicNet` and :class:`WireVector`. Everything is a :class:`Gate`. + + 3. By decoupling the :class:`Gate` representation from :class:`.WireVector` and + :class:`.LogicNet`, :class:`Gate` specializes in supporting analysis use cases, + without the burden of supporting all of :class:`.WireVector`'s other features. + This significantly simplifies :class:`Gate`'s design and implementation. + + For usage examples, see :class:`GateGraph` and :class:`Gate`'s documentation below. + """ + + op: str + """Operation performed by this ``Gate``. Corresponds to :attr:`.LogicNet.op`. + + For special ``Gates`` created for :class:`Inputs<.Input>`, ``op`` will instead be + the :class:`.Input`'s ``_code``, which is ``I``. + + For special ``Gates`` created for :class:`Consts<.Const>`, ``op`` will instead be + the :class:`.Const`'s ``_code``, which is ``C``. + + See :class:`.LogicNet`'s documentation for a description of all other ``ops``. + + .. doctest only:: + + >>> import pyrtl + >>> pyrtl.reset_working_block() + + Example:: + + >>> a = pyrtl.Input(name="a", bitwidth=1) + >>> _ = ~a + + >>> gate_graph = pyrtl.GateGraph() + >>> gate_a = gate_graph.get_gate("a") + >>> gate_a.op + 'I' + >>> gate_a.dests[0].op + '~' + """ + + op_param: tuple + """Static parameters for the operation. Corresponds to :attr:`.LogicNet.op_param`. + + These are constant parameters, whose values are statically known. These values + generally do not appear as actual values on wires. For example, the bits to select + for the ``s`` bit-slice operation are stored as ``op_params``. + + .. doctest only:: + + >>> import pyrtl + >>> pyrtl.reset_working_block() + + Example:: + + >>> a = pyrtl.Input(name="a", bitwidth=8) + >>> bit_slice = a[1:3] + >>> bit_slice.name = "bit_slice" + + >>> gate_graph = pyrtl.GateGraph() + >>> bit_slice_gate = gate_graph.get_gate("bit_slice") + >>> bit_slice_gate.op_param + (1, 2) + """ + + args: list[Gate] + """Inputs to the operation. Corresponds to :attr:`.LogicNet.args`. + + For each ``Gate`` ``arg`` in ``self.args``, ``self`` is in ``arg.dests``. + + Some special ``Gates`` represent operations without ``args``, like :class:`.Input` + and :class:`.Const`. Such operations will have an empty list of ``args``. + + .. note:: + + The same ``Gate`` may appear multiple times in ``args``. A self-loop + :class:`.Register` ``Gate`` may be its own ``arg``. + + .. doctest only:: + + >>> import pyrtl + >>> pyrtl.reset_working_block() + + Example:: + + >>> a = pyrtl.Input(name="a", bitwidth=1) + >>> b = pyrtl.Input(name="b", bitwidth=1) + >>> c = pyrtl.Input(name="c", bitwidth=1) + >>> abc = pyrtl.concat(a, b, c) + >>> abc.name = "abc" + + >>> gate_graph = pyrtl.GateGraph() + >>> abc_gate = gate_graph.get_gate("abc") + >>> [gate.name for gate in abc_gate.args] + ['a', 'b', 'c'] + """ + + name: str | None + """Name of the operation's output. Corresponds to :attr:`.WireVector.name`. + + Some operations do not have outputs, like :class:`.MemBlock` writes. These + operations will have a ``name`` of ``None``. + + .. doctest only:: + + >>> import pyrtl + >>> pyrtl.reset_working_block() + + Example:: + + >>> a = pyrtl.Input(name="a", bitwidth=1) + >>> b = pyrtl.Input(name="b", bitwidth=1) + >>> ab = a + b + >>> ab.name = "ab" + + >>> gate_graph = pyrtl.GateGraph() + >>> ab_gate = gate_graph.get_gate("ab") + >>> ab_gate.name + 'ab' + """ + + bitwidth: int | None + """Bitwidth of the operation's output. Corresponds to :attr:`.WireVector.bitwidth`. + + Some operations do not have outputs, like :class:`.MemBlock` writes. These + operations will have a ``bitwidth`` of ``None``. + + .. doctest only:: + + >>> import pyrtl + >>> pyrtl.reset_working_block() + + Example:: + + >>> a = pyrtl.Input(name="a", bitwidth=1) + >>> b = pyrtl.Input(name="b", bitwidth=1) + >>> ab = a + b + >>> ab.name = "ab" + + >>> gate_graph = pyrtl.GateGraph() + >>> ab_gate = gate_graph.get_gate("ab") + >>> ab_gate.bitwidth + 2 + """ + + dests: list[Gate] + """:class:`list` of :class:`Gates` that use this operation's output as one of + their :attr:`~Gate.args`. + + For each :class:`Gate` ``dest`` in ``self.dests``, ``self`` is in ``dest.args``. + + .. note:: + + The same :class:`Gate` may appear multiple times in ``dests``. A self-loop + :class:`.Register` ``Gate`` may appear in its own ``dests``. + + .. doctest only:: + + >>> import pyrtl + >>> pyrtl.reset_working_block() + + Example:: + + >>> a = pyrtl.Input(name="a", bitwidth=1) + >>> _ = a + 1 + >>> _ = a - 1 + + >>> gate_graph = pyrtl.GateGraph() + >>> a_gate = gate_graph.get_gate("a") + >>> [gate.op for gate in a_gate.dests] + ['+', '-'] + """ + + is_output: bool + """Indicates if the operation's output is an :class:`.Output`. + + .. doctest only:: + + >>> import pyrtl + >>> pyrtl.reset_working_block() + + Example:: + + >>> a = pyrtl.Input(name="a", bitwidth=1) + >>> b = pyrtl.Output(name="b", bitwidth=1) + >>> b <<= a + + >>> gate_graph = pyrtl.GateGraph() + >>> a_gate = gate_graph.get_gate("a") + >>> a_gate.is_output + False + >>> b_gate = gate_graph.get_gate("b") + >>> b_gate.is_output + True + """ + + def __init__( + self, + logic_net: LogicNet = None, + wire_vector: WireVector = None, + args: list[Gate] | None = None, + ): + """Create a ``Gate`` from a :class:`.LogicNet` or :class:`.WireVector`. + + ``Gates`` are complicated to construct because they are doubly-linked, and there + may be cycles in the ``Gate`` graph. Most users should not call this constructor + directly, and instead use :class:`GateGraph` to create ``Gates`` from a + :class:`.Block`. + + :param logic_net: :class:`.LogicNet` to create this ``Gate`` from. If + ``logic_net`` is specified, ``wire_vector`` must be ``None``. + + ``logic_net`` must not be a register, where ``logic_net.op == 'r'``. + + Register ``Gates`` are created in two phases by :class:`GateGraph`. In the + first phase, a placeholder ``Gate`` is created from the :class:`.Register`. + In this first phase, the register ``Gate``'s ``op`` is temporarily set to + ``R``, which is the :class:`.Register`'s ``_code``. This placeholder is + needed to resolve other ``Gate``'s references to the register in the second + phase. In the second phase, the register ``Gate``'s remaining fields are + populated from the register's :class:`.LogicNet`. In the second phase, the + register ``Gate``'s ``op`` is changed to ``r``, which is the + :class:`.LogicNet`'s :attr:`~.LogicNet.op`. + + :param wire_vector: :class:`.WireVector` to create this ``Gate`` from. If + ``wire_vector`` is specified, ``logic_net`` must be ``None``. + + ``wire_vector`` must be a :class:`.Const`, :class:`.Input`, or + :class:`.Register`. + + :param args: A :class:`list` of ``Gates`` that are inputs to this ``Gate``. This + corresponds to :attr:`.LogicNet.args`, except that each of a ``Gate``'s + ``args`` is a ``Gate``. + """ + self.op_param = None + if args is None: + self.args = [] + else: + self.args = args + self.name = None + self.bitwidth = None + # ``dests`` will be set up later, by ``GateGraph``. + self.dests = [] + self.is_output = False + + if logic_net is not None: + # Constructing a ``Gate`` from a ``logic_net``. + # + # For ``LogicNets``, set the ``Gate``'s ``op``, ``op_param``, ``name``, + # ``bitwidth``. + if wire_vector is not None: + msg = "Do not pass both logic_net and wire_vector to Gate." + raise PyrtlError(msg) + self.op = logic_net.op + if self.op == "r": + msg = "Registers should be created from a wire_vector, not a logic_net." + raise PyrtlError(msg) + self.op_param = logic_net.op_param + + num_dests = len(logic_net.dests) + if num_dests: + if num_dests > 1: + # The ``Gate`` representation supports at most one ``LogicNet`` + # ``dest``. If more than one ``LogicNet`` ``dest`` is needed in the + # future, concat them together, then split them apart with ``s`` + # bit-selection ``Gates``, or use multiple ``Gates`` with the same + # ``args``. + msg = "LogicNets with more than one dest are not supported" + raise PyrtlError(msg) + dest = logic_net.dests[0] + self.name = dest.name + self.bitwidth = dest.bitwidth + if dest._code == "O": + self.is_output = True + + else: + # Constructing a ``Gate`` from a ``wire_vector``. + # + # For ``Inputs`` and ``Registers``, set the ``Gate``'s ``op`` and ``dest``. + # For ``Consts``, also copy the ``val`` to ``op_param``. + # For ``Registers``, also copy the ``reset_value`` to ``op_param``. + if wire_vector is None: + msg = "Gate must be constructed from a logic_net or a wire_vector." + raise PyrtlError(msg) + + if wire_vector._code not in "CIR": + msg = ( + "Gate must be constructed from a Const, Input or Register " + "wire_vector." + ) + raise PyrtlError(msg) + + self.op = wire_vector._code + self.name = wire_vector.name + self.bitwidth = wire_vector.bitwidth + if self.op == "C": + self.op_param = (wire_vector.val,) + elif self.op == "R": + if not wire_vector.reset_value: + self.op_param = (0,) + else: + self.op_param = (wire_vector.reset_value,) + + def __str__(self) -> str: + """:return: A string representation of this ``Gate``. + + .. doctest only:: + + >>> import pyrtl + >>> pyrtl.reset_working_block() + + Example:: + + >>> a = pyrtl.Input(name="a", bitwidth=8) + >>> bit_slice = a[2:4] + >>> bit_slice.name = "bit_slice" + + >>> gate_graph = pyrtl.GateGraph() + >>> bit_slice_gate = gate_graph.get_gate("bit_slice") + + >>> print(bit_slice_gate) + bit_slice/2 = slice(a/8) [sel=(2, 3)] + + In this sample string representation: + + - :attr:`~Gate.name` is ``bit_slice``. + + - :attr:`~Gate.bitwidth` is ``2``. + + - :attr:`~Gate.op` is ``s``, spelled out as ``slice`` to improve readability. + + - :attr:`~Gate.args` is ``[]``:: + + >>> bit_slice_gate.args[0] is gate_graph.get_gate("a") + True + + - :attr:`~Gate.op_param` is ``(2, 3)``, written as ``sel`` + because a ``slice``'s :attr:`~Gate.op_param` determines the selected bits. + This improves readability by indicating what the :attr:`~Gate.op_param` means + for the :attr:`~Gate.op`. + """ + if self.name is None: + dest = "" + else: + dest_notes = "" + if self.is_output: + dest_notes = " [Output]" + + dest = f"{self.name}/{self.bitwidth}{dest_notes} " + + op_name_map = { + "&": "and", + "|": "or", + "^": "xor", + "n": "nand", + "~": "invert", + "+": "add", + "-": "sub", + "*": "mul", + "=": "eq", + "<": "lt", + ">": "gt", + "w": "", + "x": "", # Multiplexers are printed as ternary operators. + "c": "concat", + "s": "slice", + "r": "reg", + "m": "read", + "@": "write", + "I": "Input", + "C": "Const", + } + if self.name is None: + op = op_name_map[self.op] + else: + op = f"= {op_name_map[self.op]}" + + if not self.args: + args = "" + else: + arg_names = [f"{arg.name}/{arg.bitwidth}" for arg in self.args] + if self.op == "w": + args = arg_names[0] + elif self.op == "x": + args = f"{arg_names[0]} ? {arg_names[2]} : {arg_names[1]}" + elif self.op == "m": + args = f"(addr={arg_names[0]})" + elif self.op == "@": + args = ( + f"(addr={arg_names[0]}, data={arg_names[1]}, enable={arg_names[2]})" + ) + else: + args = f"({', '.join(arg_names)})" + + if self.op_param is None: + op_param = "" + elif self.op == "C": + op_param = f"({self.op_param[0]})" + elif self.op == "s": + op_param = f" [sel={self.op_param}]" + elif self.op == "m" or self.op == "@": + op_param = f" [memid={self.op_param[0]} mem={self.op_param[1].name}]" + elif self.op == "r": + op_param = f" [reset_value={self.op_param[0]}]" + else: + op_param = f" [op_param={self.op_param}]" + + return f"{dest}{op}{args}{op_param}" + + +class GateGraph: + """A :class:`GateGraph` is a collection of :class:`Gates`. + :class:`GateGraph`'s constructor creates :class:`Gates` from a + :class:`.Block`. + + See :ref:`gate_motivation` for more background. + + Users should generally construct :class:`GateGraphs`, rather than + attempting to directly construct individual :class:`Gates`. :class:`Gate` + construction is complex because they are doubly-linked, and the :class:`Gate` graph + may contain cycles. + + .. doctest only:: + + >>> import pyrtl + >>> pyrtl.reset_working_block() + + Example + ------- + + Let's build a :class:`GateGraph` for the :ref:`gate_motivation` example:: + + >>> a = pyrtl.Input(name="a", bitwidth=1) + >>> b = pyrtl.Input(name="b", bitwidth=1) + >>> c = pyrtl.Input(name="c", bitwidth=1) + + >>> x = a & b + >>> x.name = "x" + + >>> y = x | c + >>> y.name = "y" + + >>> gate_graph = pyrtl.GateGraph() + + The :class:`GateGraph` can be printed, revealing five :class:`Gates`:: + + >>> print(gate_graph) + a/1 = Input + b/1 = Input + c/1 = Input + x/1 = and(a/1, b/1) + y/1 = or(x/1, c/1) + + We can retrieve the :attr:`Gate` for input ``a``:: + + >>> a = gate_graph.get_gate("a") + >>> print(a) + a/1 = Input + >>> a.name + 'a' + >>> a.op + 'I' + + We can check ``a``'s :attr:`~Gate.dests` to see that it is an argument to a bitwise + ``&`` operation, with output named ``x``:: + + >>> len(a.dests) + 1 + >>> x = a.dests[0] + >>> print(x) + x/1 = and(a/1, b/1) + >>> x.op + '&' + >>> x.name + 'x' + + We can examine the bitwise ``&``'s :attr:`~Gate.args`, to get references to input + :class:`Gates` ``a`` and ``b``:: + + >>> x.args[0] is a + True + + >>> b = x.args[1] + >>> print(b) + b/1 = Input + >>> b.name + 'b' + >>> b.op + 'I' + + Special :class:`Gates` + ---------------------------- + + Generally, :class:`GateGraph` converts each :class:`.LogicNet` in a :class:`.Block` + to a corresponding :class:`Gate`, but some :class:`WireVectors<.WireVector>` and + :class:`LogicNets<.LogicNet>` are handled differently: + + - An :class:`.Input` :class:`.WireVector` is converted to a special input + :class:`Gate`, with op ``I``. Input :class:`Gates` have no + :attr:`~Gate.args`, and do not correspond to a :class:`.LogicNet`. + + - A :class:`.Const` :class:`.WireVector` is converted to a special const + :class:`Gate`, with op ``C``. Const :class:`Gates` have no + :attr:`~Gate.args`, and do not correspond to a :class:`.LogicNet`. The constant's + value is stored in :attr:`Gate.op_param`. + + - An :class:`.Output` :class:`.WireVector` is handled normally, and will be the + ``dest`` of the :class:`Gate` that defines the :class:`.Output`'s value. That + :class:`Gate` will have its :attr:`~Gate.is_output` attribute set to ``True``. + + - :class:`.Register` :class:`WireVectors<.WireVector>` and + :class:`LogicNets<.LogicNet>` are handled normally, except that the + :class:`.Register`'s ``reset_value`` is stored in :attr:`Gate.op_param`. Register + :class:`Gates` use the register :class:`.LogicNet` :attr:`~.LogicNet.op` + ``r``, not the :class:`.Register` ``_code`` ``R``. + + .. note:: + + Registers can create cycles in the :class:`Gate` graph, because the logic that + defines the register's :attr:`~.Register.next` value (which is the register + :class:`Gate`'s :attr:`~Gate.args`) can depend on the register's current value + (which is the register :class:`Gate`'s :attr:`~Gate.dests`). Watch out for + infinite loops when traversing a :class:`GateGraph` with registers. For example, + if you keep following :attr:`~Gate.dests` references, you may end up back where + you started. + """ + + gates: set[Gate] + """A :class:`set` of all :class:`Gates` in the ``GateGraph``. + + .. doctest only:: + + >>> import pyrtl + >>> pyrtl.reset_working_block() + + Example:: + + >>> a = pyrtl.Input(name="a", bitwidth=1) + >>> b = pyrtl.Input(name="b", bitwidth=1) + >>> c = pyrtl.Input(name="c", bitwidth=1) + >>> x = a & b + >>> x.name = "x" + >>> y = x | c + >>> y.name = "y" + + >>> gate_graph = pyrtl.GateGraph() + + >>> sorted(gate.name for gate in gate_graph.gates) + ['a', 'b', 'c', 'x', 'y'] + """ + + consts: set[Gate] + """A :class:`set` of :class:`.Const` :class:`Gates` in the ``GateGraph``. + + These :class:`Gates` provide constant values, with :attr:`~Gate.op` ``C``. + + .. doctest only:: + + >>> import pyrtl + >>> pyrtl.reset_working_block() + + Example:: + + >>> c = pyrtl.Const(name="c", val=0) + >>> d = pyrtl.Const(name="d", val=1) + >>> _ = c + d + + >>> gate_graph = pyrtl.GateGraph() + + >>> sorted(gate.name for gate in gate_graph.consts) + ['c', 'd'] + """ + + inputs: set[Gate] + """A :class:`set` of :class:`.Input` :class:`Gates` in the ``GateGraph``. + + These :class:`Gates` provide :class:`.Input` values, with :attr:`~Gate.op` + ``I``. + + .. doctest only:: + + >>> import pyrtl + >>> pyrtl.reset_working_block() + + Example:: + + >>> a = pyrtl.Input(name="a", bitwidth=1) + >>> b = pyrtl.Input(name="b", bitwidth=1) + >>> _ = a & b + + >>> gate_graph = pyrtl.GateGraph() + + >>> sorted(gate.name for gate in gate_graph.inputs) + ['a', 'b'] + """ + + outputs: set[Gate] + """A :class:`set` of :class:`.Output` :class:`Gates` in the ``GateGraph``. + + These :class:`Gates` set :class:`.Output` values, with :attr:`~Gate.is_output` + ``True``. + + .. doctest only:: + + >>> import pyrtl + >>> pyrtl.reset_working_block() + + Example:: + + >>> x = pyrtl.Output(name="x") + >>> y = pyrtl.Output(name="y") + >>> x <<= 42 + >>> y <<= 255 + + >>> gate_graph = pyrtl.GateGraph() + + >>> sorted(gate.name for gate in gate_graph.outputs) + ['x', 'y'] + """ + + registers: set[Gate] + """A :class:`set` of :class:`.Register` update :class:`Gates` in the + ``GateGraph``. + + These :class:`Gates` set a :class:`.Register`'s value for the next cycle, with + :attr:`~Gate.op` ``r``. + + .. doctest only:: + + >>> import pyrtl + >>> pyrtl.reset_working_block() + + Example:: + + >>> r = pyrtl.Register(name="r", bitwidth=1) + >>> s = pyrtl.Register(name="s", bitwidth=1) + >>> r.next <<= r + 1 + >>> s.next <<= s + 2 + + >>> gate_graph = pyrtl.GateGraph() + + >>> sorted(gate.name for gate in gate_graph.registers) + ['r', 's'] + """ + + mem_reads: set[Gate] + """A :class:`set` of :class:`.MemBlock` read :class:`Gates` in the + ``GateGraph``. + + These :class:`Gates` read :class:`MemBlocks<.MemBlock>`, with + :attr:`~Gate.op` ``m``. + + .. doctest only:: + + >>> import pyrtl + >>> pyrtl.reset_working_block() + + Example:: + + >>> mem = pyrtl.MemBlock(name="mem", bitwidth=4, addrwidth=2) + >>> addr = pyrtl.Input(name="addr", bitwidth=2) + >>> mem_read_1 = mem[addr] + >>> mem_read_1.name = "mem_read_1" + >>> mem_read_2 = mem[addr] + >>> mem_read_2.name = "mem_read_2" + + >>> gate_graph = pyrtl.GateGraph() + + >>> sorted(gate.name for gate in gate_graph.reads) + ['mem_read_1', 'mem_read_2'] + """ + + mem_writes: set[Gate] + """A :class:`set` of :class:`.MemBlock` write :class:`Gates` in the + ``GateGraph``. + + These :class:`Gates` write :class:`MemBlocks<.MemBlock>`, with + :attr:`~Gate.op` ``@``. + + .. doctest only:: + + >>> import pyrtl + >>> pyrtl.reset_working_block() + + Example:: + + >>> mem = pyrtl.MemBlock(name="mem", bitwidth=4, addrwidth=2) + >>> addr = pyrtl.Input(name="addr", bitwidth=2) + >>> mem[addr] <<= 7 + + >>> gate_graph = pyrtl.GateGraph() + + >>> # MemBlock writes have no name. + >>> [gate.name for gate in gate_graph.mem_writes] + [None] + + >>> [gate.op for gate in gate_graph.mem_writes] + ['@'] + """ + + sources: set[Gate] + """A :class:`set` of ``source`` :class:`Gates` in the ``GateGraph``. + + A ``source`` :class:`Gate`'s output value is known at the beginning of each clock + cycle. :class:`Consts<.Const>`, :class:`Inputs<.Input>`, and + :class:`Registers<.Register>` are ``source`` :class:`Gates`. + + .. note:: + + :class:`Registers<.Register>` are both ``sources`` and :attr:`~GateGraph.sinks`. + As a ``source``, it provides the :class:`.Register`'s value for the current + cycle. As a :attr:`sink`, it determines the + :class:`.Register`'s value for the next cycle. + + .. doctest only:: + + >>> import pyrtl + >>> pyrtl.reset_working_block() + + Example:: + + >>> a = pyrtl.Input(name="a", bitwidth=1) + >>> c = pyrtl.Const(name="c", bitwidth=1, val=0) + >>> r = pyrtl.Register(name="r", bitwidth=1) + >>> r.next <<= a + c + + >>> gate_graph = pyrtl.GateGraph() + + >>> sorted(gate.name for gate in gate_graph.sources) + ['a', 'c', 'r'] + """ + + sinks: set[Gate] + """A :class:`set` of ``sink`` :class:`Gates` in the ``GateGraph``. + + A ``sink`` :class:`Gate`'s output value is known only at the end of each clock + cycle. :class:`Registers<.Register>`, :class:`Outputs<.Output>`, :class:`MemBlock` + writes, and any :class:`Gate` without users (``len(dests) == 0``) are sink + :class:`Gates`. + + .. note:: + + :class:`Registers<.Register>` are both :attr:`~GateGraph.sources` and ``sinks``. + As a :attr:`source`, it provides the :class:`.Register`'s + value for the current cycle. As a ``sink``, it determines the + :class:`.Register`'s value for the next cycle. + + .. doctest only:: + + >>> import pyrtl + >>> pyrtl.reset_working_block() + + Example:: + + >>> a = pyrtl.Input(name="a", bitwidth=1) + >>> r = pyrtl.Register(name="r", bitwidth=1) + >>> o = pyrtl.Output(name="o", bitwidth=1) + >>> r.next <<= a + 1 + >>> o <<= 1 + >>> sum = a + r + >>> sum.name = "sum" + + >>> gate_graph = pyrtl.GateGraph() + + >>> sorted(gate.name for gate in gate_graph.sinks) + ['o', 'r', 'sum'] + """ + + def __init__(self, block: Block = None): + """Create :class:`Gates` from a :class:`.Block`. + + Most users should call this constructor, rather than attempting to directly + construct individual :class:`Gates`. + + :param block: :class:`.Block` to construct the :class:`GateGraph` from. Defaults + to the :ref:`working_block`. + """ + self.gates = set() + self.consts = set() + self.inputs = set() + self.outputs = set() + self.registers = set() + self.mem_reads = set() + self.mem_writes = set() + self.sources = set() + self.sinks = set() + + block = working_block(block) + block.sanity_check() + + # The ``Gate`` graph is doubly-linked, and may contain cycles, so construction + # is done in two phases. In the first phase, we only construct ``Gates`` for + # ``sources``, which are ``Consts``, ``Inputs``, and ``Registers``. + # + # In this phase, register gates are placeholders. They do not have ``args``, and + # their ``op`` is temporarily ``R``, which is ``Register._code`. These + # placeholders are needed to resolve references to registers in the second + # phase. + # + # ``wire_vector_map`` maps from ``WireVector`` to the corresponding gate. It is + # initially populated with ``Gates`` constructed from ``sources``. + wire_vector_map: dict[WireVector, Gate] = {} + for wire_vector in block.wirevector_subset((Const, Input, Register)): + gate = Gate(wire_vector=wire_vector) + self.gates.add(gate) + self.sources.add(gate) + wire_vector_map[wire_vector] = gate + + if gate.op == "C": + self.consts.add(gate) + elif gate.op == "I": + self.inputs.add(gate) + elif gate.op == "R": + self.registers.add(gate) + + # In the second phase, we construct all remaining ``Gates`` from ``LogicNets``. + # ``Block``'s iterator returns ``LogicNets`` in topological order, so we can be + # sure that each ``LogicNet``'s ``args`` are all in ``wire_vector_map``. + for logic_net in block: + # Find the ``Gates`` corresponding to the ``LogicNet``'s ``args``. + gate_args = [] + for wire_arg in logic_net.args: + gate_arg = wire_vector_map.get(wire_arg) + if gate_arg is None: + msg = f"Missing Gate for wire {wire_arg}" + raise PyrtlError(msg) + gate_args.append(gate_arg) + if logic_net.op == "r": + # Find the placeholder register ``Gate`` we created earlier, and finish + # constructing it. + gate = wire_vector_map[logic_net.dests[0]] + gate.op = "r" + gate.args = gate_args + self.sinks.add(gate) + else: + gate = Gate(logic_net=logic_net, args=gate_args) + self.gates.add(gate) + + # Add the new ``Gate`` as a ``dest`` for its ``args``. + for gate_arg in gate_args: + gate_arg.dests.append(gate) + + # Add the new ``Gate`` to ``wire_vector_map``, so we can resolve future + # references to it. + num_dests = len(logic_net.dests) + if num_dests: + if num_dests > 1: + msg = "LogicNets with more than one dest are not supported" + raise PyrtlError(msg) + dest = logic_net.dests[0] + wire_vector_map[dest] = gate + + if gate.is_output: + self.outputs.add(gate) + if gate.op == "m": + self.mem_reads.add(gate) + elif gate.op == "@": + self.mem_writes.add(gate) + + for gate in self.gates: + if len(gate.dests) == 0: + self.sinks.add(gate) + + def get_gate(self, name: str) -> Gate | None: + """Return the :class:`Gate` whose :attr:`~Gate.name` is ``name``, or ``None`` if + no such :class:`Gate` exists. + + .. warning:: + + :class:`.MemBlock` writes do not produce an output, so they can not be + retrieved with ``get_gate``. + + .. doctest only:: + + >>> import pyrtl + >>> pyrtl.reset_working_block() + + Example:: + + >>> a = pyrtl.Input(name="a", bitwidth=1) + >>> na = ~a + >>> na.name = "na" + + >>> gate_graph = pyrtl.GateGraph() + + >>> a_gate = gate_graph.get_gate("a") + >>> na_gate = gate_graph.get_gate("na") + >>> na_gate.op + '~' + >>> na_gate.args[0] is a_gate + True + + :param name: Name of the :class:`Gate` to find. + + :return: The named :class:`Gate`, or ``None`` if no such :class:`Gate` was + found. + """ + for gate in self.gates: + if gate.name == name: + return gate + return None + + def __str__(self) -> str: + """Return a string representation of the ``GateGraph``. + + .. doctest only:: + + >>> import pyrtl + >>> pyrtl.reset_working_block() + + Example:: + + >>> a = pyrtl.Input(name="a", bitwidth=2) + >>> b = pyrtl.Input(name="b", bitwidth=2) + >>> sum = a + b + >>> sum.name = "sum" + + >>> gate_graph = pyrtl.GateGraph() + + >>> print(gate_graph) + a/2 = Input + b/2 = Input + sum/3 = add(a/2, b/2) + + :return: A string representation of each :class:`Gate` in the ``GateGraph``, one + :class:`Gate` per line. The :class:`Gates` will be sorted by + name. + """ + sorted_gates = sorted( + self.gates, key=lambda gate: gate.name if gate.name else "~~~" + ) + return "\n".join([str(gate) for gate in sorted_gates]) diff --git a/pyrtl/visualization.py b/pyrtl/visualization.py index c919339e..556ec446 100644 --- a/pyrtl/visualization.py +++ b/pyrtl/visualization.py @@ -40,6 +40,10 @@ def net_graph(block: Block = None, split_state: bool = False): :class:`WireVectors` that are not connected to any nets are not returned as part of the graph. + .. note:: + + Consider using :ref:`gate_graphs` instead. + :param block: :class:`Block` to use (defaults to current :ref:`working_block`). :param split_state: If ``True``, split connections to/from a register update net; this means that registers will be appear as source nodes of the network, and diff --git a/tests/test_gate_graph.py b/tests/test_gate_graph.py new file mode 100644 index 00000000..82bcab4a --- /dev/null +++ b/tests/test_gate_graph.py @@ -0,0 +1,309 @@ +import doctest +import unittest + +import pyrtl + + +class TestDocTests(unittest.TestCase): + """Test documentation examples.""" + + def test_doctests(self): + failures, tests = doctest.testmod(m=pyrtl.gate_graph) + self.assertGreater(tests, 0) + self.assertEqual(failures, 0) + + +class TestGateGraph(unittest.TestCase): + def setUp(self): + pyrtl.reset_working_block() + + def test_gate_retrieval(self): + a = pyrtl.Input(name="a", bitwidth=1) + b = pyrtl.Input(name="b", bitwidth=1) + c = pyrtl.Input(name="c", bitwidth=2) + ab = a + b + ab.name = "ab" + abc = ab - c + abc.name = "abc" + + gate_graph = pyrtl.GateGraph() + + self.assertEqual( + sorted([gate.name for gate in gate_graph.gates]), + ["a", "ab", "abc", "b", "c"], + ) + + self.assertEqual( + sorted([gate.name for gate in gate_graph.sources]), ["a", "b", "c"] + ) + + self.assertEqual(sorted([gate.name for gate in gate_graph.sinks]), ["abc"]) + + gate_ab = gate_graph.get_gate("ab") + self.assertEqual(gate_ab.name, "ab") + self.assertEqual(gate_ab.op, "+") + + gate_abc = gate_graph.get_gate("abc") + self.assertEqual(gate_abc.name, "abc") + self.assertEqual(gate_abc.op, "-") + + def test_get_gate(self): + _ = pyrtl.Input(name="a", bitwidth=4) + + gate_graph = pyrtl.GateGraph() + a_gate = gate_graph.get_gate("a") + self.assertEqual(a_gate.name, "a") + + self.assertEqual(gate_graph.get_gate("q"), None) + + def test_select_gate(self): + a = pyrtl.Input(name="a", bitwidth=4) + b = pyrtl.Input(name="b", bitwidth=4) + s = pyrtl.Input(name="s", bitwidth=1) + + output = pyrtl.select(s, a, b) + output.name = "output" + + gate_graph = pyrtl.GateGraph() + select_gate = gate_graph.get_gate("output") + self.assertEqual(select_gate.op, "x") + self.assertEqual(select_gate.op_param, None) + a_gate = gate_graph.get_gate("a") + b_gate = gate_graph.get_gate("b") + s_gate = gate_graph.get_gate("s") + self.assertEqual(select_gate.args, [s_gate, b_gate, a_gate]) + self.assertEqual(select_gate.name, "output") + self.assertEqual(select_gate.bitwidth, 4) + self.assertEqual(select_gate.dests, []) + self.assertEqual(str(select_gate), "output/4 = s/1 ? a/4 : b/4") + + def test_gate_attrs(self): + a = pyrtl.Input(name="a", bitwidth=4) + b = pyrtl.Const(name="b", bitwidth=2, val=1) + bit_slice = a[2:4] + bit_slice.name = "bit_slice" + ab = bit_slice + b + ab.name = "ab" + + output = pyrtl.Output(name="output", bitwidth=3) + bb = b + b + bb.name = "bb" + output <<= bb + + gate_graph = pyrtl.GateGraph() + + a_gate = gate_graph.get_gate("a") + self.assertEqual(a_gate.op, "I") + self.assertEqual(a_gate.args, []) + + b_gate = gate_graph.get_gate("b") + self.assertEqual(b_gate.op, "C") + self.assertEqual(b_gate.op_param, (1,)) + + self.assertEqual(str(b_gate), "b/2 = Const(1)") + + bit_slice_gate = gate_graph.get_gate("bit_slice") + self.assertEqual(bit_slice_gate.op, "s") + + self.assertEqual(bit_slice_gate.op_param, (2, 3)) + + self.assertEqual(bit_slice_gate.args, [a_gate]) + + self.assertEqual(bit_slice_gate.name, "bit_slice") + self.assertEqual(bit_slice_gate.bitwidth, 2) + + self.assertEqual(str(bit_slice_gate), "bit_slice/2 = slice(a/4) [sel=(2, 3)]") + + ab_gate = gate_graph.get_gate("ab") + self.assertEqual(ab_gate.op, "+") + + self.assertEqual(ab_gate.args, [bit_slice_gate, b_gate]) + + self.assertEqual(ab_gate.name, "ab") + self.assertEqual(ab_gate.bitwidth, 3) + self.assertFalse(ab_gate.is_output) + + output_gate = gate_graph.get_gate("output") + self.assertEqual(output_gate.op, "w") + self.assertTrue(output_gate.is_output) + self.assertEqual(str(output_gate), "output/3 [Output] = bb/3") + + self.assertEqual(len(output_gate.args), 1) + output_add_gate = output_gate.args[0] + + self.assertEqual(output_add_gate.args, [b_gate, b_gate]) + + self.assertEqual(len(b_gate.dests), 3) + num_ab_gates = 0 + num_output_add_gates = 0 + for dest_gate in b_gate.dests: + if dest_gate is ab_gate: + num_ab_gates += 1 + elif dest_gate is output_add_gate: + num_output_add_gates += 1 + self.assertEqual(num_ab_gates, 1) + self.assertEqual(num_output_add_gates, 2) + + def test_register_gate_forward(self): + counter = pyrtl.Register(name="counter", bitwidth=3) + one = pyrtl.Const(name="one", bitwidth=3, val=1) + truncated = (counter + one).truncate(3) + truncated.name = "truncated" + counter.next <<= truncated + + gate_graph = pyrtl.GateGraph() + + # Traverse the ``GateGraph`` forward, following ``dests`` references, from + # ``counter``. We should end up back at ``counter``. + counter_gate = gate_graph.get_gate("counter") + self.assertEqual(len(counter_gate.dests), 1) + self.assertEqual( + str(counter_gate), "counter/3 = reg(truncated/3) [reset_value=0]" + ) + + plus_gate = counter_gate.dests[0] + self.assertEqual(plus_gate.op, "+") + self.assertEqual(len(plus_gate.dests), 1) + + slice_gate = plus_gate.dests[0] + self.assertEqual(slice_gate.op, "s") + self.assertEqual(len(slice_gate.dests), 1) + + self.assertEqual(slice_gate.dests[0], counter_gate) + + def test_register_gate_backward(self): + counter = pyrtl.Register(name="counter", bitwidth=3, reset_value=2) + one = pyrtl.Const(name="one", bitwidth=3, val=1) + counter.next <<= counter + one + + gate_graph = pyrtl.GateGraph() + + # Traverse the ``GateGraph`` backward, following ``args`` references, from + # ``counter``. We should end up back at ``counter``. + counter_gate = gate_graph.get_gate("counter") + self.assertEqual(len(counter_gate.args), 1) + self.assertEqual(counter_gate.op_param, (2,)) + + # Implicit truncation from 4-bit sum to 3-bit register input. + slice_gate = counter_gate.args[0] + self.assertEqual(slice_gate.op, "s") + self.assertEqual(len(slice_gate.args), 1) + + plus_gate = slice_gate.args[0] + self.assertEqual(plus_gate.op, "+") + self.assertEqual(len(plus_gate.args), 2) + + self.assertEqual(plus_gate.args[0], counter_gate) + + def test_register_self_loop(self): + """Test a register that sets its next value directly from itself. + + This is an unusual case that creates a self-loop in the ``GateGraph``. + """ + r = pyrtl.Register(name="r", bitwidth=1) + r.next <<= r + + gate_graph = pyrtl.GateGraph() + r_gate = gate_graph.get_gate("r") + self.assertEqual(r_gate.args[0], r_gate) + self.assertEqual(r_gate.dests[0], r_gate) + self.assertEqual(str(r_gate), "r/1 = reg(r/1) [reset_value=0]") + + def test_memblock(self): + mem = pyrtl.MemBlock(name="mem", bitwidth=8, addrwidth=2) + + write_addr = pyrtl.Input(name="write_addr", bitwidth=2) + write_data = pyrtl.Input(name="write_data", bitwidth=8) + write_enable = pyrtl.Input(name="write_enable", bitwidth=1) + mem[write_addr] <<= pyrtl.MemBlock.EnabledWrite( + data=write_data, enable=write_enable + ) + + read_addr = pyrtl.Input(name="read_addr", bitwidth=2) + read_data = mem[read_addr] + read_data.name = "read_data" + + gate_graph = pyrtl.GateGraph() + + read_addr_gate = gate_graph.get_gate("read_addr") + read_gate = gate_graph.get_gate("read_data") + self.assertEqual(read_gate.op, "m") + self.assertEqual(read_gate.args, [read_addr_gate]) + self.assertEqual(read_gate.op_param, (mem.id, mem)) + self.assertEqual(read_gate.bitwidth, 8) + self.assertEqual( + str(read_gate), + f"read_data/8 = read(addr=read_addr/2) [memid={mem.id} mem=mem]", + ) + + write_addr_gate = gate_graph.get_gate("write_addr") + write_data_gate = gate_graph.get_gate("write_data") + write_enable_gate = gate_graph.get_gate("write_enable") + write_gate = write_data_gate.dests[0] + self.assertEqual(write_gate.op, "@") + self.assertEqual(read_gate.op_param, (mem.id, mem)) + self.assertEqual( + write_gate.args, [write_addr_gate, write_data_gate, write_enable_gate] + ) + self.assertEqual(write_gate.name, None) + self.assertEqual(write_gate.bitwidth, None) + self.assertEqual(write_gate.dests, []) + self.assertEqual( + str(write_gate), + "write(addr=write_addr/2, data=write_data/8, enable=write_enable/1) " + f"[memid={mem.id} mem=mem]", + ) + + def test_gate_sets(self): + a = pyrtl.Input(name="a", bitwidth=1) + b = pyrtl.Input(name="b", bitwidth=1) + + c = pyrtl.Const(name="c", bitwidth=1, val=0) + d = pyrtl.Const(name="d", bitwidth=1, val=1) + + x = pyrtl.Output(name="x", bitwidth=1) + y = pyrtl.Output(name="y", bitwidth=1) + + r = pyrtl.Register(name="r", bitwidth=1) + s = pyrtl.Register(name="s", bitwidth=1) + + mem = pyrtl.MemBlock(name="mem", bitwidth=1, addrwidth=1) + + x <<= a + c + + r.next <<= r + c + s.next <<= r + d + + mem[a] <<= pyrtl.MemBlock.EnabledWrite(data=c, enable=d) + read = mem[b] + read.name = "read" + y <<= read + d + + gate_graph = pyrtl.GateGraph() + + self.assertEqual(sorted(gate.name for gate in gate_graph.inputs), ["a", "b"]) + self.assertEqual(sorted(gate.name for gate in gate_graph.consts), ["c", "d"]) + self.assertEqual(sorted(gate.name for gate in gate_graph.outputs), ["x", "y"]) + self.assertEqual(sorted(gate.name for gate in gate_graph.registers), ["r", "s"]) + self.assertEqual(sorted(gate.name for gate in gate_graph.mem_reads), ["read"]) + mem_writes = gate_graph.mem_writes + # Check the MemBlock write. + self.assertEqual(len(mem_writes), 1) + write_gate = next(iter(mem_writes)) + self.assertTrue(write_gate is not None) + self.assertEqual(write_gate.op, "@") + # MemBlock write has no name. + self.assertEqual(write_gate.name, None) + + self.assertEqual( + sorted(gate.name for gate in gate_graph.sources), + ["a", "b", "c", "d", "r", "s"], + ) + sinks = set(gate_graph.sinks) + self.assertEqual( + sorted(str(gate.name) for gate in sinks), ["None", "r", "s", "x", "y"] + ) + + +if __name__ == "__main__": + unittest.main()