Skip to content

Commit e7c2a83

Browse files
committed
Add retroactive feature
1 parent ac2e19c commit e7c2a83

File tree

11 files changed

+211
-85
lines changed

11 files changed

+211
-85
lines changed

src/rembus/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
RembusError,
1414
CBOR,
1515
JSON,
16+
LastReceived,
17+
Now,
1618
QOS0,
1719
QOS1,
1820
QOS2,

src/rembus/admin.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import logging
22
import rembus.protocol as rp
3+
#import rembus.db as rdb
34

45
logger = logging.getLogger(__name__)
56

@@ -29,12 +30,14 @@ def remove_exposer(router, twin, topic):
2930
)
3031

3132

32-
def add_subscriber(router, twin, topic):
33+
def add_subscriber(router, twin, topic, msgfrom):
3334
logger.debug(
3435
"[%s] adding [%s] subscriber for topic [%s]", router, twin, topic
3536
)
3637
if topic not in router.subscribers:
3738
router.subscribers[topic] = []
39+
40+
twin.msg_from[topic] = msgfrom
3841
upsert_twin(router.subscribers[topic], twin)
3942

4043
def remove_subscriber(router, twin, topic):
@@ -45,6 +48,12 @@ def remove_subscriber(router, twin, topic):
4548
"[%s] removed [%s] subscriber for topic [%s]", router, twin, topic
4649
)
4750

51+
async def reactive(router, twin, status: bool):
52+
logger.debug("[%s] reactive: %s", twin, status)
53+
twin.isreactive = status
54+
if status:
55+
await router.inbox.put(rp.SendDataAtRest(twin))
56+
4857
async def admin_command(msg: rp.AdminMsg):
4958
"""Handle admin commands"""
5059
twin = msg.twin
@@ -61,8 +70,9 @@ async def admin_command(msg: rp.AdminMsg):
6170
elif cmd == rp.REMOVE_IMPL:
6271
remove_exposer(router, twin, topic)
6372
elif cmd == rp.ADD_INTEREST:
64-
add_subscriber(router, twin, topic)
73+
add_subscriber(router, twin, topic, msg.data[rp.MSG_FROM])
6574
elif cmd == rp.REMOVE_INTEREST:
6675
remove_subscriber(router, twin, topic)
67-
76+
elif cmd == rp.REACTIVE_CMD:
77+
await reactive(router, twin, msg.data[rp.STATUS])
6878
await twin.response(rp.STS_OK, msg)

src/rembus/core.py

