Skip to content

Commit bf4ace3

Browse files
committed
Review
1 parent 07f9169 commit bf4ace3

File tree

4 files changed

+149
-130
lines changed

4 files changed

+149
-130
lines changed

src/rembus/admin.py

Lines changed: 36 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,38 @@ def upsert_twin(lst, twin):
1313
else:
1414
lst.append(twin)
1515

16+
17+
def add_exposer(router, twin, topic):
18+
logger.debug("[%s] adding [%s] exposer for topic [%s]", router, twin, topic)
19+
if topic not in router.exposers:
20+
router.exposers[topic] = []
21+
upsert_twin(router.exposers[topic], twin)
22+
23+
def remove_exposer(router, twin, topic):
24+
if topic in router.exposers:
25+
if twin in router.exposers[topic]:
26+
router.exposers[topic].remove(twin)
27+
logger.debug(
28+
"[%s] removed [%s] exposer for topic [%s]", router, twin, topic
29+
)
30+
31+
32+
def add_subscriber(router, twin, topic):
33+
logger.debug(
34+
"[%s] adding [%s] subscriber for topic [%s]", router, twin, topic
35+
)
36+
if topic not in router.subscribers:
37+
router.subscribers[topic] = []
38+
upsert_twin(router.subscribers[topic], twin)
39+
40+
def remove_subscriber(router, twin, topic):
41+
if topic in router.subscribers:
42+
if twin in router.subscribers[topic]:
43+
router.subscribers[topic].remove(twin)
44+
logger.debug(
45+
"[%s] removed [%s] subscriber for topic [%s]", router, twin, topic
46+
)
47+
1648
async def admin_command(msg: rp.AdminMsg):
1749
"""Handle admin commands"""
1850
twin = msg.twin
@@ -25,32 +57,12 @@ async def admin_command(msg: rp.AdminMsg):
2557
router = twin.router
2658
cmd = msg.data[rp.COMMAND]
2759
if cmd == rp.ADD_IMPL:
28-
logger.debug(
29-
"[%s] adding [%s] exposer for topic [%s]", router, twin, topic
30-
)
31-
if topic not in router.exposers:
32-
router.exposers[topic] = []
33-
34-
upsert_twin(router.exposers[topic], twin)
35-
60+
add_exposer(router, twin, topic)
3661
elif cmd == rp.REMOVE_IMPL:
37-
if topic in router.exposers:
38-
if twin in router.exposers[topic]:
39-
router.exposers[topic].remove(twin)
40-
logger.info("[%s] removed [%s] exposer for topic [%s]", router, twin, topic)
41-
logger.info("exposers: %s", router.exposers)
62+
remove_exposer(router, twin, topic)
4263
elif cmd == rp.ADD_INTEREST:
43-
logger.debug(
44-
"[%s] adding [%s] subscriber for topic [%s]", router, twin, topic
45-
)
46-
if topic not in router.subscribers:
47-
router.subscribers[topic] = []
48-
49-
upsert_twin(router.subscribers[topic], twin)
50-
64+
add_subscriber(router, twin, topic)
5165
elif cmd == rp.REMOVE_INTEREST:
52-
if topic in router.subscribers:
53-
if twin in router.subscribers[topic]:
54-
router.subscribers[topic].remove(twin)
66+
remove_subscriber(router, twin, topic)
5567

5668
await twin.response(rp.STS_OK, msg)

src/rembus/core.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,23 @@ async def _shutdown(self):
407407
if self.db is not None:
408408
self.db.close()
409409

410+
async def _broadcast(self, msg):
411+
twin = msg.twin
412+
data = rp.tag2df(msg.data)
413+
try:
414+
if msg.topic in self.handler:
415+
await self.evaluate(msg.twin, msg.topic, data)
416+
417+
subs = self.subscribers.get(msg.topic, [])
418+
for t in subs:
419+
if t != twin:
420+
# Do not send back to publisher.
421+
await t.send(msg)
422+
423+
except Exception as e: # pylint: disable=broad-exception-caught
424+
logger.warning("[%s] error in method invocation: %s", self, e)
425+
traceback.print_exc()
426+
410427
async def _pubsub_msg(self, msg: rp.PubSubMsg):
411428
msg.recvts = int(time.time())
412429
twin = msg.twin
@@ -423,26 +440,11 @@ async def _pubsub_msg(self, msg: rp.PubSubMsg):
423440
# Save the message id to guarantee exactly one delivery.
424441
twin.ackdf[msg.id] = int(time.time())
425442

426-
data = rp.tag2df(msg.data)
427-
428443
if self.db is not None:
429444
# save the message into msg_cache
430445
self.append_message(msg)
431446

