diff --git a/src/ark/comm/queriable.py b/src/ark/comm/queriable.py index 294346a..2e55a75 100644 --- a/src/ark/comm/queriable.py +++ b/src/ark/comm/queriable.py @@ -21,7 +21,10 @@ def __init__( ): super().__init__(node_name, session, clock, channel, data_collector) self._handler = handler - self._queryable = self._session.declare_queryable(self._channel, self._on_query) + self._queryable = self._session.declare_queryable(self._channel, + self._on_query, + complete=False) + print(f"Declared queryable on channel: {self._channel}") def core_registration(self): print("..todo: register with ark core..") @@ -29,22 +32,26 @@ def core_registration(self): def _on_query(self, query: zenoh.Query) -> None: # If we were closed, ignore queries if not self._active: + print("Received query on closed Queryable, ignoring") return try: # Zenoh query may or may not include a payload. # For your use-case, the request is always in query.value (bytes) - raw = bytes(query.value) if query.value is not None else b"" + raw = bytes(query.payload) if query.payload is not None else b"" if not raw: + print("Received query with no payload, ignoring") return # nothing to do req_env = Envelope() req_env.ParseFromString(raw) # Decode request protobuf - req_type = msgs.get(req_env.payload_msg_type) + # req_type = msgs.get(req_env.payload_msg_type) + req_type = msgs.get(req_env.msg_type) if req_type is None: # Unknown message type: ignore (or reply error later) + print(f"Unknown message type '{req_env.msg_type}' in query, ignoring") return req_msg = req_type() @@ -60,11 +67,13 @@ def _on_query(self, query: zenoh.Query) -> None: resp_env.sent_seq_index = self._seq_index resp_env.src_node_name = self._node_name resp_env.channel = self._channel + resp_env.msg_type = resp_msg.DESCRIPTOR.full_name + resp_env.payload = resp_msg.SerializeToString() self._seq_index += 1 - resp_env = Envelope.pack(self._node_name, self._clock, resp_msg) - query.reply(resp_env.SerializeToString()) + with query: + query.reply(query.key_expr, resp_env.SerializeToString()) if self._data_collector: self._data_collector.append(req_env.SerializeToString()) @@ -73,4 +82,9 @@ def _on_query(self, query: zenoh.Query) -> None: except Exception: # Keep it minimal: don't kill the zenoh callback thread # You can add logging here if desired + print("Error processing query:") + # write the traceback to stdout for debugging + import traceback + traceback.print_exc() + return diff --git a/src/ark/comm/querier.py b/src/ark/comm/querier.py index c6d4586..88a6e8a 100644 --- a/src/ark/comm/querier.py +++ b/src/ark/comm/querier.py @@ -3,6 +3,7 @@ from google.protobuf.message import Message from ark.data.data_collector import DataCollector from ark.comm.end_point import EndPoint +from ark_msgs.registry import msgs class Querier(EndPoint): @@ -11,12 +12,16 @@ def __init__( self, node_name: str, session: zenoh.Session, + query_target, clock, channel: str, data_collector: DataCollector | None, ): super().__init__(node_name, session, clock, channel, data_collector) - self._querier = self._session.declare_querier(self._channel) + self._querier = self._session.declare_querier(self._channel, + target=query_target) + print(f"Declared querier on channel: {self._channel}") + self._query_selector = zenoh.Selector(self._channel) def core_registration(self): print("..todo: register with ark core..") @@ -48,18 +53,21 @@ def query( else: raise TypeError("req must be a protobuf Message or bytes") - replies = self._querier.get(value=req_env.SerializeToString(), timeout=timeout) + replies = self._querier.get(payload=req_env.SerializeToString()) for reply in replies: if reply.ok is None: continue resp_env = Envelope() - resp_env.ParseFromString(bytes(reply.ok)) + resp_env.ParseFromString(bytes(reply.ok.payload)) resp_env.dst_node_name = self._node_name resp_env.recv_timestamp = self._clock.now() - resp = resp_env.extract_message() + try: + resp = resp_env.extract_message() + except Exception as e: + continue self._seq_index += 1 @@ -69,11 +77,6 @@ def query( return resp - else: - raise TimeoutError( - f"No OK reply received for query on '{self._channel}' within {timeout}s" - ) - def close(self): super().close() self._querier.undeclare() diff --git a/src/ark/diff/__init__.py b/src/ark/diff/__init__.py new file mode 100644 index 0000000..6777145 --- /dev/null +++ b/src/ark/diff/__init__.py @@ -0,0 +1 @@ +from ark.diff.variable import Variable diff --git a/src/ark/diff/variable.py b/src/ark/diff/variable.py new file mode 100644 index 0000000..3330c97 --- /dev/null +++ b/src/ark/diff/variable.py @@ -0,0 +1,80 @@ +import torch +from ark_msgs import Value + + +class Variable: + + def __init__(self, name, value, mode, variables_registry, lock, clock, create_queryable_fn): + self.name = name + self.mode = mode + self._variables_registry = variables_registry + self._lock = lock + self._clock = clock + self._grads = {} # input vars: {output_name: grad_value} + + if mode == "input": + self._tensor = torch.tensor(value, requires_grad=True) + self._history = {} + self._replay_tensor = None + else: + self._tensor = None + self._computation_ts = clock.now() + self._replay_fn = None + for inp_name, inp_var in variables_registry.items(): + if inp_var.mode == "input": + grad_channel = f"grad/{inp_name}/{name}" + + def _make_handler(iv, ov_name, reg, lk): + def handler(_req): + out_var = reg.get(ov_name) + if _req.timestamp != 0 and out_var._replay_fn: + val, grad = out_var._replay_fn(_req.timestamp, iv.name, ov_name) + return Value(val=val, grad=grad, timestamp=_req.timestamp) + with lk: + val = float(out_var._tensor.detach()) if out_var and out_var._tensor is not None else 0.0 + grad = iv._grads.get(ov_name, 0.0) + ts = out_var._computation_ts if out_var else 0 + return Value(val=val, grad=grad, timestamp=ts) + return handler + + create_queryable_fn(grad_channel, _make_handler(inp_var, name, variables_registry, self._lock)) + + def snapshot(self, ts): + """Record current tensor value at clock timestamp ts.""" + self._history[ts] = float(self._tensor.detach()) + + def at(self, ts): + """Return a fresh requires_grad tensor from history at ts.""" + val = self._history[ts] + self._replay_tensor = torch.tensor(val, requires_grad=True) + return self._replay_tensor + + @property + def tensor(self): + return self._tensor + + @tensor.setter + def tensor(self, value): + if self.mode == "output": + self._tensor = value + self._compute_and_store_grads() + else: + self._tensor.data = value.data if isinstance(value, torch.Tensor) else torch.tensor(value) + + def _is_last_output(self): + output_names = [k for k, v in self._variables_registry.items() if v.mode == "output"] + return output_names and output_names[-1] == self.name + + def _compute_and_store_grads(self): + if self._tensor is None or not self._tensor.requires_grad: + return + with self._lock: + for var in self._variables_registry.values(): + if var.mode == "input" and var._tensor.grad is not None: + var._tensor.grad.zero_() + self._tensor.backward(retain_graph=not self._is_last_output()) + for var in self._variables_registry.values(): + if var.mode == "input": + grad = float(var._tensor.grad) if var._tensor.grad is not None else 0.0 + var._grads[self.name] = grad + self._computation_ts = self._clock.now() diff --git a/src/ark/node.py b/src/ark/node.py index 6ba6c6e..1a01706 100644 --- a/src/ark/node.py +++ b/src/ark/node.py @@ -1,5 +1,7 @@ import json import time +import threading +import torch import zenoh from ark.time.clock import Clock from ark.time.rate import Rate @@ -10,6 +12,8 @@ from ark.comm.queriable import Queryable from ark.data.data_collector import DataCollector from ark.core.registerable import Registerable +from ark.diff.variable import Variable +from ark_msgs import VariableInfo class BaseNode(Registerable): @@ -22,7 +26,8 @@ def __init__( sim: bool = False, collect_data: bool = False, ): - self._z_cfg = zenoh.Config.from_json5(json.dumps(z_cfg)) + # self._z_cfg = zenoh.Config.from_json5(json.dumps(z_cfg)) + self._z_cfg = z_cfg self._session = zenoh.open(self._z_cfg) self._env_name = env_name self._node_name = node_name @@ -36,6 +41,9 @@ def __init__( self._subs = {} self._queriers = {} self._queriables = {} + self._variables = {} + self._grad_lock = threading.Lock() + self._registry_pub = self.create_publisher("ark/vars/register") self._session.declare_subscriber(f"{env_name}/reset", self._on_reset) @@ -73,17 +81,19 @@ def create_subscriber(self, channel, callback) -> Subscriber: self._subs[channel] = sub return sub - def create_querier(self, channel, timeout=10.0) -> Querier: + def create_querier(self, channel, target, timeout=10.0) -> Querier: querier = Querier( self._node_name, self._session, + target, self._clock, channel, self._data_collector, - timeout, + # timeout, ) querier.core_registration() self._queriers[channel] = querier + # print session and channelinfo for debugging return querier def create_queryable(self, channel, handler) -> Queryable: @@ -99,6 +109,35 @@ def create_queryable(self, channel, handler) -> Queryable: self._queriables[channel] = queryable return queryable + def create_variable(self, name, value, mode="input"): + """Create a differentiable variable. + + For "output" mode, queryables are created on "grad/{input_name}/{name}" + for each existing input variable. Setting the tensor triggers an eager + backward pass that caches gradients into each input variable. + + Args: + name: Variable identifier, used in channel names. + value: Initial scalar value for the underlying tensor. + mode: "input" or "output". + """ + var = Variable(name, value, mode, self._variables, self._grad_lock, self._clock, self.create_queryable) + self._variables[name] = var + + if mode == "output": + grad_channels = [ + f"grad/{inp_name}/{name}" + for inp_name, v in self._variables.items() + if v.mode == "input" + ] + self._registry_pub.publish(VariableInfo( + output_name=name, + node_name=self._node_name, + grad_channels=grad_channels, + )) + + return var + def create_rate(self, hz: float): rate = Rate(self._clock, hz) self._rates.append(rate) diff --git a/src/ark/scripts/core.py b/src/ark/scripts/core.py index 8760b94..638dc76 100644 --- a/src/ark/scripts/core.py +++ b/src/ark/scripts/core.py @@ -1,6 +1,68 @@ -import sys +import argparse +import time +import zenoh +from ark.node import BaseNode +from ark_msgs import VariableInfo + + +class RegistryNode(BaseNode): + + def __init__(self, cfg): + super().__init__("ark", "registry", cfg) + self._var_registry: dict[str, VariableInfo] = {} + self.create_subscriber("ark/vars/register", self._on_register) + + def _on_register(self, msg: VariableInfo): + name = msg.output_name + self._var_registry[name] = msg + channel = f"ark/vars/{name}" + if channel not in self._queriables: + def _make_handler(n): + def handler(_req): + return self._var_registry[n] + return handler + self.create_queryable(channel, _make_handler(name)) + print(f"Registered output variable '{name}' from node '{msg.node_name}' " + f"with channels: {list(msg.grad_channels)}") + + def core_registration(self): + pass + + def close(self): + super().close() def main(): - print(">>Ark core<<") - print(sys.executable) + parser = argparse.ArgumentParser( + prog="ark-core", description="Ark central registry" + ) + parser.add_argument("--mode", "-m", dest="mode", + choices=["peer", "client"], type=str) + parser.add_argument("--connect", "-e", dest="connect", + metavar="ENDPOINT", action="append", type=str) + parser.add_argument("--listen", "-l", dest="listen", + metavar="ENDPOINT", action="append", type=str) + args = parser.parse_args() + + cfg = zenoh.Config() + if args.mode: + import json + cfg.insert_json5("mode", json.dumps(args.mode)) + if args.connect: + import json + cfg.insert_json5("connect/endpoints", json.dumps(args.connect)) + if args.listen: + import json + cfg.insert_json5("listen/endpoints", json.dumps(args.listen)) + + node = RegistryNode(cfg) + print("Ark registry running.") + try: + node.spin() + except KeyboardInterrupt: + print("Shutting down registry.") + node.close() + + +if __name__ == "__main__": + main() diff --git a/test/ad_plotter_sub.py b/test/ad_plotter_sub.py new file mode 100644 index 0000000..e8b3777 --- /dev/null +++ b/test/ad_plotter_sub.py @@ -0,0 +1,170 @@ +import time +import threading +import matplotlib.pyplot as plt +import matplotlib.animation as animation +from ark.node import BaseNode +from ark_msgs import Value, VariableInfo + +import argparse +import zenoh +import common_example as common + + +class AutodiffPlotterNode(BaseNode): + def __init__(self, cfg, target): + super().__init__("env", "autodiff_plotter", cfg, sim=True) + self.pos_x, self.pos_y = [], [] + self.pos_x_ts, self.pos_y_ts = [], [] + self._grad_queriers = {} # channel -> Querier + self._grad_data = {} # channel -> [float] + self._grad_ts = {} # channel -> [int] + self.create_subscriber("x", self.on_x) + self.create_subscriber("y", self.on_y) + self._discover_grad_channels(["x", "y"], target) + + def _discover_grad_channels(self, output_names, target, timeout=5.0): + for out in output_names: + disc = self.create_querier(f"ark/vars/{out}", target=target) + deadline = time.time() + timeout + while time.time() < deadline: + try: + resp = disc.query(VariableInfo()) + if isinstance(resp, VariableInfo): + for ch in resp.grad_channels: + self._grad_queriers[ch] = self.create_querier(ch, target=target) + self._grad_data[ch] = [] + self._grad_ts[ch] = [] + break + except Exception: + time.sleep(0.2) + + def on_x(self, msg: Value): + self.pos_x.append(msg.val) + self.pos_x_ts.append(msg.timestamp) + + def on_y(self, msg: Value): + self.pos_y.append(msg.val) + self.pos_y_ts.append(msg.timestamp) + + def fetch_grads(self): + req = Value() + for ch, querier in self._grad_queriers.items(): + try: + resp = querier.query(req) + if isinstance(resp, Value): + self._grad_data[ch].append(resp.grad) + self._grad_ts[ch].append(resp.timestamp) + except Exception: + pass + + def fetch_grads_at(self, ts): + req = Value(timestamp=ts) + results = {} + for ch, querier in self._grad_queriers.items(): + try: + resp = querier.query(req) + if isinstance(resp, Value): + results[ch] = (resp.val, resp.grad) + except Exception: + pass + return results + + +def main(): + parser = argparse.ArgumentParser(description="Autodiff Plotter Node") + common.add_config_arguments(parser) + parser.add_argument( + "--target", + "-t", + dest="target", + choices=["ALL", "BEST_MATCHING", "ALL_COMPLETE", "NONE"], + default="BEST_MATCHING", + type=str, + help="The target queryables of the query.", + ) + parser.add_argument( + "--timeout", + "-o", + dest="timeout", + default=10.0, + type=float, + help="The query timeout", + ) + + args = parser.parse_args() + conf = common.get_config_from_args(args) + + target = { + "ALL": zenoh.QueryTarget.ALL, + "BEST_MATCHING": zenoh.QueryTarget.BEST_MATCHING, + "ALL_COMPLETE": zenoh.QueryTarget.ALL_COMPLETE, + }.get(args.target) + + node = AutodiffPlotterNode(conf, target) + threading.Thread(target=node.spin, daemon=True).start() + + fig, (ax_pos, ax_grad, ax_replay) = plt.subplots(1, 3, figsize=(18, 5)) + ax_pos.set_title("Position") + ax_pos.set_xlabel("x") + ax_pos.set_ylabel("y") + ax_pos.set_aspect("equal") + (line_pos,) = ax_pos.plot([], [], "b-") + ax_grad.set_title("Gradients (live)") + ax_grad.set_xlabel("sim time (s)") + ax_grad.set_ylabel("grad") + + colors = plt.cm.tab10.colors + grad_lines = {} + for i, ch in enumerate(node._grad_queriers): + (line,) = ax_grad.plot([], [], color=colors[i % 10], label=ch) + grad_lines[ch] = line + ax_grad.legend() + + ax_replay.set_title("Gradients (replay)") + ax_replay.set_xlabel("sim time (s)") + ax_replay.set_ylabel("grad") + replay_data = {ch: [] for ch in node._grad_queriers} + replay_ts = {ch: [] for ch in node._grad_queriers} + replay_lines = {} + for i, ch in enumerate(node._grad_queriers): + (line,) = ax_replay.plot([], [], color=colors[i % 10], label=ch) + replay_lines[ch] = line + ax_replay.legend() + + def update(frame): + node.fetch_grads() + + # Replay: query gradient at a historical timestamp + if len(node.pos_x_ts) > 10: + historical_ts = node.pos_x_ts[-10] + results = node.fetch_grads_at(historical_ts) + for ch, (val, grad) in results.items(): + replay_data[ch].append(grad) + replay_ts[ch].append(historical_ts) + + n = min(len(node.pos_x), len(node.pos_y)) + line_pos.set_data(node.pos_x[:n], node.pos_y[:n]) + ax_pos.relim() + ax_pos.autoscale_view() + for ch, line in grad_lines.items(): + data = node._grad_data[ch] + times = [t / 1e9 for t in node._grad_ts[ch]] + line.set_data(times[: len(data)], data) + ax_grad.relim() + ax_grad.autoscale_view() + for ch, line in replay_lines.items(): + data = replay_data[ch] + times = [t / 1e9 for t in replay_ts[ch]] + line.set_data(times[: len(data)], data) + ax_replay.relim() + ax_replay.autoscale_view() + return line_pos, *grad_lines.values(), *replay_lines.values() + + ani = animation.FuncAnimation(fig, update, interval=50, blit=False) + plt.tight_layout() + plt.show() + node.close() + + +if __name__ == "__main__": + main() diff --git a/test/common.py b/test/common.py index b26b213..0a5d7bc 100644 --- a/test/common.py +++ b/test/common.py @@ -1 +1,17 @@ -z_cfg = {"mode": "peer", "connect": {"endpoints": ["udp/127.0.0.1:7447"]}} +listen_cfg = { + "mode": "peer", + "listen": { + "endpoints": ["tcp/0.0.0.0:7447"]}, +} +connect_cfg = { + "mode": "peer", + "connect": { + "endpoints": ["tcp/127.0.0.1:7447"] + } +} +z_cfg = { + "mode": "peer", + # "connect": { + # "endpoints":["udp/127.0.0.1:7447"] + # } +} diff --git a/test/common_example.py b/test/common_example.py new file mode 100644 index 0000000..0c1eea3 --- /dev/null +++ b/test/common_example.py @@ -0,0 +1,83 @@ +import argparse +import json + +import zenoh + + +def add_config_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--mode", + "-m", + dest="mode", + choices=["peer", "client"], + type=str, + help="The zenoh session mode.", + ) + parser.add_argument( + "--connect", + "-e", + dest="connect", + metavar="ENDPOINT", + action="append", + type=str, + help="Endpoints to connect to.", + ) + parser.add_argument( + "--listen", + "-l", + dest="listen", + metavar="ENDPOINT", + action="append", + type=str, + help="Endpoints to listen on.", + ) + parser.add_argument( + "--config", + "-c", + dest="config", + metavar="FILE", + type=str, + help="A configuration file.", + ) + parser.add_argument( + "--no-multicast-scouting", + dest="no_multicast_scouting", + default=False, + action="store_true", + help="Disable multicast scouting.", + ) + parser.add_argument( + "--cfg", + dest="cfg", + metavar="CFG", + default=[], + action="append", + type=str, + help="Allows arbitrary configuration changes as column-separated KEY:VALUE pairs. Where KEY must be a valid config path and VALUE must be a valid JSON5 string that can be deserialized to the expected type for the KEY field. Example: --cfg='transport/unicast/max_links:2'.", + ) + + +def get_config_from_args(args) -> zenoh.Config: + conf = ( + zenoh.Config.from_file(args.config) + if args.config is not None + else zenoh.Config() + ) + if args.mode is not None: + conf.insert_json5("mode", json.dumps(args.mode)) + if args.connect is not None: + conf.insert_json5("connect/endpoints", json.dumps(args.connect)) + if args.listen is not None: + conf.insert_json5("listen/endpoints", json.dumps(args.listen)) + if args.no_multicast_scouting: + conf.insert_json5("scouting/multicast/enabled", json.dumps(False)) + + for c in args.cfg: + try: + [key, value] = c.split(":", 1) + except: + print(f"`--cfg` argument: expected KEY:VALUE pair, got {c}") + raise + conf.insert_json5(key, value) + + return conf diff --git a/test/diff_variable_pub.py b/test/diff_variable_pub.py new file mode 100644 index 0000000..de18f63 --- /dev/null +++ b/test/diff_variable_pub.py @@ -0,0 +1,88 @@ +from ark.node import BaseNode +from ark_msgs import Value +import argparse +import common_example as common +import torch + +HZ = 50 +DT = 1.0 / HZ + + +class LineVariableNode(BaseNode): + + def __init__(self, cfg): + super().__init__("env", "line_var_pub", cfg, sim=True) + self.x_pub = self.create_publisher("x") + self.y_pub = self.create_publisher("y") + + # Output variables auto-create grad queryables: + # grad/v/x, grad/v/y, grad/m/x, grad/m/y, grad/c/x, grad/c/y + self.v = self.create_variable("v", 0.0, mode="input") + self.m = self.create_variable("m", 0.0, mode="input") + self.c = self.create_variable("c", 0.0, mode="input") + self.x = self.create_variable("x", 0.0, mode="output") + self.y = self.create_variable("y", 0.0, mode="output") + + self.create_subscriber("param/v", lambda msg: self.v.tensor.data.fill_(msg.val)) + self.create_subscriber("param/m", lambda msg: self.m.tensor.data.fill_(msg.val)) + self.create_subscriber("param/c", lambda msg: self.c.tensor.data.fill_(msg.val)) + + self.x._replay_fn = self._replay_grad + self.y._replay_fn = self._replay_grad + + self.create_stepper(HZ, self.step) + + def forward(self, ts, replay=False): + """Compute outputs from inputs at a given timestamp. + + Builds the computation graph parameterised by ts so that + gradients can later be evaluated at arbitrary times. + When replay=True, uses historical input values at ts. + """ + if replay: + v, m, c = self.v.at(ts), self.m.at(ts), self.c.at(ts) + else: + v, m, c = self.v.tensor, self.m.tensor, self.c.tensor + + t_val = torch.tensor(ts / 1e9, requires_grad=False) + x = v * t_val + y = m * x + c + return x, y + + def _replay_grad(self, ts, input_name, output_name): + x, y = self.forward(ts, replay=True) + outputs = {'x': x, 'y': y} + inp_var = self._variables[input_name] + (grad,) = torch.autograd.grad(outputs[output_name], inp_var._replay_tensor, retain_graph=True, allow_unused=True) + return float(outputs[output_name].detach()), float(grad) if grad is not None else 0.0 + + def step(self, ts): + x, y = self.forward(ts) + + # Setting output tensors triggers eager backward and caches gradients + self.x.tensor = x + self.y.tensor = y + + # Snapshot input values at this timestamp + self.v.snapshot(ts) + self.m.snapshot(ts) + self.c.snapshot(ts) + + self.x_pub.publish(Value(val=float(self.x.tensor.detach()), timestamp=ts)) + self.y_pub.publish(Value(val=float(self.y.tensor.detach()), timestamp=ts)) + + +if __name__ == "__main__": + try: + parser = argparse.ArgumentParser( + prog="diff_variable_pub", description="Differentiable variable publisher" + ) + common.add_config_arguments(parser) + args = parser.parse_args() + conf = common.get_config_from_args(args) + + node = LineVariableNode(conf) + node.spin() + except KeyboardInterrupt: + print("Shutting down diff variable publisher.") + node.close() diff --git a/test/gradient_exp.md b/test/gradient_exp.md new file mode 100644 index 0000000..c7f35e1 --- /dev/null +++ b/test/gradient_exp.md @@ -0,0 +1,79 @@ +# Gradient Experiment + +Demonstrates differentiable simulation using ark framework with distributed parameter publishing. A `ParamPublisherNode` publishes parameter values (`v`, `m`, `c`), a `LineVariableNode` subscribes to those parameters, computes position on a line (`y = m*x + c`, `x = v*t`) with autograd gradients, and an `AutodiffPlotterNode` subscribes to position and queries gradients in real time. + +## Architecture + +``` +ParamPublisherNode LineVariableNode AutodiffPlotterNode + publishes: subscribes to: subscribes to: + param/v (Value) ──► param/v, param/m, param/c position + param/m (Value) computes: queries: + param/c (Value) x = v*t, y = m*x + c grad/v/x, grad/m/y + publishes: plots: + position (Translation) trajectory + gradients + serves queryables: vs sim time + grad/{v,m,c}/{x,y} +``` + +## Key Concepts + +- **`create_variable(name, value, mode="input", fields=...)`** on `BaseNode`: + - Creates a `torch.tensor` with `requires_grad=True` + - Auto-subscribes on `param/{name}` to receive values from other nodes + - Auto-creates gradient queryables at `grad/{name}/{field}` for each field +- **`update_variable(name, grad_dict)`**: Caches gradients after `backward()`, served by queryables + +## Prerequisites + +- Install ark framework and dependencies (`zenoh`, `torch`, `matplotlib`, `ark_msgs`) +- Run all commands from the `test/` directory + +## Running the Experiment + +Open four separate terminals. All commands are run from the `test/` directory. + +### Shell 1 — Sim Clock + +Drives simulated time for all sim-enabled nodes. + +```bash +cd test +python simstep.py +``` + +### Shell 2 — Parameter Publisher + +Publishes fixed parameter values: `v=1.0`, `m=0.5`, `c=0.0`. + +```bash +cd test +python param_publisher.py +``` + +### Shell 3 — Diff Variable Publisher + +Subscribes to parameters, computes position and gradients, publishes position, serves gradient queryables. + +```bash +cd test +python diff_variable_pub.py +``` + +### Shell 4 — Autodiff Plotter + +Subscribes to position and queries gradients, then plots both against simulation time. + +```bash +cd test +python ad_plotter_sub.py +``` + +## What to Expect + +- **Shell 1** prints real elapsed time and sim time advancing each tick. +- **Shell 2** publishes parameter values at 10Hz (no output by default). +- **Shell 3** prints computed gradients (`dx/dv`, `dy/dm`) each step. +- **Shell 4** opens a matplotlib window with two plots: + - **Left**: Position trajectory (x vs y), autoscaling. + - **Right**: Gradients vs simulation time (dx/dv in green, dy/dm in magenta), autoscaling. diff --git a/test/param_publisher.py b/test/param_publisher.py new file mode 100644 index 0000000..ff439ea --- /dev/null +++ b/test/param_publisher.py @@ -0,0 +1,39 @@ +from ark.node import BaseNode +from ark_msgs import Value +import argparse +import common_example as common + +HZ = 10 + + +class ParamPublisherNode(BaseNode): + + def __init__(self, cfg): + super().__init__("env", "param_pub", cfg, sim=True) + self.pub_v = self.create_publisher("param/v") + self.pub_m = self.create_publisher("param/m") + self.pub_c = self.create_publisher("param/c") + self.rate = self.create_rate(HZ) + + def spin(self): + while True: + self.pub_v.publish(Value(val=1.0)) + self.pub_m.publish(Value(val=0.5)) + self.pub_c.publish(Value(val=0.0)) + self.rate.sleep() + + +if __name__ == "__main__": + try: + parser = argparse.ArgumentParser( + prog="param_publisher", description="Publishes parameter values" + ) + common.add_config_arguments(parser) + args = parser.parse_args() + conf = common.get_config_from_args(args) + + node = ParamPublisherNode(conf) + node.spin() + except KeyboardInterrupt: + print("Shutting down param publisher.") + node.close() diff --git a/test/plotter_subsriber.py b/test/plotter_subsriber.py new file mode 100644 index 0000000..2477962 --- /dev/null +++ b/test/plotter_subsriber.py @@ -0,0 +1,47 @@ +import threading +import matplotlib.pyplot as plt +import matplotlib.animation as animation +from ark.node import BaseNode +from ark_msgs import Translation, dTranslation +from common import z_cfg +class SubscriberPlotterNode(BaseNode): + def __init__(self): + super().__init__("env", "plotter", z_cfg, sim=True) + self.pos_x, self.pos_y = [], [] + self.vel_x, self.vel_y = [], [] + self.create_subscriber("position", self.on_position) + self.create_subscriber("velocity", self.on_velocity) + def on_position(self, msg: Translation): + self.pos_x.append(msg.x) + self.pos_y.append(msg.y) + def on_velocity(self, msg: dTranslation): + self.vel_x.append(msg.x) + self.vel_y.append(msg.y) +def main(): + node = SubscriberPlotterNode() + threading.Thread(target=node.spin, daemon=True).start() + fig, (ax_pos, ax_vel) = plt.subplots(1, 2, figsize=(10, 5)) + ax_pos.set_title("Position (Translation)") + ax_pos.set_xlabel("x") + ax_pos.set_ylabel("y") + ax_pos.set_xlim(-1.5, 1.5) + ax_pos.set_ylim(-1.5, 1.5) + ax_pos.set_aspect("equal") + (line_pos,) = ax_pos.plot([], [], "b-") + ax_vel.set_title("Velocity (dTranslation)") + ax_vel.set_xlabel("dx") + ax_vel.set_ylabel("dy") + ax_vel.set_xlim(-5, 5) + ax_vel.set_ylim(-5, 5) + ax_vel.set_aspect("equal") + (line_vel,) = ax_vel.plot([], [], "r-") + def update(frame): + line_pos.set_data(node.pos_x, node.pos_y) + line_vel.set_data(node.vel_x, node.vel_y) + return line_pos, line_vel + ani = animation.FuncAnimation(fig, update, interval=50, blit=True) + plt.tight_layout() + plt.show() + node.close() +if __name__ == "__main__": + main() diff --git a/test/simstep.py b/test/simstep.py new file mode 100644 index 0000000..48f1be2 --- /dev/null +++ b/test/simstep.py @@ -0,0 +1,22 @@ +from ark.time.simtime import SimTime +from common import z_cfg +import json +import zenoh +import time + + +def main(): + z_config = zenoh.Config.from_json5(json.dumps(z_cfg)) + with zenoh.open(z_config) as z: + sim_time = SimTime(z, "clock", 1000) + sim_time.reset() + start_time = time.time() + while True: + sim_time.tick() + elapsed = time.time() - start_time + sim_elapsed = sim_time._sim_time_ns / 1e9 + print(f"Real: {elapsed:.2f} s | Sim: {sim_elapsed:.3f} s") + + +if __name__ == "__main__": + main()