Skip to content

Commit 829bee5

Browse files
committed
Add support for creator power
1 parent dbf6425 commit 829bee5

File tree

9 files changed

+182
-23
lines changed

9 files changed

+182
-23
lines changed

mautrix/appservice/api/intent.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -710,6 +710,8 @@ async def _ensure_has_power_level_for(
710710
if not await self.state_store.has_power_levels_cached(room_id):
711711
# TODO add option to not try to fetch power levels from server
712712
await self.get_power_levels(room_id, ignore_cache=True, ensure_joined=False)
713+
if not await self.state_store.has_create_cached(room_id):
714+
await self.get_state_event(room_id, EventType.ROOM_CREATE, format="event")
713715
if not await self.state_store.has_power_level(room_id, self.mxid, event_type):
714716
# TODO implement something better
715717
raise IntentError(

mautrix/client/api/events.py

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
66
from __future__ import annotations
77

8-
from typing import Awaitable
8+
from typing import Awaitable, Literal, overload
99
import json
1010

1111
from mautrix.api import Method, Path
@@ -168,12 +168,41 @@ async def get_event_context(
168168
)
169169
return EventContext.deserialize(resp)
170170

171+
@overload
171172
async def get_state_event(
172173
self,
173174
room_id: RoomID,
174175
event_type: EventType,
175176
state_key: str = "",
176-
) -> StateEventContent:
177+
*,
178+
format: Literal["content"] = "content",
179+
) -> StateEventContent: ...
180+
@overload
181+
async def get_state_event(
182+
self,
183+
room_id: RoomID,
184+
event_type: EventType,
185+
state_key: str = "",
186+
*,
187+
format: Literal["event"],
188+
) -> StateEvent: ...
189+
@overload
190+
async def get_state_event(
191+
self,
192+
room_id: RoomID,
193+
event_type: EventType,
194+
state_key: str = "",
195+
*,
196+
format: str = "content",
197+
) -> StateEventContent | StateEvent: ...
198+
async def get_state_event(
199+
self,
200+
room_id: RoomID,
201+
event_type: EventType,
202+
state_key: str = "",
203+
*,
204+
format: str = "content",
205+
) -> StateEventContent | StateEvent:
177206
"""
178207
Looks up the contents of a state event in a room. If the user is joined to the room then the
179208
state is taken from the current state of the room. If the user has left the room then the
@@ -185,18 +214,27 @@ async def get_state_event(
185214
room_id: The ID of the room to look up the state in.
186215
event_type: The type of state to look up.
187216
state_key: The key of the state to look up. Defaults to empty string.
217+
format: The format of the state event to return. Defaults to "content", which only returns
218+
the content of the state event. If set to "event", the full event is returned.
219+
See https://github.com/matrix-org/matrix-spec/issues/1047 for more info.
188220
189221
Returns:
190222
The state event.
191223
"""
192224
content = await self.api.request(
193225
Method.GET,
194226
Path.v3.rooms[room_id].state[event_type][state_key],
227+
query_params={"format": format} if format != "content" else None,
195228
metrics_method="getStateEvent",
196229
)
197230
content["__mautrix_event_type"] = event_type
198231
try:
199-
return StateEvent.deserialize_content(content)
232+
if format == "content":
233+
return StateEvent.deserialize_content(content)
234+
elif format == "event":
235+
return StateEvent.deserialize(content)
236+
else:
237+
return content
200238
except SerializerError as e:
201239
raise MatrixResponseError("Invalid state event in response") from e
202240

mautrix/client/state_store/abstract.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,18 @@ async def set_power_levels(
121121
) -> None:
122122
pass
123123

124+
@abstractmethod
125+
async def has_create_cached(self, room_id: RoomID) -> bool:
126+
pass
127+
128+
@abstractmethod
129+
async def get_create(self, room_id: RoomID) -> StateEvent | None:
130+
pass
131+
132+
@abstractmethod
133+
async def set_create(self, event: StateEvent) -> None:
134+
pass
135+
124136
@abstractmethod
125137
async def has_encryption_info_cached(self, room_id: RoomID) -> bool:
126138
pass
@@ -135,7 +147,7 @@ async def get_encryption_info(self, room_id: RoomID) -> RoomEncryptionStateEvent
135147

136148
@abstractmethod
137149
async def set_encryption_info(
138-
self, room_id: RoomID, content: RoomEncryptionStateEventContent | dict[str, any]
150+
self, room_id: RoomID, content: RoomEncryptionStateEventContent | dict[str, Any]
139151
) -> None:
140152
pass
141153

@@ -149,6 +161,8 @@ async def update_state(self, evt: StateEvent) -> None:
149161
await self.set_member(evt.room_id, UserID(evt.state_key), evt.content)
150162
elif evt.type == EventType.ROOM_ENCRYPTION:
151163
await self.set_encryption_info(evt.room_id, evt.content)
164+
elif evt.type == EventType.ROOM_CREATE and evt.sender:
165+
await self.set_create(evt)
152166

153167
async def get_membership(self, room_id: RoomID, user_id: UserID) -> Membership:
154168
member = await self.get_member(room_id, user_id)
@@ -172,4 +186,7 @@ async def has_power_level(
172186
room_levels = await self.get_power_levels(room_id)
173187
if not room_levels:
174188
return None
175-
return room_levels.get_user_level(user_id) >= room_levels.get_event_level(event_type)
189+
create_event = await self.get_create(room_id)
190+
return room_levels.get_user_level(user_id, create_event) >= room_levels.get_event_level(
191+
event_type
192+
)

mautrix/client/state_store/asyncpg/store.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
RoomEncryptionStateEventContent,
1717
RoomID,
1818
Serializable,
19+
StateEvent,
1920
UserID,
2021
)
2122
from mautrix.util.async_db import Database, Scheme
@@ -223,6 +224,29 @@ async def set_power_levels(
223224
json.dumps(content.serialize() if isinstance(content, Serializable) else content),
224225
)
225226

227+
async def has_create_cached(self, room_id: RoomID) -> bool:
228+
return bool(
229+
await self.db.fetchval(
230+
"SELECT create_event IS NOT NULL FROM mx_room_state WHERE room_id=$1", room_id
231+
)
232+
)
233+
234+
async def get_create(self, room_id: RoomID) -> StateEvent | None:
235+
create_event_json = await self.db.fetchval(
236+
"SELECT create_event FROM mx_room_state WHERE room_id=$1", room_id
237+
)
238+
if create_event_json is None:
239+
return None
240+
return StateEvent.parse_json(create_event_json)
241+
242+
async def set_create(self, event: StateEvent) -> None:
243+
await self.db.execute(
244+
"INSERT INTO mx_room_state (room_id, create_event) VALUES ($1, $2) "
245+
"ON CONFLICT (room_id) DO UPDATE SET create_event=$2",
246+
event.room_id,
247+
json.dumps(event.serialize() if isinstance(event, Serializable) else event),
248+
)
249+
226250
async def has_encryption_info_cached(self, room_id: RoomID) -> bool:
227251
return bool(
228252
await self.db.fetchval(

mautrix/client/state_store/asyncpg/upgrade.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,16 @@
1414
)
1515

1616

17-
@upgrade_table.register(description="Latest revision", upgrades_to=3)
18-
async def upgrade_blank_to_v3(conn: Connection, scheme: Scheme) -> None:
17+
@upgrade_table.register(description="Latest revision", upgrades_to=4)
18+
async def upgrade_blank_to_v4(conn: Connection, scheme: Scheme) -> None:
1919
await conn.execute(
2020
"""CREATE TABLE mx_room_state (
2121
room_id TEXT PRIMARY KEY,
2222
is_encrypted BOOLEAN,
2323
has_full_member_list BOOLEAN,
2424
encryption TEXT,
25-
power_levels TEXT
25+
power_levels TEXT,
26+
create_event TEXT,
2627
)"""
2728
)
2829
membership_check = ""
@@ -69,3 +70,8 @@ async def upgrade_v3(conn: Connection) -> None:
6970
WHERE mx_room_state.encryption IS NULL
7071
"""
7172
)
73+
74+
75+
@upgrade_table.register(description="Add create event to room state cache")
76+
async def upgrade_v4(conn: Connection) -> None:
77+
await conn.execute("ALTER TABLE mx_room_state ADD COLUMN create_event TYPE TEXT")

mautrix/client/state_store/file.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
PowerLevelStateEventContent,
1616
RoomEncryptionStateEventContent,
1717
RoomID,
18+
StateEvent,
1819
UserID,
1920
)
2021
from mautrix.util.file_store import Filer, FileStore
@@ -65,3 +66,7 @@ async def set_power_levels(
6566
) -> None:
6667
await super().set_power_levels(room_id, content)
6768
self._time_limited_flush()
69+
70+
async def set_create(self, event: StateEvent) -> None:
71+
await super().set_create(event)
72+
self._time_limited_flush()

mautrix/client/state_store/memory.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
PowerLevelStateEventContent,
1515
RoomEncryptionStateEventContent,
1616
RoomID,
17+
StateEvent,
1718
UserID,
1819
)
1920

@@ -25,19 +26,22 @@ class SerializedStateStore(TypedDict):
2526
full_member_list: dict[RoomID, bool]
2627
power_levels: dict[RoomID, Any]
2728
encryption: dict[RoomID, Any]
29+
create: dict[RoomID, Any]
2830

2931

3032
class MemoryStateStore(StateStore):
3133
members: dict[RoomID, dict[UserID, Member]]
3234
full_member_list: dict[RoomID, bool]
3335
power_levels: dict[RoomID, PowerLevelStateEventContent]
3436
encryption: dict[RoomID, RoomEncryptionStateEventContent | None]
37+
create: dict[RoomID, StateEvent]
3538

3639
def __init__(self) -> None:
3740
self.members = {}
3841
self.full_member_list = {}
3942
self.power_levels = {}
4043
self.encryption = {}
44+
self.create = {}
4145

4246
def serialize(self) -> SerializedStateStore:
4347
"""
@@ -58,6 +62,7 @@ def serialize(self) -> SerializedStateStore:
5862
room_id: (content.serialize() if content is not None else None)
5963
for room_id, content in self.encryption.items()
6064
},
65+
"create": {room_id: evt.serialize() for room_id, evt in self.create.items()},
6166
}
6267