432-
try:
433-
if msg.topic in self.handler:
434-
await self.evaluate(msg.twin, msg.topic, data)
435-
436-
subs = self.subscribers.get(msg.topic, [])
437-
for t in subs:
438-
if t != twin:
439-
# Do not send back to publisher.
440-
await t.send(msg)
441-
442-
except Exception as e: # pylint: disable=broad-exception-caught
443-
logger.warning("[%s] error in method invocation: %s", self, e)
444-
traceback.print_exc()
445-
return
447+
await self._broadcast(msg)
446448

447449
def append_message(self, msg: rp.PubSubMsg):
448450
"""Append a message to the message cache."""

src/rembus/db.py

Lines changed: 71 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ def append(con: duckdb.DuckDBPyConnection, tabledef, msgs):
342342
for msg in msgs:
343343
try:
344344
values = msg.data
345-
# key_value format
345+
# key_value
346346
if fmt == "key_value":
347347
obj = getobj(topic, values)
348348
set_default(msg, tabledef, obj, add_nullable=True)
@@ -356,7 +356,7 @@ def append(con: duckdb.DuckDBPyConnection, tabledef, msgs):
356356

357357
fields = [obj[f] for f in tblfields]
358358

359-
# dataframe format
359+
# dataframe
360360
elif fmt == "dataframe":
361361
df = tag2df(values[0])
362362
if list(df.columns) != tblfields:
@@ -424,90 +424,66 @@ def schema_to_polars(tabledef: Table):
424424
return schema
425425

426426

427-
def upsert(con: duckdb.DuckDBPyConnection, tabledef, messages):
428-
tname = tabledef.table
429-
fmt = tabledef.format
430-
col_names = columns(tabledef) + list(tabledef.extras.values())
431-
indexes = list(tabledef.keys)
432-
records = []
433-
dataframes = []
427+
def add_extras(table, msg, obj):
428+
if "recv_ts" in table.extras:
429+
obj[table.extras["recv_ts"]] = msg.recvts
430+
if "slot" in table.extras:
431+
obj[table.extras["slot"]] = msg.slot
434432

435-
# Build final DataFrame (tdf)
436-
for msg in messages:
437-
try:
438-
values = msg.data
439433

440-
# key_value
441-
if fmt == "key_value":
442-
obj = getobj(tname, values)
443-
set_default(msg, tabledef, obj, add_nullable=True)
434+
def handle_key_value(msg, table, col_names, records, tname):
435+
obj = getobj(tname, msg.data)
436+
set_default(msg, table, obj, add_nullable=True)
437+
add_extras(table, msg, obj)
444438

445-
if "recv_ts" in tabledef.extras:
446-
obj[tabledef.extras["recv_ts"]] = msg.recvts
447-
if "slot" in tabledef.extras:
448-
obj[tabledef.extras["slot"]] = msg.slot
439+
if all(k in obj for k in col_names):
440+
records.append(obj)
441+
else:
442+
logger.warning(
443+
"[%s] unsaved %s missing required fields %s", tname, obj, col_names
444+
)
449445

450-
# Ensure all fields exist
451-
if all(k in obj for k in col_names):
452-
records.append(obj)
453-
else:
454-
logger.warning(
455-
"[%s] unsaved %s missing required fields %s",
456-
tname,
457-
obj,
458-
col_names,
459-
)
460-
continue
461446

462-
# dataframe
463-
elif fmt == "dataframe":
464-
df = tag2df(values[0])
465-
df = df_extras(tabledef, df, msg)
466-
if list(df.columns) == col_names:
467-
dataframes.append(df)
468-
else:
469-
logger.warning("[%s] unsaved df (mismatched fields)", tname)
470-
continue
447+
def handle_dataframe(msg, table, col_names, dataframes, tname):
448+
df = tag2df(msg.data[0])
449+
df = df_extras(table, df, msg)
471450

472-
# default format
473-
else:
474-
vals = values
475-
if len(vals) == len(tabledef.columns):
476-
extra_vals = extras(tabledef, msg)
477-
all_vals = vals + extra_vals
478-
records.append(all_vals)
479-
else:
480-
logger.warning(
481-
"[%s] unsaved %s with mismatched fields", tname, vals
482-
)
483-
continue
484-
except Exception as e:
485-
logger.error("[upsert] %s: %s", tname, e)
451+
if list(df.columns) == col_names:
452+
dataframes.append(df)
453+
else:
454+
logger.warning("[%s] unsaved df (mismatched fields)", tname)
455+
456+
457+
def handle_default(msg, table, col_names, records, tname):
458+
vals = msg.data
459+
if len(vals) != len(table.columns):
460+
logger.warning("[%s] unsaved %s with mismatched fields", tname, vals)
461+
return
462+
463+
extra_vals = extras(table, msg)
464+
records.append(vals + extra_vals)
465+
466+
467+
def execute_upsert(con, table, col_names, indexes, records, dataframes):
468+
tname = table.table
486469

487-
tdf = pl.DataFrame(records, schema=schema_to_polars(tabledef), orient="row")
470+
tdf = pl.DataFrame(records, schema=schema_to_polars(table), orient="row")
488471
tdf = pl.concat([tdf, *dataframes], how="vertical")
472+
489473
if tdf.is_empty():
490474
return
491475

