Skip to content

Commit 92ca83e

Browse files
committed
Save exposers and subscribers
1 parent ae1824a commit 92ca83e

File tree

5 files changed

+175
-47
lines changed

5 files changed

+175
-47
lines changed

.flake8

Lines changed: 0 additions & 6 deletions
This file was deleted.

pyrightconfig.json

Lines changed: 0 additions & 8 deletions
This file was deleted.

src/rembus/core.py

Lines changed: 167 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import cbor2
2222
from cryptography.hazmat.primitives import hashes
2323
from cryptography.hazmat.primitives.asymmetric import padding, rsa, ec
24+
import polars as pl
2425
import websockets
2526
import rembus.protocol as rp
2627
import rembus.settings as rs
@@ -377,8 +378,11 @@ def _builtins(self):
377378
self.handler["uptime"] = lambda *_: self.uptime()
378379

379380
async def init_twin(self, uid: RbURL, enc: int, isserver: bool):
380-
"""Create and start a Twin"""
381+
"""
382+
Create and start a Twin for the component that connects to a server.
383+
"""
381384
cmp = Twin(uid, bottom_router(self), not isserver, enc)
385+
cmp.start()
382386
if not uid.isrepl():
383387
self.id_twin[uid.twkey] = cmp
384388

@@ -441,7 +445,7 @@ async def _pubsub_msg(self, msg: rp.PubSubMsg):
441445
return
442446
else:
443447
# Save the message id to guarantee exactly one delivery.
444-
##twin.ackdf[msg.id] = int(time.time())
448+
# twin.ackdf[msg.id] = int(time.time())
445449
twin.ackdf[msg.id] = rp.timestamp()
446450

447451
if self.db is not None:
@@ -594,6 +598,7 @@ async def _client_receiver(self, ws):
594598
"""Receive messages from the client component."""
595599
url = RbURL()
596600
twin = Twin(url, bottom_router(self), False)
601+
twin.start()
597602
self.id_twin[url.twkey] = twin
598603

599604
twin.socket = ws
@@ -636,7 +641,7 @@ async def _update_twin(self, twin, identity):
636641
twin.rid = identity
637642
self.id_twin[twin.twkey] = twin
638643
if twin.db is not None:
639-
load_mark(twin)
644+
load_twin(twin)
640645