Lines changed: 60 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ def __init__(
340340
self.config = rs.Config(name)
341341
self.policy = Policy(policy)
342342
self.owners = rs.load_tenants(self)
343-
self.start_ts = time.time()
343+
self.start_ts = rp.timestamp()
344344
self.msg_cache: list[rp.PubSubMsg] = []
345345
self.msg_topic_cache: dict[str, List[rp.PubSubMsg]] = {}
346346
self.tables: dict[str, rdb.Table] = {}
@@ -361,9 +361,10 @@ def __repr__(self):
361361

362362
def isconnected(self, rid: str) -> bool:
363363
"""Check if a component with the given rid is connected."""
364-
for tk in self.id_twin:
364+
for tk, twin in self.id_twin.items():
365365
if tk.startswith(rid + "@"):
366-
return True
366+
if twin.isopen():
367+
return True
367368
return False
368369

369370
def uptime(self) -> str:
@@ -419,14 +420,16 @@ async def _broadcast(self, msg):
419420
if t != twin:
420421
# Do not send back to publisher.
421422
await t.send(msg)
423+
t.mark = msg.recvts
422424

423425
except Exception as e: # pylint: disable=broad-exception-caught
424426
logger.warning("[%s] error in method invocation: %s", self, e)
425427
traceback.print_exc()
426428

427429
async def _pubsub_msg(self, msg: rp.PubSubMsg):
428-
msg.recvts = int(time.time())
429430
twin = msg.twin
431+
ts = rp.timestamp()
432+
msg.recvts = ts
430433
qos = msg.flags & rp.QOS2
431434
if qos > rp.QOS0 and msg.id:
432435
if twin.socket:
@@ -438,7 +441,8 @@ async def _pubsub_msg(self, msg: rp.PubSubMsg):
438441
return
439442
else:
440443
# Save the message id to guarantee exactly one delivery.
441-
twin.ackdf[msg.id] = int(time.time())
444+
##twin.ackdf[msg.id] = int(time.time())
445+
twin.ackdf[msg.id] = rp.timestamp()
442446

443447
if self.db is not None:
444448
# save the message into msg_cache
@@ -526,7 +530,7 @@ async def _handle_identity(self, msg: rp.IdentityMsg) -> None:
526530
await self._auth_identity(msg)
527531

528532
async def _handle_attestation(self, msg: rp.AttestationMsg) -> None:
529-
sts = self._verify_signature(msg)
533+
sts = await self._verify_signature(msg)
530534
await msg.twin.response(sts, msg)
531535

532536
async def _handle_admin(self, msg: rp.AdminMsg) -> None:
@@ -554,6 +558,8 @@ async def _task_impl(self) -> None:
554558
case "save_messages":
555559
# save messages to the database periodically
556560
rdb.save_data_at_rest(self)
561+
case rp.SendDataAtRest():
562+
await rdb.send_data_at_rest(msg)
557563
case rp.PubSubMsg():
558564
await self._pubsub_msg(msg)
559565
case rp.RpcReqMsg():
@@ -587,11 +593,8 @@ async def evaluate(self, twin, topic: str, data: Any) -> Any:
587593
async def _client_receiver(self, ws):
588594
"""Receive messages from the client component."""
589595
url = RbURL()
590-
if url.twkey in self.id_twin:
591-
twin = self.id_twin[url.twkey]
592-
else:
593-
twin = Twin(url, bottom_router(self), False)
594-
self.id_twin[url.twkey] = twin
596+
twin = Twin(url, bottom_router(self), False)
597+
self.id_twin[url.twkey] = twin
595598

596599
twin.socket = ws
597600
await twin.twin_receiver()
@@ -627,13 +630,15 @@ def _needs_auth(self, cid: str):
627630
except FileNotFoundError:
628631
return False
629632

630-
def _update_twin(self, twin, identity):
633+
async def _update_twin(self, twin, identity):
631634
logger.debug("[%s] setting name: [%s]", twin, identity)
632635
self.id_twin.pop(twin.twkey, twin)
633636
twin.rid = identity
634637
self.id_twin[twin.twkey] = twin
638+
if twin.db is not None:
639+
load_mark(twin)
635640

636-
def _verify_signature(self, msg: rp.AttestationMsg):
641+
async def _verify_signature(self, msg: rp.AttestationMsg):
637642
"""Verify the signature of the attestation message."""
638643
twin = msg.twin
639644
cid = msg.cid
@@ -654,7 +659,7 @@ def _verify_signature(self, msg: rp.AttestationMsg):
654659
elif isinstance(pubkey, ec.EllipticCurvePublicKey):
655660
pubkey.verify(signature, plain, ec.ECDSA(hashes.SHA256()))
656661

657-
self._update_twin(twin, msg.cid)
662+
await self._update_twin(twin, msg.cid)
658663
return rp.STS_OK
659664
except Exception as e: # pylint: disable=broad-exception-caught
660665
logger.error("verification failed: %s (%s)", e, type(e))
@@ -680,7 +685,7 @@ async def _auth_identity(self, msg: rp.IdentityMsg):
680685
# component is provisioned, send the challenge
681686
response = self._challenge(msg)
682687
else:
683-
self._update_twin(twin, identity)
688+
await self._update_twin(twin, identity)
684689
response = rp.ResMsg(id=msg.id, status=rp.STS_OK)
685690

686691
await twin.send(response)
@@ -764,6 +769,9 @@ def __init__(
764769
self.reconnect_task: Optional[asyncio.Task[None]] = None
765770
self.ackdf: dict[int, int] = {} # msgid => ts
766771
self.handler["phase"] = lambda: "CLOSED"
772+
self.isreactive: bool = False
773+
self.msg_from: dict[str, float] = {}
774+
self.mark: int = 0
767775
self.start()
768776

769777
def __str__(self):
@@ -790,6 +798,7 @@ def rid(self):
790798

791799
@rid.setter
792800
def rid(self, rid: str):
801+
self.uid.hasname = True
793802
self.uid.id = rid
794803

795804
@property
@@ -854,7 +863,10 @@ async def _reconnect(self):
854863
async def _shutdown(self):
855864
"""Cleanup logic when shutting down the twin."""
856865
logger.debug("[%s] twin shutdown", self)
857-
866+
867+
if self.db is not None:
868+
save_mark(self)
869+
858870
if self.isclient or self.uid.isrepl():
859871
await self._shutdown_router()
860872

@@ -1113,7 +1125,7 @@ async def publish(self, topic: str, *data: Any, **kwargs):
11131125
"""Publish a message to the specified topic."""
11141126

11151127
slot = kwargs.get("slot", None)
1116-
qos = kwargs.get("qos", rp.QOS0)
1128+
qos = kwargs.get("qos", rp.QOS0) & rp.QOS2
11171129
if qos == rp.QOS0:
11181130
msg = rp.PubSubMsg(topic=topic, data=data, slot=slot)
11191131
msg.twin = self
@@ -1261,7 +1273,7 @@ async def unreactive(self):
12611273
async def subscribe(
12621274
self,
12631275
fn: Callable[..., Any],
1264-
retroactive: bool = False,
1276+
msgfrom: float = rp.Now,
12651277
topic: Optional[str] = None,
12661278
):
12671279
"""
@@ -1270,17 +1282,14 @@ async def subscribe(
12701282
if topic is None:
12711283
topic = fn.__name__
12721284

1273-
await self.setting(topic, rp.ADD_INTEREST, {"retroactive": retroactive})
1285+
await self.setting(topic, rp.ADD_INTEREST, {"msg_from": msgfrom})
12741286
self.router.handler[topic] = fn
12751287
return self
12761288

12771289
async def unsubscribe(self, fn: Callable[..., Any] | str):
12781290
"""
12791291
Unsubscribe the function from the corresponding topic.
12801292
"""
1281-
# if topic is None:
1282-
# topic = fn.__name__
1283-
12841293
if isinstance(fn, str):
12851294
topic = fn
12861295
else:
@@ -1306,9 +1315,6 @@ async def unexpose(
13061315
"""
13071316
Unexpose the function as a remote procedure call(RPC) handler.
13081317
"""
1309-
# if topic is None:
1310-
# topic = fn.__name__
1311-
13121318
if isinstance(fn, str):
13131319
topic = fn
13141320
else:
@@ -1404,3 +1410,32 @@ async def component(
14041410
await router.init_twin(RbURL(netlink), enc, isserver)
14051411

14061412
return handle
1413+
1414+
def load_mark(twin):
1415+
result =twin.db.sql("""
1416+
SELECT mark
1417+
FROM mark
1418+
WHERE twin = ? AND name = ?
1419+
""", params=[twin.rid, twin.router.id]).fetchone()
1420+
1421+
if result:
1422+
twin.mark = result[0]
1423+
logger.debug("[%s] loaded mark %s", twin, twin.mark)
1424+
else:
1425+
logger.debug("[%s] no mark found in database", twin)
1426+
1427+
def save_mark(twin):
1428+
db = twin.db
1429+
mark = twin.mark
1430+
tid = twin.rid
1431+
name = twin.router.id
1432+
db.sql("""
1433+
MERGE INTO mark AS m
1434+
USING (SELECT ? AS name, ? AS twin, ? AS mark) AS t
1435+
ON m.name = t.name AND m.twin = t.twin
1436+
WHEN MATCHED THEN
1437+
UPDATE SET mark = t.mark
1438+
WHEN NOT MATCHED THEN
1439+
INSERT (name, twin, mark) VALUES (t.name, t.twin, t.mark)
1440+
""", params=[name, tid, mark])
1441+
logger.debug("[%s] saved mark %s", twin, mark)

src/rembus/db.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import pyarrow as pa
1515
from pydantic import BaseModel, Field, model_validator
1616
from rembus.settings import broker_dir, rembus_dir
17-
from rembus.protocol import tag2df, df2tag, PubSubMsg
17+
from rembus.protocol import tag2df, df2tag, timestamp, PubSubMsg
1818

1919
logger = logging.getLogger(__name__)
2020

@@ -116,7 +116,7 @@ def reset_db(broker_name):
116116

117117
elif dl_url.startswith("sqlite"):
118118
logger.debug("removing db %s", db)
119-
os.remove(db)
119+
Path(db).unlink(missing_ok=True)
120120
else:
121121
broker_ducklake = Path(rembus_dir()) / f"{broker_name}.ducklake"
122122
broker_ducklake.unlink(True)
@@ -151,7 +151,7 @@ def init_db(router, schema):
151151
qos UTINYINT,
152152
uid UBIGINT,
153153
topic TEXT NOT NULL,
154-
data TEXT
154+
data BLOB
155155
)""",
156156
"""
157157
CREATE TABLE IF NOT EXISTS exposer (
@@ -608,9 +608,7 @@ def msg_table(router, msg: PubSubMsg):
608608

609609
def save_data_at_rest(router):
610610
"""Save cached messages to the database."""
611-
logger.debug(
612-
"[save_data_at_rest] saving %d messages", len(router.msg_cache)
613-
)
611+
#logger.debug("[%s] saving %d messages", router, len(router.msg_cache))
614612
msgs = router.msg_cache
615613
if not msgs:
616614
return
@@ -628,6 +626,27 @@ def save_data_at_rest(router):
628626
router.msg_cache.clear()
629627
router.msg_topic_cache.clear()
630628

629+
async def send_messages(twin, df, ts):
630+
r = twin.router
631+
for (name, recv, slot, qos, uid, topic, data) in df.iter_rows():
632+
if recv>=ts and topic in r.subscribers and twin in r.subscribers[topic]:
633+
payload = cbor2.loads(data)
634+
if payload:
635+
await twin.publish(topic, *payload, slot=slot, qos=qos)
636+
else:
637+
await twin.publish(topic, slot=slot, qos=qos)
638+
639+
async def send_data_at_rest(msg, max_period=3600000000000):
640+
twin = msg.twin
641+
r = twin.router
642+
db = twin.db
643+
644+
ts = timestamp() - max_period
645+
if twin.uid.hasname:
646+
logger.debug("[%s] sending data at rest", twin)
647+
df = db.execute(
648+
f"SELECT * FROM message WHERE name='{r.id}' AND recv>={ts}").pl()
649+
await send_messages(twin, df, ts)
631650

632651
def build_message_batch(broker_id: str, msgs: list):
633652
"""Build a PyArrow Table from a list of message tuples."""

src/rembus/protocol.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import logging
99
from typing import Any, List
1010
import json
11+
import time
1112
import cbor2
1213
from cryptography.hazmat.primitives import serialization
1314
from cryptography.hazmat.primitives.asymmetric import rsa, ec
@@ -30,6 +31,9 @@
3031
CBOR = 0
3132
JSON = 1
3233

34+
Now = 0.0
35+
LastReceived = float("inf")
36+
3337
UInt8 = conint(ge=0, le=255)
3438

3539
QOS0 = UInt8(0x00)
@@ -83,11 +87,16 @@
8387

8488
BROKER_CONFIG = "__config__"
8589
COMMAND = "cmd"
90+
MSG_FROM = "msg_from"
8691
ADD_INTEREST = "subscribe"
8792
REMOVE_INTEREST = "unsubscribe"
8893
ADD_IMPL = "expose"
8994
REMOVE_IMPL = "unexpose"
95+
REACTIVE_CMD = "reactive"
96+
STATUS = "status"
9097

98+
def timestamp():
99+
return time.time_ns()
91100

92101
def msgid():
93102
"""Return an array of MSGID_SZ random bytes."""
@@ -116,6 +125,9 @@ def bytes2id(byte_data: bytearray) -> int:
116125
# Convert bytes to integer (assuming little-endian, adjust if big-endian)
117126
return from_bytes(byte_data)
118127

128+
class SendDataAtRest:
129+
def __init__(self, twin):
130+
self.twin = twin
119131

120132
class RembusException(Exception):
121133
"""Base class for all Rembus exceptions."""

0 commit comments

Comments
 (0)