6368
def deserialize(self, data: SerializedStateStore) -> None:
@@ -84,6 +89,9 @@ def deserialize(self, data: SerializedStateStore) -> None:
8489
)
8590
for room_id, content in data["encryption"].items()
8691
}
92+
self.create = {
93+
room_id: StateEvent.deserialize(evt) for room_id, evt in data["create"].items()
94+
}
8795

8896
async def get_member(self, room_id: RoomID, user_id: UserID) -> Member | None:
8997
try:
@@ -176,6 +184,17 @@ async def set_power_levels(
176184
content = PowerLevelStateEventContent.deserialize(content)
177185
self.power_levels[room_id] = content
178186

187+
async def has_create_cached(self, room_id: RoomID) -> bool:
188+
return room_id in self.create
189+
190+
async def get_create(self, room_id: RoomID) -> StateEvent | None:
191+
return self.create.get(room_id)
192+
193+
async def set_create(self, event: StateEvent | dict[str, Any]) -> None:
194+
if not isinstance(event, StateEvent):
195+
event = StateEvent.deserialize(event)
196+
self.create[event.room_id] = event
197+
179198
async def has_encryption_info_cached(self, room_id: RoomID) -> bool:
180199
return room_id in self.encryption
181200

mautrix/client/store_updater.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
66
from __future__ import annotations
77

8+
from typing import Literal
89
import asyncio
910

1011
from mautrix.errors import MForbidden, MNotFound
@@ -196,20 +197,28 @@ async def send_state_event(
196197
return event_id
197198

198199
async def get_state_event(
199-
self, room_id: RoomID, event_type: EventType, state_key: str = ""
200-
) -> StateEventContent:
201-
event = await super().get_state_event(room_id, event_type, state_key)
200+
self,
201+
room_id: RoomID,
202+
event_type: EventType,
203+
state_key: str = "",
204+
*,
205+
format: str = "content",
206+
) -> StateEventContent | StateEvent:
207+
event = await super().get_state_event(room_id, event_type, state_key, format=format)
202208
if self.state_store:
203-
fake_event = StateEvent(
204-
type=event_type,
205-
room_id=room_id,
206-
event_id=EventID(""),
207-
sender=UserID(""),
208-
state_key=state_key,
209-
timestamp=0,
210-
content=event,
211-
)
212-
await self.state_store.update_state(fake_event)
209+
if isinstance(event, StateEvent):
210+
await self.state_store.update_state(event)
211+
else:
212+
fake_event = StateEvent(
213+
type=event_type,
214+
room_id=room_id,
215+
event_id=EventID(""),
216+
sender=UserID(""),
217+
state_key=state_key,
218+
timestamp=0,
219+
content=event,
220+
)
221+
await self.state_store.update_state(fake_event)
213222
return event
214223

215224
async def get_joined_members(self, room_id: RoomID) -> dict[UserID, Member]:

mautrix/types/event/state.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,19 @@ class PowerLevelStateEventContent(SerializableAttrs):
4141
ban: int = 50
4242
redact: int = 50
4343

44-
def get_user_level(self, user_id: UserID) -> int:
44+
def get_user_level(
45+
self,
46+
user_id: UserID,
47+
create: Optional["StateEvent"] = None,
48+
) -> int:
49+
if (
50+
create
51+
and create.content.supports_creator_power
52+
and (user_id == create.sender or user_id in (create.content.additional_creators or []))
53+
):
54+
# This is really meant to be infinity, but involving floats would be annoying,
55+
# so we use an integer larger than the maximum power level (2^53-1) instead.
56+
return 2**60 - 1
4557
return int(self.users.get(user_id, self.users_default))
4658

4759
def set_user_level(self, user_id: UserID, level: int) -> None:
@@ -50,7 +62,16 @@ def set_user_level(self, user_id: UserID, level: int) -> None:
5062
else:
5163
self.users[user_id] = level
5264

53-
def ensure_user_level(self, user_id: UserID, level: int) -> bool:
65+
def ensure_user_level(
66+
self, user_id: UserID, level: int, create: Optional["StateEvent"] = None
67+
) -> bool:
68+
if (
69+
create
70+
and create.content.supports_creator_power
71+
and (user_id == create.sender or user_id in (create.content.additional_creators or []))
72+
):
73+
# Don't try to set creator power levels
74+
return False
5475
if self.get_user_level(user_id) != level:
5576
self.set_user_level(user_id, level)
5677
return True
@@ -193,6 +214,24 @@ class RoomCreateStateEventContent(SerializableAttrs):
193214
federate: bool = field(json="m.federate", omit_default=True, default=True)
194215
predecessor: Optional[RoomPredecessor] = None
195216
type: Optional[RoomType] = None
217+
additional_creators: Optional[List[UserID]] = None
218+
219+
@property
220+
def supports_creator_power(self) -> bool:
221+
return self.room_version not in (
222+
"",
223+
"1",
224+
"2",
225+
"3",
226+
"4",
227+
"5",
228+
"6",
229+
"7",
230+
"8",
231+
"9",
232+
"10",
233+
"11",
234+
)
196235

197236

198237
@dataclass

0 commit comments

Comments
 (0)