Stuff for multi-user support

This commit is contained in:
Andrew Ferrazzutti 2021-02-28 01:50:14 -05:00
parent 6766ef5d55
commit 3438179d48
2 changed files with 41 additions and 13 deletions

View File

@ -46,3 +46,12 @@ async def upgrade_v1(conn: Connection) -> None:
UNIQUE (mxid, mx_room) UNIQUE (mxid, mx_room)
)""") )""")
@upgrade_table.register(description="Multi-user support")
async def upgrade_v2(conn: Connection) -> None:
mid_exists = await conn.fetchval(
'SELECT EXISTS(SELECT FROM information_schema.columns '
'WHERE table_name="user" AND column_name="mid"')
if not mid_exists:
await conn.execute("ALTER TABLE 'user' ADD COLUMN mid TEXT")

View File

@ -13,8 +13,9 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional, ClassVar, TYPE_CHECKING from typing import Optional, List, TYPE_CHECKING, ClassVar
from asyncpg import Record
from attr import dataclass from attr import dataclass
from mautrix.types import UserID, RoomID from mautrix.types import UserID, RoomID
@ -28,22 +29,40 @@ class User:
db: ClassVar[Database] = fake_db db: ClassVar[Database] = fake_db
mxid: UserID mxid: UserID
mid: Optional[str]
notice_room: Optional[RoomID] notice_room: Optional[RoomID]
async def insert(self) -> None: @classmethod
q = ('INSERT INTO "user" (mxid, notice_room) ' def _from_row(cls, row: Optional[Record]) -> Optional['User']:
'VALUES ($1, $2)') if row is None:
await self.db.execute(q, self.mxid, self.notice_room) return None
return cls(**row)
async def update(self) -> None: @classmethod
await self.db.execute('UPDATE "user" SET notice_room=$2 WHERE mxid=$1', async def all_logged_in(cls) -> List['User']:
self.mxid, self.notice_room) rows = await cls.db.fetch('SELECT mxid, mid, notice_room FROM "user" '
"WHERE mid IS NOT NULL")
return [cls._from_row(row) for row in rows]
@classmethod
async def get_by_mid(cls, mid: str) -> Optional['User']:
q = 'SELECT mxid, mid, notice_room FROM "user" WHERE mid=$1'
row = await cls.db.fetchrow(q, mid)
return cls._from_row(row)
@classmethod @classmethod
async def get_by_mxid(cls, mxid: UserID) -> Optional['User']: async def get_by_mxid(cls, mxid: UserID) -> Optional['User']:
q = ("SELECT mxid, notice_room " q = 'SELECT mxid, mid, notice_room FROM "user" WHERE mxid=$1'
'FROM "user" WHERE mxid=$1')
row = await cls.db.fetchrow(q, mxid) row = await cls.db.fetchrow(q, mxid)
if not row: return cls._from_row(row)
return None
return cls(**row) async def insert(self) -> None:
q = 'INSERT INTO "user" (mxid, mid, notice_room) VALUES ($1, $2, $3)'
await self.db.execute(q, self.mxid, self.mid, self.notice_room)
async def delete(self) -> None:
await self.db.execute('DELETE FROM "user" WHERE mxid=$1', self.mxid)
async def update(self) -> None:
await self.db.execute('UPDATE "user" SET mid=$2, notice_room=$3 WHERE mxid=$1',
self.mxid, self.mid, self.notice_room)