492476
if indexes:
493-
tdf = (
494-
tdf.sort(indexes) # Sort so last row is last
495-
.group_by(indexes)
496-
.agg([pl.all().last()]) # Take last row of each group
497-
)
498-
499-
# logger.debug("[%s] upserting dataframe:\n%s", tname, tdf)
477+
tdf = tdf.sort(indexes).group_by(indexes).agg([pl.all().last()])
500478

501479
con.register("df_view", tdf)
502480

503-
conds = [f"df_view.{k} = {tname}.{k}" for k in indexes]
504-
cond_str = " AND ".join(conds)
505-
481+
cond_str = " AND ".join(f"df_view.{k} = {tname}.{k}" for k in indexes)
506482
col_list = ", ".join(col_names)
507483
val_list = ", ".join(f"df_view.{c}" for c in col_names)
508484

509-
update_columns = [c for c in col_names if c not in indexes]
510-
update_list = ", ".join(f"{c} = df_view.{c}" for c in update_columns)
485+
update_cols = [c for c in col_names if c not in indexes]
486+
update_list = ", ".join(f"{c} = df_view.{c}" for c in update_cols)
511487

512488
sql = f"""
513489
MERGE INTO {tname}
@@ -521,6 +497,33 @@ def upsert(con: duckdb.DuckDBPyConnection, tabledef, messages):
521497
con.unregister("df_view")
522498

523499

500+
def upsert(con: duckdb.DuckDBPyConnection, table, messages):
501+
"""Insert/update a record in a table with keywords."""
502+
tname = table.table
503+
fmt = table.format
504+
col_names = columns(table) + list(table.extras.values())
505+
indexes = list(table.keys)
506+
507+
records = []
508+
dataframes = []
509+
510+
for msg in messages:
511+
try:
512+
if fmt == "key_value":
513+
handle_key_value(msg, table, col_names, records, tname)
514+
515+
elif fmt == "dataframe":
516+
handle_dataframe(msg, table, col_names, dataframes, tname)
517+
518+
else:
519+
handle_default(msg, table, col_names, records, tname)
520+
521+
except Exception as e: # pylint: disable=broad-except
522+
logger.error("[upsert] %s: %s", tname, e)
523+
524+
execute_upsert(con, table, col_names, indexes, records, dataframes)
525+
526+
524527
def expand(msg: PubSubMsg):
525528
"""
526529
Given a message `msg` containing a topic string

src/rembus/keyspace.py

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,28 @@ async def unsubscribe_handler(self, msg):
7474
if el in self.spaces[retopic]:
7575
self.spaces[retopic].remove(el)
7676

77+
async def broadcast(self, topic, msg, space_twins):
78+
"""Brodcast the message to all subscribed spaces."""
79+
to_remove = []
80+
for space_twin in space_twins:
81+
pattern = space_twin.space
82+
logger.debug(
83+
"[keyspace] publish to %s: %s, %s",
84+
pattern,
85+
topic,
86+
msg.data,
87+
)
88+
twid = space_twin.twid
89+
if twid in self.broker.id_twin:
90+
tw = self.broker.id_twin[twid]
91+
if tw.isopen():
92+
await tw.publish(pattern, topic, *msg.data)
93+
else:
94+
# cleanup, the twin has gone
95+
to_remove.append(space_twin)
96+
for el in to_remove:
97+
space_twins.remove(el)
98+
7799
async def publish_interceptor(self, msg):
78100
"""Parse the topic and dispatch to all twins subscribed to spaces
79101
with regex matching the topic"""
@@ -84,35 +106,15 @@ async def publish_interceptor(self, msg):
84106
if m is not None:
85107
unsealed = True
86108

87-
# m.groups() == captures in Julia
88109
for capture in m.groups():
89110
if "@" in capture:
90111
# a regex chunk matched a verbatim chunk → reject
91112
unsealed = False
92113
break
93114

94115
if unsealed:
95-
to_remove = []
96-
# for pattern, twid in space_twins:
97-
for space_twin in space_twins:
98-
pattern = space_twin.space
99-
logger.debug(
100-
"[keyspace] publish to %s: %s, %s",
101-
pattern,
102-
topic,
103-
msg.data,
104-
)
105-
twid = space_twin.twid
106-
if twid in self.broker.id_twin:
107-
tw = self.broker.id_twin[twid]
108-
if tw.isopen():
109-
await tw.publish(pattern, topic, *msg.data)
110-
else:
111-
# cleanup, the twin has gone
112-
to_remove.append(space_twin)
113-
114-
for el in to_remove:
115-
space_twins.remove(el)
116+
await self.broadcast(topic, msg, space_twins)
117+
116118

117119
async def _task_impl(self) -> None:
118120
"""Override in subclasses for supervised task impl."""

0 commit comments

Comments
 (0)