use a modular invite policy

master
histausse 3 years ago
parent 19cc22a479
commit 89120454ad
Signed by: histausse
GPG Key ID: 67486F107F62E9E9

@ -14,6 +14,11 @@ from typing import (
Union Union
) )
from .async_utils import Aobject from .async_utils import Aobject
from .invite_policy import (
DeclineAll,
InvitePolicy,
WhiteList
)
from .utils import ( from .utils import (
Room, Room,
RoomAlias, RoomAlias,
@ -27,10 +32,6 @@ class Client(Aobject):
Connect to the matrix server and handle interactions with the Connect to the matrix server and handle interactions with the
server. server.
whitelist_rooms: dict of the rooms where the bot is allowed to connect, indexed
by id (the name starting with '!'). If set to None, the bot connect to
all room where it is invited.
/!\ The client is initialized asyncronously: `client = await Client(...)` /!\ The client is initialized asyncronously: `client = await Client(...)`
""" """
@ -41,14 +42,14 @@ class Client(Aobject):
__sync_token:Optional[str] __sync_token:Optional[str]
__sync_token_queue: asyncio.Queue[str] __sync_token_queue: asyncio.Queue[str]
__invite_queue: asyncio.Queue[tuple[RoomId, nio.responses.InviteInfo]] __invite_queue: asyncio.Queue[tuple[RoomId, nio.responses.InviteInfo]]
whitelist_rooms: Optional[dict[RoomId, Room]] __invite_policy: InvitePolicy
async def __init__( async def __init__(
self, self,
username: str, username: str,
homeserver: str, homeserver: str,
password: str, password: str,
whitelist_rooms_names: Optional[list[Union[RoomAlias, RoomId]]]=None, invite_policy: Optional[InvitePolicy]=None,
sync_token_file:str="./.sync_token", sync_token_file:str="./.sync_token",
): ):
""" """
@ -56,9 +57,7 @@ class Client(Aobject):
username: the username used by the bot username: the username used by the bot
homeserver: the matrix home server of the bot (expl: "https://matrix.org") homeserver: the matrix home server of the bot (expl: "https://matrix.org")
password: the password of the user password: the password of the user
whitelist_rooms: the list of the rooms where the bot is allowed to connect invite_policy: the policy to apply to invitation, default is to decline all.
(given by room id (expl: '!xxx:matrix.org') of room alias (expl:
'#xxx:matrix.org'))
sync_token_file: the file where is stored the last sync token received from sync_token_file: the file where is stored the last sync token received from
the sync responses. This token avoid reloadind all the history of the the sync responses. This token avoid reloadind all the history of the
bot each time we start it. bot each time we start it.
@ -77,18 +76,12 @@ class Client(Aobject):
self.__sync_token = None self.__sync_token = None
self.__sync_token_queue = asyncio.Queue() self.__sync_token_queue = asyncio.Queue()
self.__invite_queue = asyncio.Queue() self.__invite_queue = asyncio.Queue()
self.__invite_policy = invite_policy or DeclineAll()
resp = await self.__client.login(password) resp = await self.__client.login(password)
if isinstance(resp, nio.responses.LoginError): if isinstance(resp, nio.responses.LoginError):
raise RuntimeError(f"Fail to connect: {resp.message}") raise RuntimeError(f"Fail to connect: {resp.message}")
log.info("logged in") log.info("logged in")
self.whitelist_rooms = None
if whitelist_rooms_names:
self.whitelist_rooms = {}
rooms = await asyncio.gather(*(self.resolve_room(room_name) for room_name in whitelist_rooms_names))
for room in rooms:
self.whitelist_rooms[room.id] = room # room uniqueness is handled by self.resolve_room
async def resolve_room( async def resolve_room(
self, self,
room_name: Union[RoomAlias, RoomId] room_name: Union[RoomAlias, RoomId]
@ -131,6 +124,15 @@ class Client(Aobject):
self.__rooms_by_aliases[room_name] = room self.__rooms_by_aliases[room_name] = room
return room return room
def set_invite_policy(
self,
invite_policy: InvitePolicy
):
"""
Set the invite policy.
"""
self.__invite_policy = invite_policy
async def send_message( async def send_message(
self, self,
room: Union[RoomAlias, RoomId], room: Union[RoomAlias, RoomId],
@ -160,17 +162,7 @@ class Client(Aobject):
room_id, invite_info = await self.__invite_queue.get() room_id, invite_info = await self.__invite_queue.get()
accept_invite = False accept_invite = False
if self.whitelist_rooms is not None: if (await self.__invite_policy.accept_invite(room_id, invite_info)):
if room_id not in self.whitelist_rooms:
log.warning(f"Received invite for {room_id}, but room_id is not in the white list.")
else:
accept_invite = True
log.info(f"Received invite for {room_id}: invite accepted.")
else:
accept_invite = True
log.info(f"Received invite for {room_id}: invite accepted.")
if accept_invite:
result = await self.__client.join(room_id) result = await self.__client.join(room_id)
if isinstance(result, nio.JoinError): if isinstance(result, nio.JoinError):
log.warning(f"Error while joinning room {room_id}: {result.message}") log.warning(f"Error while joinning room {room_id}: {result.message}")

@ -1,14 +1,32 @@
from __future__ import annotations
""" """
InvitePolicy class: InvitePolicy class:
InvitePolicy object are use to chose whether to accept or decline InvitePolicy object are use to chose whether to accept or decline
an invite to a room. an invite to a room.
""" """
import asyncio
import nio
import logging
from abc import ( from abc import (
ABC, ABC,
abstractmethod abstractmethod
) )
import nio from typing import (
Union,
TYPE_CHECKING
)
if TYPE_CHECKING:
from .client import Client
from .utils import (
Room,
RoomAlias,
RoomId
)
from .async_utils import Aobject
log = logging.getLogger(__name__)
class InvitePolicy(ABC): class InvitePolicy(ABC):
""" """
@ -16,10 +34,16 @@ class InvitePolicy(ABC):
""" """
@abstractmethod @abstractmethod
await def accept_invite( async def accept_invite(
self, self,
room_id: RoomId,
invite: nio.responses.InviteInfo invite: nio.responses.InviteInfo
)->bool: )->bool:
"""
Test if the invit must be accepted of declined.
Async because the policy might want to do exotic
stuff, like, idk, send a pm to someone to ask confirmation
"""
pass pass
class DeclineAll(InvitePolicy): class DeclineAll(InvitePolicy):
@ -27,8 +51,9 @@ class DeclineAll(InvitePolicy):
Decline all invitations. Decline all invitations.
""" """
await def accept_invite( async def accept_invite(
self, self,
room_id: RoomId,
invite: nio.responses.InviteInfo invite: nio.responses.InviteInfo
)->bool: )->bool:
return False return False
@ -38,8 +63,51 @@ class AcceptAll(InvitePolicy):
Accept all invitations. Accept all invitations.
""" """
await def accept_invite( async def accept_invite(
self, self,
room_id: RoomId,
invite: nio.responses.InviteInfo invite: nio.responses.InviteInfo
)->bool: )->bool:
return True return True
class WhiteList(InvitePolicy, Aobject):
"""
Accept invite for whitelisted room.
This policy cannot be set during the initialization of the
client because we need the client to initialize this object:
use `client.set_invite_policy(policy)`
"""
whitelisted_rooms: dict[RoomId, RoomAlias]
async def __init__(
self,
client: Client,
whitelisted_rooms_names: list[Union[RoomAlias, RoomId]]
):
"""
client: the matrix client (cf client.py)
whitelisted_rooms: the list of the rooms where the bot is allowed to connect
(given by room id (expl: '!xxx:matrix.org') or room alias (expl:
'#xxx:matrix.org'))
"""
self.whitelisted_rooms = {}
rooms = await asyncio.gather(*(client.resolve_room(room_name) for room_name in whitelisted_rooms_names))
for room in rooms:
self.whitelisted_rooms[room.id] = room
async def accept_invite(
self,
room_id: RoomId,
invite: nio.responses.InviteInfo
)->bool:
"""
Accept invite from whitelisted room, reject the other.
"""
if room_id in self.whitelisted_rooms:
log.info(f"Invite from {room_id} accepted.")
return True
else:
log.warning(f"Invite from {room_id} declined: not in whitelist.")
return False

@ -7,6 +7,11 @@ import logging
import os import os
from dotenv import load_dotenv from dotenv import load_dotenv
from matrix_bot.client import Client from matrix_bot.client import Client
from matrix_bot.invite_policy import (
AcceptAll,
DeclineAll,
WhiteList
)
from getpass import getpass from getpass import getpass
logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.DEBUG) logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.DEBUG)
@ -17,9 +22,11 @@ async def main():
os.environ["MUSER"], os.environ["MUSER"],
os.environ["HOMESERVER"], os.environ["HOMESERVER"],
os.environ["PASSWD"], os.environ["PASSWD"],
os.environ["ROOMS"].split(",")
) )
room_name = os.environ["ROOMS"].split(",")[0] withelisted_room_names = os.environ["ROOMS"].split(",")
whitelist_policy = await WhiteList(client, withelisted_room_names)
room_name = withelisted_room_names[0]
client.set_invite_policy(whitelist_policy)
await client.send_message(room_name, "Beware of Greeks bearing gifts") await client.send_message(room_name, "Beware of Greeks bearing gifts")
await client.run() await client.run()

Loading…
Cancel
Save