Compare commits

..

4 Commits

9 changed files with 110 additions and 54 deletions

View File

@ -49,7 +49,7 @@ from ..types.request import (
CommandResultDoneValue CommandResultDoneValue
) )
from .types import PortalChannelInfo, UserInfoUnion from .types import PortalChannelInfo, UserInfoUnion, ChannelProps
from .errors import InvalidAccessToken from .errors import InvalidAccessToken
from .error_helper import raise_unsuccessful_response from .error_helper import raise_unsuccessful_response
@ -109,7 +109,7 @@ class Client:
Obtain a session token by logging in with user-provided credentials. Obtain a session token by logging in with user-provided credentials.
Must have first called register_device with these credentials. Must have first called register_device with these credentials.
""" """
# NOTE Actually returns a LoginData object, but this only needs an OAuthCredential # NOTE Actually returns an auth LoginData, but this only needs an OAuthCredential
return await cls._api_request_result(OAuthCredential, "login", is_secret=True, **req) return await cls._api_request_result(OAuthCredential, "login", is_secret=True, **req)
# endregion # endregion
@ -207,34 +207,34 @@ class Client:
) )
return profile_req_struct.profile return profile_req_struct.profile
async def get_portal_channel_info(self, channel_id: Long) -> PortalChannelInfo: async def get_portal_channel_info(self, channel_props: ChannelProps) -> PortalChannelInfo:
return await self._api_user_request_result( return await self._api_user_request_result(
PortalChannelInfo, PortalChannelInfo,
"get_portal_channel_info", "get_portal_channel_info",
channel_id=channel_id.serialize() channel_props=channel_props.serialize(),
) )
async def get_participants(self, channel_id: Long) -> list[UserInfoUnion]: async def get_participants(self, channel_props: ChannelProps) -> list[UserInfoUnion]:
return await self._api_user_request_result( return await self._api_user_request_result(
ResultListType(UserInfoUnion), ResultListType(UserInfoUnion),
"get_participants", "get_participants",
channel_id=channel_id.serialize() channel_props=channel_props.serialize()
) )
async def get_chats(self, channel_id: Long, sync_from: Long | None, limit: int | None) -> list[Chatlog]: async def get_chats(self, channel_props: ChannelProps, sync_from: Long | None, limit: int | None) -> list[Chatlog]:
return await self._api_user_request_result( return await self._api_user_request_result(
ResultListType(Chatlog), ResultListType(Chatlog),
"get_chats", "get_chats",
channel_id=channel_id.serialize(), channel_props=channel_props.serialize(),
sync_from=sync_from.serialize() if sync_from else None, sync_from=sync_from.serialize() if sync_from else None,
limit=limit limit=limit
) )
async def send_message(self, channel_id: Long, text: str) -> Chatlog: async def send_message(self, channel_props: ChannelProps, text: str) -> Chatlog:
return await self._api_user_request_result( return await self._api_user_request_result(
Chatlog, Chatlog,
"send_message", "send_message",
channel_id=channel_id.serialize(), channel_props=channel_props.serialize(),
text=text text=text
) )

View File

@ -21,7 +21,9 @@ from attr import dataclass
from mautrix.types import SerializableAttrs, JSON, deserializer from mautrix.types import SerializableAttrs, JSON, deserializer
from ..types.bson import Long
from ..types.channel.channel_info import NormalChannelInfo from ..types.channel.channel_info import NormalChannelInfo
from ..types.channel.channel_type import ChannelType
from ..types.openlink.open_channel_info import OpenChannelInfo from ..types.openlink.open_channel_info import OpenChannelInfo
from ..types.user.channel_user_info import NormalChannelUserInfo, OpenChannelUserInfo from ..types.user.channel_user_info import NormalChannelUserInfo, OpenChannelUserInfo
@ -56,3 +58,9 @@ class PortalChannelInfo(SerializableAttrs):
participants: list[UserInfoUnion] participants: list[UserInfoUnion]
# TODO Image # TODO Image
channel_info: Optional[ChannelInfoUnion] = None # Should be set manually by caller channel_info: Optional[ChannelInfoUnion] = None # Should be set manually by caller
@dataclass
class ChannelProps(SerializableAttrs):
id: Long
type: ChannelType

View File

