diff --git a/src/matrix_bot/client.py b/src/matrix_bot/client.py index 3fee1d6..ae64e2c 100644 --- a/src/matrix_bot/client.py +++ b/src/matrix_bot/client.py @@ -14,6 +14,11 @@ from typing import ( Union ) from .async_utils import Aobject +from .invite_policy import ( + DeclineAll, + InvitePolicy, + WhiteList +) from .utils import ( Room, RoomAlias, @@ -27,10 +32,6 @@ class Client(Aobject): Connect to the matrix server and handle interactions with the server. - whitelist_rooms: dict of the rooms where the bot is allowed to connect, indexed - by id (the name starting with '!'). If set to None, the bot connect to - all room where it is invited. - /!\ The client is initialized asyncronously: `client = await Client(...)` """ @@ -41,14 +42,14 @@ class Client(Aobject): __sync_token:Optional[str] __sync_token_queue: asyncio.Queue[str] __invite_queue: asyncio.Queue[tuple[RoomId, nio.responses.InviteInfo]] - whitelist_rooms: Optional[dict[RoomId, Room]] + __invite_policy: InvitePolicy async def __init__( self, username: str, homeserver: str, password: str, - whitelist_rooms_names: Optional[list[Union[RoomAlias, RoomId]]]=None, + invite_policy: Optional[InvitePolicy]=None, sync_token_file:str="./.sync_token", ): """ @@ -56,9 +57,7 @@ class Client(Aobject): username: the username used by the bot homeserver: the matrix home server of the bot (expl: "https://matrix.org") password: the password of the user - whitelist_rooms: the list of the rooms where the bot is allowed to connect - (given by room id (expl: '!xxx:matrix.org') of room alias (expl: - '#xxx:matrix.org')) + invite_policy: the policy to apply to invitation, default is to decline all. sync_token_file: the file where is stored the last sync token received from the sync responses. This token avoid reloadind all the history of the bot each time we start it. @@ -77,18 +76,12 @@ class Client(Aobject): self.__sync_token = None self.__sync_token_queue = asyncio.Queue() self.__invite_queue = asyncio.Queue() + self.__invite_policy = invite_policy or DeclineAll() resp = await self.__client.login(password) if isinstance(resp, nio.responses.LoginError): raise RuntimeError(f"Fail to connect: {resp.message}") log.info("logged in") - self.whitelist_rooms = None - if whitelist_rooms_names: - self.whitelist_rooms = {} - rooms = await asyncio.gather(*(self.resolve_room(room_name) for room_name in whitelist_rooms_names)) - for room in rooms: - self.whitelist_rooms[room.id] = room # room uniqueness is handled by self.resolve_room - async def resolve_room( self, room_name: Union[RoomAlias, RoomId] @@ -131,6 +124,15 @@ class Client(Aobject): self.__rooms_by_aliases[room_name] = room return room + def set_invite_policy( + self, + invite_policy: InvitePolicy + ): + """ + Set the invite policy. + """ + self.__invite_policy = invite_policy + async def send_message( self, room: Union[RoomAlias, RoomId], @@ -160,17 +162,7 @@ class Client(Aobject): room_id, invite_info = await self.__invite_queue.get() accept_invite = False - if self.whitelist_rooms is not None: - if room_id not in self.whitelist_rooms: - log.warning(f"Received invite for {room_id}, but room_id is not in the white list.") - else: - accept_invite = True - log.info(f"Received invite for {room_id}: invite accepted.") - else: - accept_invite = True - log.info(f"Received invite for {room_id}: invite accepted.") - - if accept_invite: + if (await self.__invite_policy.accept_invite(room_id, invite_info)): result = await self.__client.join(room_id) if isinstance(result, nio.JoinError): log.warning(f"Error while joinning room {room_id}: {result.message}") diff --git a/src/matrix_bot/invite_policy.py b/src/matrix_bot/invite_policy.py index 4c1588f..03a47b4 100644 --- a/src/matrix_bot/invite_policy.py +++ b/src/matrix_bot/invite_policy.py @@ -1,14 +1,32 @@ +from __future__ import annotations """ InvitePolicy class: InvitePolicy object are use to chose whether to accept or decline an invite to a room. """ +import asyncio +import nio +import logging + from abc import ( ABC, abstractmethod ) -import nio +from typing import ( + Union, + TYPE_CHECKING +) +if TYPE_CHECKING: + from .client import Client +from .utils import ( + Room, + RoomAlias, + RoomId +) +from .async_utils import Aobject + +log = logging.getLogger(__name__) class InvitePolicy(ABC): """ @@ -16,10 +34,16 @@ class InvitePolicy(ABC): """ @abstractmethod - await def accept_invite( + async def accept_invite( self, + room_id: RoomId, invite: nio.responses.InviteInfo )->bool: + """ + Test if the invit must be accepted of declined. + Async because the policy might want to do exotic + stuff, like, idk, send a pm to someone to ask confirmation + """ pass class DeclineAll(InvitePolicy): @@ -27,8 +51,9 @@ class DeclineAll(InvitePolicy): Decline all invitations. """ - await def accept_invite( + async def accept_invite( self, + room_id: RoomId, invite: nio.responses.InviteInfo )->bool: return False @@ -38,8 +63,51 @@ class AcceptAll(InvitePolicy): Accept all invitations. """ - await def accept_invite( + async def accept_invite( self, + room_id: RoomId, invite: nio.responses.InviteInfo )->bool: return True + +class WhiteList(InvitePolicy, Aobject): + """ + Accept invite for whitelisted room. + This policy cannot be set during the initialization of the + client because we need the client to initialize this object: + use `client.set_invite_policy(policy)` + """ + + whitelisted_rooms: dict[RoomId, RoomAlias] + + async def __init__( + self, + client: Client, + whitelisted_rooms_names: list[Union[RoomAlias, RoomId]] + ): + """ + client: the matrix client (cf client.py) + whitelisted_rooms: the list of the rooms where the bot is allowed to connect + (given by room id (expl: '!xxx:matrix.org') or room alias (expl: + '#xxx:matrix.org')) + """ + self.whitelisted_rooms = {} + rooms = await asyncio.gather(*(client.resolve_room(room_name) for room_name in whitelisted_rooms_names)) + for room in rooms: + self.whitelisted_rooms[room.id] = room + + async def accept_invite( + self, + room_id: RoomId, + invite: nio.responses.InviteInfo + )->bool: + """ + Accept invite from whitelisted room, reject the other. + """ + if room_id in self.whitelisted_rooms: + log.info(f"Invite from {room_id} accepted.") + return True + else: + log.warning(f"Invite from {room_id} declined: not in whitelist.") + return False + diff --git a/tests/test.py b/tests/test.py index d24a5b1..335c136 100644 --- a/tests/test.py +++ b/tests/test.py @@ -7,6 +7,11 @@ import logging import os from dotenv import load_dotenv from matrix_bot.client import Client +from matrix_bot.invite_policy import ( + AcceptAll, + DeclineAll, + WhiteList +) from getpass import getpass logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.DEBUG) @@ -17,9 +22,11 @@ async def main(): os.environ["MUSER"], os.environ["HOMESERVER"], os.environ["PASSWD"], - os.environ["ROOMS"].split(",") ) - room_name = os.environ["ROOMS"].split(",")[0] + withelisted_room_names = os.environ["ROOMS"].split(",") + whitelist_policy = await WhiteList(client, withelisted_room_names) + room_name = withelisted_room_names[0] + client.set_invite_policy(whitelist_policy) await client.send_message(room_name, "Beware of Greeks bearing gifts") await client.run()