|
21 | 21 | import cbor2 |
22 | 22 | from cryptography.hazmat.primitives import hashes |
23 | 23 | from cryptography.hazmat.primitives.asymmetric import padding, rsa, ec |
| 24 | +import polars as pl |
24 | 25 | import websockets |
25 | 26 | import rembus.protocol as rp |
26 | 27 | import rembus.settings as rs |
@@ -377,8 +378,11 @@ def _builtins(self): |
377 | 378 | self.handler["uptime"] = lambda *_: self.uptime() |
378 | 379 |
|
379 | 380 | async def init_twin(self, uid: RbURL, enc: int, isserver: bool): |
380 | | - """Create and start a Twin""" |
| 381 | + """ |
| 382 | + Create and start a Twin for the component that connects to a server. |
| 383 | + """ |
381 | 384 | cmp = Twin(uid, bottom_router(self), not isserver, enc) |
| 385 | + cmp.start() |
382 | 386 | if not uid.isrepl(): |
383 | 387 | self.id_twin[uid.twkey] = cmp |
384 | 388 |
|
@@ -441,7 +445,7 @@ async def _pubsub_msg(self, msg: rp.PubSubMsg): |
441 | 445 | return |
442 | 446 | else: |
443 | 447 | # Save the message id to guarantee exactly one delivery. |
444 | | - ##twin.ackdf[msg.id] = int(time.time()) |
| 448 | + # twin.ackdf[msg.id] = int(time.time()) |
445 | 449 | twin.ackdf[msg.id] = rp.timestamp() |
446 | 450 |
|
447 | 451 | if self.db is not None: |
@@ -594,6 +598,7 @@ async def _client_receiver(self, ws): |
594 | 598 | """Receive messages from the client component.""" |
595 | 599 | url = RbURL() |
596 | 600 | twin = Twin(url, bottom_router(self), False) |
| 601 | + twin.start() |
597 | 602 | self.id_twin[url.twkey] = twin |
598 | 603 |
|
599 | 604 | twin.socket = ws |
@@ -636,7 +641,7 @@ async def _update_twin(self, twin, identity): |
636 | 641 | twin.rid = identity |
637 | 642 | self.id_twin[twin.twkey] = twin |
638 | 643 | if twin.db is not None: |
639 | | - load_mark(twin) |
| 644 | + load_twin(twin) |
640 | 645 |
|
641 | 646 | async def _verify_signature(self, msg: rp.AttestationMsg): |
642 | 647 | """Verify the signature of the attestation message.""" |
@@ -772,7 +777,7 @@ def __init__( |
772 | 777 | self.isreactive: bool = False |
773 | 778 | self.msg_from: dict[str, float] = {} |
774 | 779 | self.mark: int = 0 |
775 | | - self.start() |
| 780 | + # self.start() |
776 | 781 |
|
777 | 782 | def __str__(self): |
778 | 783 | return f"{self.uid.id}" |
@@ -863,10 +868,10 @@ async def _reconnect(self): |
863 | 868 | async def _shutdown(self): |
864 | 869 | """Cleanup logic when shutting down the twin.""" |
865 | 870 | logger.debug("[%s] twin shutdown", self) |
866 | | - |
| 871 | + |
867 | 872 | if self.db is not None: |
868 | | - save_mark(self) |
869 | | - |
| 873 | + save_twin(self) |
| 874 | + |
870 | 875 | if self.isclient or self.uid.isrepl(): |
871 | 876 | await self._shutdown_router() |
872 | 877 |
|
@@ -1411,31 +1416,164 @@ async def component( |
1411 | 1416 |
|
1412 | 1417 | return handle |
1413 | 1418 |
|
1414 | | -def load_mark(twin): |
1415 | | - result =twin.db.sql(""" |
1416 | | - SELECT mark |
1417 | | - FROM mark |
1418 | | - WHERE twin = ? AND name = ? |
1419 | | - """, params=[twin.rid, twin.router.id]).fetchone() |
| 1419 | + |
| 1420 | +def sync_table( |
| 1421 | + db, table_name: str, current_df: pl.DataFrame, new_df: pl.DataFrame |
| 1422 | +): |
| 1423 | + """ |
| 1424 | + Synchronize a db table by removing rows not in `new_df` |
| 1425 | + and adding rows not in `current_df`. |
| 1426 | + """ |
| 1427 | + fields = current_df.columns |
| 1428 | + |
| 1429 | + conds = ["df.name = t.name"] |
| 1430 | + for field in fields: |
| 1431 | + conds.append(f"df.{field} = t.{field}") |
| 1432 | + cond_str = " AND ".join(conds) |
| 1433 | + |
| 1434 | + # Find rows in current_df that are not in new_df (rows to delete) |
| 1435 | + diff_df = current_df.join(new_df, on=fields, how="anti") |
| 1436 | + |
| 1437 | + if not diff_df.is_empty(): |
| 1438 | + # Delete rows that exist in current but not in new |
| 1439 | + db.sql(f""" |
| 1440 | + DELETE FROM {table_name} t WHERE EXISTS ( |
| 1441 | + SELECT 1 FROM diff_df WHERE {cond_str} |
| 1442 | + ) |
| 1443 | + """) |
| 1444 | + |
| 1445 | + # Find rows in new_df that are not in current_df (rows to insert) |
| 1446 | + diff_df = new_df.join(current_df, on=fields, how="anti") |
| 1447 | + |
| 1448 | + if not diff_df.is_empty(): |
| 1449 | + db.sql(f"INSERT INTO {table_name} SELECT * FROM diff_df") |
| 1450 | + |
| 1451 | + |
| 1452 | +def sync_twin( |
| 1453 | + db, router_name: str, twin_name: str, table_name: str, new_df: pl.DataFrame |
| 1454 | +): |
| 1455 | + """Synchronize a twin's data in a specific table.""" |
| 1456 | + current_df = db.sql( |
| 1457 | + f""" |
| 1458 | + SELECT * FROM {table_name} |
| 1459 | + WHERE name = ? AND twin = ? |
| 1460 | + """, |
| 1461 | + params=[router_name, twin_name], |
| 1462 | + ).pl() |
| 1463 | + |
| 1464 | + sync_table(db, table_name, current_df, new_df) |
| 1465 | + |
| 1466 | + |
| 1467 | +def sync_cfg(db, router_name: str, table_name: str, new_df: pl.DataFrame): |
| 1468 | + """Synchronize configuration data for a router.""" |
| 1469 | + current_df = db.sql( |
| 1470 | + f""" |
| 1471 | + SELECT * FROM {table_name} |
| 1472 | + WHERE name = ? |
| 1473 | + """, |
| 1474 | + params=[router_name], |
| 1475 | + ).pl() |
| 1476 | + |
| 1477 | + sync_table(db, table_name, current_df, new_df) |
| 1478 | + |
| 1479 | + |
| 1480 | +def save_twin(twin): |
| 1481 | + """ |
| 1482 | + Save twin data (subscribers, exposers, and marks) to the database. |
| 1483 | + """ |
| 1484 | + router = twin.router |
| 1485 | + tid = twin.rid |
| 1486 | + name = router.id |
| 1487 | + db = router.db |
| 1488 | + |
| 1489 | + # Save subscriber data |
| 1490 | + if twin.msg_from: |
| 1491 | + current_df = pl.DataFrame( |
| 1492 | + { |
| 1493 | + "name": [name] * len(twin.msg_from), |
| 1494 | + "twin": [tid] * len(twin.msg_from), |
| 1495 | + "topic": list(twin.msg_from.keys()), |
| 1496 | + "msg_from": list(twin.msg_from.values()), |
| 1497 | + } |
| 1498 | + ) |
| 1499 | + sync_twin(db, name, tid, "subscriber", current_df) |
| 1500 | + |
| 1501 | + # Save exposer data |
| 1502 | + exposed_topics = exposed_topics_for_twin(router, twin) |
| 1503 | + if exposed_topics: |
| 1504 | + current_df = pl.DataFrame( |
| 1505 | + { |
| 1506 | + "name": [name] * len(exposed_topics), |
| 1507 | + "twin": [tid] * len(exposed_topics), |
| 1508 | + "topic": exposed_topics, |
| 1509 | + } |
| 1510 | + ) |
| 1511 | + sync_twin(db, name, tid, "exposer", current_df) |
| 1512 | + |
| 1513 | + # Save mark data |
| 1514 | + current_df = pl.DataFrame( |
| 1515 | + {"name": [name], "twin": [tid], "mark": [twin.mark]} |
| 1516 | + ) |
| 1517 | + sync_twin(db, name, tid, "mark", current_df) |
| 1518 | + |
| 1519 | + |
| 1520 | +def exposed_topics_for_twin(router, twin) -> List[str]: |
| 1521 | + """ |
| 1522 | + Get list of topics exposed by this twin. |
| 1523 | + """ |
| 1524 | + exposed = [] |
| 1525 | + for topic, twins in router.exposers.items(): |
| 1526 | + if twin in twins: |
| 1527 | + exposed.append(topic) |
| 1528 | + return exposed |
| 1529 | + |
| 1530 | + |
| 1531 | +def load_twin(twin): |
| 1532 | + """ |
| 1533 | + Load twin data (subscribers, exposers, and marks) from the database. |
| 1534 | + """ |
| 1535 | + router = twin.router |
| 1536 | + name = router.id |
| 1537 | + tid = twin.rid |
| 1538 | + db = router.db |
| 1539 | + |
| 1540 | + df = db.sql( |
| 1541 | + "SELECT topic, msg_from FROM subscriber WHERE name = ? AND twin = ?", |
| 1542 | + params=[name, tid], |
| 1543 | + ).pl() |
| 1544 | + |
| 1545 | + if not df.is_empty(): |
| 1546 | + twin.msg_from = dict( |
| 1547 | + zip(df["topic"].to_list(), df["msg_from"].to_list()) |
| 1548 | + ) |
| 1549 | + |
| 1550 | + # Update router's subscribers |
| 1551 | + for topic in df["topic"].to_list(): |
| 1552 | + if topic not in router.subscribers: |
| 1553 | + router.subscribers[topic] = [] |
| 1554 | + if twin not in router.subscribers[topic]: |
| 1555 | + router.subscribers[topic].append(twin) |
| 1556 | + |
| 1557 | + df = db.sql( |
| 1558 | + "SELECT topic FROM exposer WHERE name = ? AND twin = ?", |
| 1559 | + params=[name, tid], |
| 1560 | + ).pl() |
| 1561 | + |
| 1562 | + if not df.is_empty(): |
| 1563 | + # Update router's exposers |
| 1564 | + for topic in df["topic"].to_list(): |
| 1565 | + if topic not in router.exposers: |
| 1566 | + router.exposers[topic] = [] |
| 1567 | + if twin not in router.exposers[topic]: |
| 1568 | + router.exposers[topic].append(twin) |
| 1569 | + |
| 1570 | + # Load mark data |
| 1571 | + result = db.sql( |
| 1572 | + "SELECT mark FROM mark WHERE name = ? AND twin = ?", params=[name, tid] |
| 1573 | + ).fetchone() |
1420 | 1574 |
|
1421 | 1575 | if result: |
1422 | 1576 | twin.mark = result[0] |
1423 | 1577 | logger.debug("[%s] loaded mark %s", twin, twin.mark) |
1424 | 1578 | else: |
1425 | 1579 | logger.debug("[%s] no mark found in database", twin) |
1426 | | - |
1427 | | -def save_mark(twin): |
1428 | | - db = twin.db |
1429 | | - mark = twin.mark |
1430 | | - tid = twin.rid |
1431 | | - name = twin.router.id |
1432 | | - db.sql(""" |
1433 | | - MERGE INTO mark AS m |
1434 | | - USING (SELECT ? AS name, ? AS twin, ? AS mark) AS t |
1435 | | - ON m.name = t.name AND m.twin = t.twin |
1436 | | - WHEN MATCHED THEN |
1437 | | - UPDATE SET mark = t.mark |
1438 | | - WHEN NOT MATCHED THEN |
1439 | | - INSERT (name, twin, mark) VALUES (t.name, t.twin, t.mark) |
1440 | | - """, params=[name, tid, mark]) |
1441 | | - logger.debug("[%s] saved mark %s", twin, mark) |
|
0 commit comments