@ -22,8 +22,7 @@ from ..oauth import OAuthCredential
@dataclass @dataclass
class AuthLoginData(OAuthCredential): class LoginData(OAuthCredential):
"""aka LoginData"""
countryIso: str countryIso: str
countryCode: str countryCode: str
accountId: int accountId: int
@ -94,7 +93,7 @@ class KnownAuthStatusCode(IntEnum):
__all__ = [ __all__ = [
"AuthLoginData", "LoginData",
"LoginForm", "LoginForm",
"TokenLoginForm", "TokenLoginForm",
"KnownAuthStatusCode", "KnownAuthStatusCode",

View File

@ -67,8 +67,7 @@ class ChannelData(SerializableAttrs, Generic[T]):
@dataclass @dataclass
class ChannelLoginData(SerializableAttrs, Generic[T]): class LoginData(SerializableAttrs, Generic[T]):
"""aka non-auth LoginData"""
lastUpdate: int lastUpdate: int
channel: T channel: T
@ -88,6 +87,6 @@ __all__ = [
"ChannelInfo", "ChannelInfo",
"NormalChannelInfo", "NormalChannelInfo",
"ChannelData", "ChannelData",
"ChannelLoginData", "LoginData",
"NormalChannelData", "NormalChannelData",
] ]

View File

@ -20,28 +20,28 @@ from attr import dataclass
from mautrix.types import SerializableAttrs, JSON, deserializer from mautrix.types import SerializableAttrs, JSON, deserializer
from ..bson import Long from ..bson import Long
from ..channel.channel_info import ChannelLoginData, NormalChannelData from ..channel.channel_info import LoginData, NormalChannelData
from ..openlink.open_channel_info import OpenChannelData from ..openlink.open_channel_info import OpenChannelData
ChannelLoginDataItem = NewType("ChannelLoginDataItem", ChannelLoginData[Union[NormalChannelData, OpenChannelData]]) LoginDataItem = NewType("LoginDataItem", LoginData[Union[NormalChannelData, OpenChannelData]])
@deserializer(ChannelLoginDataItem) @deserializer(LoginDataItem)
def deserialize_channel_login_data_item(data: JSON) -> ChannelLoginDataItem: def deserialize_channel_login_data_item(data: JSON) -> LoginDataItem:
channel_data = data["channel"] channel_data = data["channel"]
if "linkId" in channel_data: if "linkId" in channel_data:
data["channel"] = OpenChannelData.deserialize(channel_data) data["channel"] = OpenChannelData.deserialize(channel_data)
else: else:
data["channel"] = NormalChannelData.deserialize(channel_data) data["channel"] = NormalChannelData.deserialize(channel_data)
return ChannelLoginData.deserialize(data) return LoginData.deserialize(data)
setattr(ChannelLoginDataItem, "deserialize", deserialize_channel_login_data_item) setattr(LoginDataItem, "deserialize", deserialize_channel_login_data_item)
@dataclass @dataclass
class LoginResult(SerializableAttrs): class LoginResult(SerializableAttrs):
"""Return value of TalkClient.login""" """Return value of TalkClient.login"""
channelList: list[ChannelLoginDataItem] channelList: list[LoginDataItem]
userId: Long userId: Long
lastChannelId: Long lastChannelId: Long
lastTokenId: Long lastTokenId: Long

View File

@ -52,7 +52,7 @@ from .kt.types.channel.channel_info import ChannelInfo
from .kt.types.channel.channel_type import KnownChannelType, ChannelType from .kt.types.channel.channel_type import KnownChannelType, ChannelType
from .kt.types.chat.chat import Chatlog from .kt.types.chat.chat import Chatlog
from .kt.client.types import UserInfoUnion, PortalChannelInfo from .kt.client.types import UserInfoUnion, PortalChannelInfo, ChannelProps
from .kt.client.errors import CommandException from .kt.client.errors import CommandException
if TYPE_CHECKING: if TYPE_CHECKING:
@ -194,6 +194,13 @@ class Portal(DBPortal, BasePortal):
raise ValueError(f"Non-direct chat portal should have no sender, but has sender {self._kt_sender}") raise ValueError(f"Non-direct chat portal should have no sender, but has sender {self._kt_sender}")
return self._kt_sender return self._kt_sender
@property
def channel_props(self) -> ChannelProps:
return ChannelProps(
id=self.ktid,
type=self.kt_type
)
@property @property
def main_intent(self) -> IntentAPI: def main_intent(self) -> IntentAPI:
if not self._main_intent: if not self._main_intent:
@ -245,7 +252,7 @@ class Portal(DBPortal, BasePortal):
) -> PortalChannelInfo: ) -> PortalChannelInfo:
if not info: if not info:
self.log.debug("Called update_info with no info, fetching channel info...") self.log.debug("Called update_info with no info, fetching channel info...")
info = await source.client.get_portal_channel_info(self.ktid) info = await source.client.get_portal_channel_info(self.channel_props)
changed = False changed = False
if not self.is_direct: if not self.is_direct:
changed = any( changed = any(
@ -400,7 +407,7 @@ class Portal(DBPortal, BasePortal):
async def _update_participants(self, source: u.User, participants: list[UserInfoUnion] | None = None) -> bool: async def _update_participants(self, source: u.User, participants: list[UserInfoUnion] | None = None) -> bool:
if participants is None: if participants is None:
self.log.debug("Called _update_participants with no participants, fetching them now...") self.log.debug("Called _update_participants with no participants, fetching them now...")
participants = await source.client.get_participants(self.ktid) participants = await source.client.get_participants(self.channel_props)
changed = False changed = False
if not self._main_intent: if not self._main_intent:
assert self.is_direct, "_main_intent for non-direct chat portal should have been set already" assert self.is_direct, "_main_intent for non-direct chat portal should have been set already"
@ -736,7 +743,7 @@ class Portal(DBPortal, BasePortal):
converted = await matrix_to_kakaotalk(message, self.mxid, self.log) converted = await matrix_to_kakaotalk(message, self.mxid, self.log)
try: try:
chatlog = await sender.client.send_message( chatlog = await sender.client.send_message(
self.ktid, self.channel_props,
text=converted.text, text=converted.text,
# TODO # TODO
#mentions=converted.mentions, #mentions=converted.mentions,
@ -959,7 +966,7 @@ class Portal(DBPortal, BasePortal):
self.log.debug(f"Backfilling history through {source.mxid}") self.log.debug(f"Backfilling history through {source.mxid}")
self.log.debug(f"Fetching {f'up to {limit}' if limit else 'all'} messages through {source.ktid}") self.log.debug(f"Fetching {f'up to {limit}' if limit else 'all'} messages through {source.ktid}")
messages = await source.client.get_chats( messages = await source.client.get_chats(
channel_info.channelId, self.channel_props,
after_log_id, after_log_id,
limit limit
) )
@ -989,7 +996,6 @@ class Portal(DBPortal, BasePortal):
# TODO Save kt_sender in DB instead? Depends on if DM channels are shared... # TODO Save kt_sender in DB instead? Depends on if DM channels are shared...
user = await u.User.get_by_ktid(self.kt_receiver) user = await u.User.get_by_ktid(self.kt_receiver)
assert user, f"Found no user for this portal's receiver of {self.kt_receiver}" assert user, f"Found no user for this portal's receiver of {self.kt_receiver}"
# TODO Should this backfill? Useful for forgotten channels
await self._update_participants(user) await self._update_participants(user)
else: else:
self.log.debug("Not setting _main_intent of new direct chat until after checking participant list") self.log.debug("Not setting _main_intent of new direct chat until after checking participant list")

View File

@ -83,7 +83,7 @@ class RPCClient:
await self.request("register", peer_id=self.config["appservice.address"]) await self.request("register", peer_id=self.config["appservice.address"])
async def disconnect(self) -> None: async def disconnect(self) -> None:
assert self._writer is not None if self._writer is not None:
self._writer.write_eof() self._writer.write_eof()
await self._writer.drain() await self._writer.drain()
self._writer = None self._writer = None

View File

@ -42,7 +42,7 @@ from .kt.types.bson import Long
from .kt.types.channel.channel_info import ChannelInfo, NormalChannelInfo, NormalChannelData from .kt.types.channel.channel_info import ChannelInfo, NormalChannelInfo, NormalChannelData
from .kt.types.channel.channel_type import ChannelType from .kt.types.channel.channel_type import ChannelType
from .kt.types.chat.chat import Chatlog from .kt.types.chat.chat import Chatlog
from .kt.types.client.client_session import ChannelLoginDataItem, LoginResult from .kt.types.client.client_session import LoginDataItem, LoginResult
from .kt.types.oauth import OAuthCredential from .kt.types.oauth import OAuthCredential
from .kt.types.openlink.open_channel_info import OpenChannelData, OpenChannelInfo from .kt.types.openlink.open_channel_info import OpenChannelData, OpenChannelInfo
@ -449,19 +449,19 @@ class User(DBUser, BaseUser):
sync_count = num_channels if sync_count < 0 else min(sync_count, num_channels) sync_count = num_channels if sync_count < 0 else min(sync_count, num_channels)
await self.push_bridge_state(BridgeStateEvent.BACKFILLING) await self.push_bridge_state(BridgeStateEvent.BACKFILLING)
self.log.debug(f"Syncing {sync_count} of {num_channels} channels...") self.log.debug(f"Syncing {sync_count} of {num_channels} channels...")
for channel_item in login_result.channelList[:sync_count]: for login_data in login_result.channelList[:sync_count]:
try: try:
await self._sync_channel(channel_item) await self._sync_channel(login_data)
except AuthenticationRequired: except AuthenticationRequired:
raise raise
except Exception: except Exception:
self.log.exception(f"Failed to sync channel {channel_item.channel.channelId}") self.log.exception(f"Failed to sync channel {login_data.channel.channelId}")
await self.update_direct_chats() await self.update_direct_chats()
async def _sync_channel(self, channel_item: ChannelLoginDataItem) -> None: async def _sync_channel(self, login_data: LoginDataItem) -> None:
channel_data = channel_item.channel channel_data = login_data.channel
self.log.debug(f"Syncing channel {channel_data.channelId} (last updated at {channel_item.lastUpdate})") self.log.debug(f"Syncing channel {channel_data.channelId} (last updated at {login_data.lastUpdate})")
channel_info = channel_data.info channel_info = channel_data.info
if isinstance(channel_data, NormalChannelData): if isinstance(channel_data, NormalChannelData):
channel_data: NormalChannelData channel_data: NormalChannelData
@ -499,7 +499,7 @@ class User(DBUser, BaseUser):
kt_receiver=self.ktid, kt_receiver=self.ktid,
kt_type=channel_info.type kt_type=channel_info.type
) )
portal_info = await self.client.get_portal_channel_info(channel_info.channelId) portal_info = await self.client.get_portal_channel_info(portal.channel_props)
portal_info.channel_info = channel_info portal_info.channel_info = channel_info
if not portal.mxid: if not portal.mxid:
await portal.create_matrix_room(self, portal_info) await portal.create_matrix_room(self, portal_info)

View File

@ -14,6 +14,7 @@
// You should have received a copy of the GNU Affero General Public License // You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>. // along with this program. If not, see <https://www.gnu.org/licenses/>.
import { Long } from "bson" import { Long } from "bson"
import { import {
AuthApiClient, AuthApiClient,
OAuthApiClient, OAuthApiClient,
@ -22,10 +23,12 @@ import {
KnownAuthStatusCode, KnownAuthStatusCode,
util, util,
} from "node-kakao" } from "node-kakao"
/** @typedef {import("node-kakao").OAuthCredential} OAuthCredential */
/** @typedef {import("node-kakao/dist/talk").TalkChannelList} TalkChannelList */
/** @typedef {import("node-kakao").ChannelType} ChannelType */
import chat from "node-kakao/chat" import chat from "node-kakao/chat"
const { KnownChatType } = chat const { KnownChatType } = chat
/** @typedef {import("node-kakao").OAuthCredential} OAuthCredential */
/** @typedef {import("./clientmanager.js").default} ClientManager} */
import { emitLines, promisify } from "./util.js" import { emitLines, promisify } from "./util.js"
@ -67,6 +70,30 @@ class UserClient {
return userClient return userClient
} }
/**
* @param {Object} channel_props
* @param {Long} channel_props.id
* @param {ChannelType} channel_props.type
*/
async getChannel(channel_props) {
let channel = this.#talkClient.channelList.get(channel_props.id)
if (channel) {
return channel
} else {
const channelList = getChannelListForType(
this.#talkClient.channelList,
channel_props.type
)
const res = await channelList.addChannel({
channelId: channel_props.id,
})
if (!res.success) {
throw new Error(`Unable to add ${channel_props.type} channel ${channel_props.id}`)
}
return res.result
}
}
close() { close() {
this.#talkClient.close() this.#talkClient.close()
} }
@ -82,20 +109,21 @@ class UserClient {
export default class PeerClient { export default class PeerClient {
/** /**
* @param {ClientManager} manager * @param {import("./clientmanager.js").default} manager
* @param {import("net").Socket} socket * @param {import("net").Socket} socket
* @param {number} connID * @param {number} connID
* @param {Map<string, UserClient>} userClients
*/ */
constructor(manager, socket, connID) { constructor(manager, socket, connID) {
this.manager = manager this.manager = manager
this.socket = socket this.socket = socket
this.connID = connID this.connID = connID
this.stopped = false this.stopped = false
this.notificationID = 0 this.notificationID = 0
this.maxCommandID = 0 this.maxCommandID = 0
this.peerID = null this.peerID = null
/** @type {Map<string, UserClient>} */
this.userClients = new Map() this.userClients = new Map()
} }
@ -335,11 +363,11 @@ export default class PeerClient {
/** /**
* @param {Object} req * @param {Object} req
* @param {string} req.mxid * @param {string} req.mxid
* @param {Long} req.channel_id * @param {Object} req.channel_props
*/ */
getPortalChannelInfo = async (req) => { getPortalChannelInfo = async (req) => {
const userClient = this.#getUser(req.mxid) const userClient = this.#getUser(req.mxid)
const talkChannel = userClient.talkClient.channelList.get(req.channel_id) const talkChannel = await userClient.getChannel(req.channel_props)
const res = await talkChannel.updateAll() const res = await talkChannel.updateAll()
if (!res.success) return res if (!res.success) return res
@ -354,24 +382,24 @@ export default class PeerClient {
/** /**
* @param {Object} req * @param {Object} req
* @param {string} req.mxid * @param {string} req.mxid
* @param {Long} req.channel_id * @param {Object} req.channel_props
*/ */
getParticipants = async (req) => { getParticipants = async (req) => {
const userClient = this.#getUser(req.mxid) const userClient = this.#getUser(req.mxid)
const talkChannel = userClient.getChannel(req.channel_id) const talkChannel = await userClient.getChannel(req.channel_props)
return await talkChannel.getAllLatestUserInfo() return await talkChannel.getAllLatestUserInfo()
} }
/** /**
* @param {Object} req * @param {Object} req
* @param {string} req.mxid * @param {string} req.mxid
* @param {Long} req.channel_id * @param {Object} req.channel_props
* @param {Long?} req.sync_from * @param {Long?} req.sync_from
* @param {Number?} req.limit * @param {Number?} req.limit
*/ */
getChats = async (req) => { getChats = async (req) => {
const userClient = this.#getUser(req.mxid) const userClient = this.#getUser(req.mxid)
const talkChannel = userClient.talkClient.channelList.get(req.channel_id) const talkChannel = await userClient.getChannel(req.channel_props)
const res = await talkChannel.getChatListFrom(req.sync_from) const res = await talkChannel.getChatListFrom(req.sync_from)
if (res.success && 0 < req.limit && req.limit < res.result.length) { if (res.success && 0 < req.limit && req.limit < res.result.length) {
@ -383,12 +411,12 @@ export default class PeerClient {
/** /**
* @param {Object} req * @param {Object} req
* @param {string} req.mxid * @param {string} req.mxid
* @param {Long} req.channel_id * @param {Object} req.channel_props
* @param {string} req.text * @param {string} req.text
*/ */
sendMessage = async (req) => { sendMessage = async (req) => {
const userClient = this.#getUser(req.mxid) const userClient = this.#getUser(req.mxid)
const talkChannel = userClient.talkClient.channelList.get(req.channel_id) const talkChannel = await userClient.getChannel(req.channel_props)
return await talkChannel.sendChat({ return await talkChannel.sendChat({
type: KnownChatType.TEXT, type: KnownChatType.TEXT,
@ -444,6 +472,7 @@ export default class PeerClient {
this.log("Ignoring old request", req.id) this.log("Ignoring old request", req.id)
return return
} }
this.log("Received request", req.id, "with command", req.command)
this.maxCommandID = req.id this.maxCommandID = req.id
let handler let handler
if (!this.peerID) { if (!this.peerID) {
@ -514,3 +543,18 @@ export default class PeerClient {
return value return value
} }
} }
/**
* @param {TalkChannelList} channelList
* @param {ChannelType} channelType
*/
function getChannelListForType(channelList, channelType) {
switch (channelType) {
case "OM":
case "OD":
return channelList.open
default:
return channelList.normal
}
}