641646
async def _verify_signature(self, msg: rp.AttestationMsg):
642647
"""Verify the signature of the attestation message."""
@@ -772,7 +777,7 @@ def __init__(
772777
self.isreactive: bool = False
773778
self.msg_from: dict[str, float] = {}
774779
self.mark: int = 0
775-
self.start()
780+
# self.start()
776781

777782
def __str__(self):
778783
return f"{self.uid.id}"
@@ -863,10 +868,10 @@ async def _reconnect(self):
863868
async def _shutdown(self):
864869
"""Cleanup logic when shutting down the twin."""
865870
logger.debug("[%s] twin shutdown", self)
866-
871+
867872
if self.db is not None:
868-
save_mark(self)
869-
873+
save_twin(self)
874+
870875
if self.isclient or self.uid.isrepl():
871876
await self._shutdown_router()
872877

@@ -1411,31 +1416,164 @@ async def component(
14111416

14121417
return handle
14131418

1414-
def load_mark(twin):
1415-
result =twin.db.sql("""
1416-
SELECT mark
1417-
FROM mark
1418-
WHERE twin = ? AND name = ?
1419-
""", params=[twin.rid, twin.router.id]).fetchone()
1419+
1420+
def sync_table(
1421+
db, table_name: str, current_df: pl.DataFrame, new_df: pl.DataFrame
1422+
):
1423+
"""
1424+
Synchronize a db table by removing rows not in `new_df`
1425+
and adding rows not in `current_df`.
1426+
"""
1427+
fields = current_df.columns
1428+
1429+
conds = ["df.name = t.name"]
1430+
for field in fields:
1431+
conds.append(f"df.{field} = t.{field}")
1432+
cond_str = " AND ".join(conds)
1433+
1434+
# Find rows in current_df that are not in new_df (rows to delete)
1435+
diff_df = current_df.join(new_df, on=fields, how="anti")
1436+
1437+
if not diff_df.is_empty():
1438+
# Delete rows that exist in current but not in new
1439+
db.sql(f"""
1440+
DELETE FROM {table_name} t WHERE EXISTS (
1441+
SELECT 1 FROM diff_df WHERE {cond_str}
1442+
)
1443+
""")
1444+
1445+
# Find rows in new_df that are not in current_df (rows to insert)
1446+
diff_df = new_df.join(current_df, on=fields, how="anti")
1447+
1448+
if not diff_df.is_empty():
1449+
db.sql(f"INSERT INTO {table_name} SELECT * FROM diff_df")
1450+
1451+
1452+
def sync_twin(
1453+
db, router_name: str, twin_name: str, table_name: str, new_df: pl.DataFrame
1454+
):
1455+
"""Synchronize a twin's data in a specific table."""
1456+
current_df = db.sql(
1457+
f"""
1458+
SELECT * FROM {table_name}
1459+
WHERE name = ? AND twin = ?
1460+
""",
1461+
params=[router_name, twin_name],
1462+
).pl()
1463+
1464+
sync_table(db, table_name, current_df, new_df)
1465+
1466+
1467+
def sync_cfg(db, router_name: str, table_name: str, new_df: pl.DataFrame):
1468+
"""Synchronize configuration data for a router."""
1469+
current_df = db.sql(
1470+
f"""
1471+
SELECT * FROM {table_name}
1472+
WHERE name = ?
1473+
""",
1474+
params=[router_name],
1475+
).pl()
1476+
1477+
sync_table(db, table_name, current_df, new_df)
1478+
1479+
1480+
def save_twin(twin):
1481+
"""
1482+
Save twin data (subscribers, exposers, and marks) to the database.
1483+
"""
1484+
router = twin.router
1485+
tid = twin.rid
1486+
name = router.id
1487+
db = router.db
1488+
1489+
# Save subscriber data
1490+
if twin.msg_from:
1491+
current_df = pl.DataFrame(
1492+
{
1493+
"name": [name] * len(twin.msg_from),
1494+
"twin": [tid] * len(twin.msg_from),
1495+
"topic": list(twin.msg_from.keys()),
1496+
"msg_from": list(twin.msg_from.values()),
1497+
}
1498+
)
1499+
sync_twin(db, name, tid, "subscriber", current_df)
1500+
1501+
# Save exposer data
1502+
exposed_topics = exposed_topics_for_twin(router, twin)
1503+
if exposed_topics:
1504+
current_df = pl.DataFrame(
1505+
{
1506+
"name": [name] * len(exposed_topics),
1507+
"twin": [tid] * len(exposed_topics),
1508+
"topic": exposed_topics,
1509+
}
1510+
)
1511+
sync_twin(db, name, tid, "exposer", current_df)
1512+
1513+
# Save mark data
1514+
current_df = pl.DataFrame(
1515+
{"name": [name], "twin": [tid], "mark": [twin.mark]}
1516+
)
1517+
sync_twin(db, name, tid, "mark", current_df)
1518+
1519+
1520+
def exposed_topics_for_twin(router, twin) -> List[str]:
1521+
"""
1522+
Get list of topics exposed by this twin.
1523+
"""
1524+
exposed = []
1525+
for topic, twins in router.exposers.items():
1526+
if twin in twins:
1527+
exposed.append(topic)
1528+
return exposed
1529+
1530+
1531+
def load_twin(twin):
1532+
"""
1533+
Load twin data (subscribers, exposers, and marks) from the database.
1534+
"""
1535+
router = twin.router
1536+
name = router.id
1537+
tid = twin.rid
1538+
db = router.db
1539+
1540+
df = db.sql(
1541+
"SELECT topic, msg_from FROM subscriber WHERE name = ? AND twin = ?",
1542+
params=[name, tid],
1543+
).pl()
1544+
1545+
if not df.is_empty():
1546+
twin.msg_from = dict(
1547+
zip(df["topic"].to_list(), df["msg_from"].to_list())
1548+
)
1549+
1550+
# Update router's subscribers
1551+
for topic in df["topic"].to_list():
1552+
if topic not in router.subscribers:
1553+
router.subscribers[topic] = []
1554+
if twin not in router.subscribers[topic]:
1555+
router.subscribers[topic].append(twin)
1556+
1557+
df = db.sql(
1558+
"SELECT topic FROM exposer WHERE name = ? AND twin = ?",
1559+
params=[name, tid],
1560+
).pl()
1561+
1562+
if not df.is_empty():
1563+
# Update router's exposers
1564+
for topic in df["topic"].to_list():
1565+
if topic not in router.exposers:
1566+
router.exposers[topic] = []
1567+
if twin not in router.exposers[topic]:
1568+
router.exposers[topic].append(twin)
1569+
1570+
# Load mark data
1571+
result = db.sql(
1572+
"SELECT mark FROM mark WHERE name = ? AND twin = ?", params=[name, tid]
1573+
).fetchone()
14201574

14211575
if result:
14221576
twin.mark = result[0]
14231577
logger.debug("[%s] loaded mark %s", twin, twin.mark)
14241578
else:
14251579
logger.debug("[%s] no mark found in database", twin)
1426-
1427-
def save_mark(twin):
1428-
db = twin.db
1429-
mark = twin.mark
1430-
tid = twin.rid
1431-
name = twin.router.id
1432-
db.sql("""
1433-
MERGE INTO mark AS m
1434-
USING (SELECT ? AS name, ? AS twin, ? AS mark) AS t
1435-
ON m.name = t.name AND m.twin = t.twin
1436-
WHEN MATCHED THEN
1437-
UPDATE SET mark = t.mark
1438-
WHEN NOT MATCHED THEN
1439-
INSERT (name, twin, mark) VALUES (t.name, t.twin, t.mark)
1440-
""", params=[name, tid, mark])
1441-
logger.debug("[%s] saved mark %s", twin, mark)

src/rembus/db.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -608,7 +608,7 @@ def msg_table(router, msg: PubSubMsg):
608608

609609
def save_data_at_rest(router):
610610
"""Save cached messages to the database."""
611-
#logger.debug("[%s] saving %d messages", router, len(router.msg_cache))
611+
# logger.debug("[%s] saving %d messages", router, len(router.msg_cache))
612612
msgs = router.msg_cache
613613
if not msgs:
614614
return
@@ -626,20 +626,22 @@ def save_data_at_rest(router):
626626
router.msg_cache.clear()
627627
router.msg_topic_cache.clear()
628628

629+
629630
async def send_messages(twin, df, ts):
630631
r = twin.router
631632
for (name, recv, slot, qos, uid, topic, data) in df.iter_rows():
632-
if (recv>twin.mark and
633-
recv>ts-twin.msg_from.get(topic, 0) and
633+
if (recv > twin.mark and
634+
recv > ts-twin.msg_from.get(topic, 0) and
634635
topic in r.subscribers and
635-
twin in r.subscribers[topic]):
636+
twin in r.subscribers[topic]):
636637
payload = cbor2.loads(data)
637638
if payload:
638639
await twin.publish(topic, *payload, slot=slot, qos=qos)
639640
else:
640641
await twin.publish(topic, slot=slot, qos=qos)
641642
twin.mark = recv
642643

644+
643645
async def send_data_at_rest(msg, max_period=3600_000_000_000):
644646
twin = msg.twin
645647
r = twin.router
@@ -652,6 +654,7 @@ async def send_data_at_rest(msg, max_period=3600_000_000_000):
652654
f"SELECT * FROM message WHERE name='{r.id}' AND recv>={ts-max_period}").pl()
653655
await send_messages(twin, df, ts)
654656

657+
655658
def build_message_batch(broker_id: str, msgs: list):
656659
"""Build a PyArrow Table from a list of message tuples."""
657660

tests/unit/test_protocol.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ async def test_rembus_messages():
4848
"""Test the string representation of rembus protocol messages."""
4949
router = rc.Router("broker")
5050
twin = rc.Twin(rc.RbURL("twin"), router)
51+
twin.start()
5152
for msg in [
5253
rp.AttestationMsg(id=1, cid="cid", signature=b"signature"),
5354
rp.IdentityMsg(id=2, cid="cid"),

0 commit comments

Comments
 (0)