Compare commits

...

4 Commits

9 changed files with 110 additions and 54 deletions

View File

@ -49,7 +49,7 @@ from ..types.request import (
CommandResultDoneValue
)
from .types import PortalChannelInfo, UserInfoUnion
from .types import PortalChannelInfo, UserInfoUnion, ChannelProps
from .errors import InvalidAccessToken
from .error_helper import raise_unsuccessful_response
@ -109,7 +109,7 @@ class Client:
Obtain a session token by logging in with user-provided credentials.
Must have first called register_device with these credentials.
"""
# NOTE Actually returns a LoginData object, but this only needs an OAuthCredential
# NOTE Actually returns an auth LoginData, but this only needs an OAuthCredential
return await cls._api_request_result(OAuthCredential, "login", is_secret=True, **req)
# endregion
@ -207,34 +207,34 @@ class Client:
)
return profile_req_struct.profile
async def get_portal_channel_info(self, channel_id: Long) -> PortalChannelInfo:
async def get_portal_channel_info(self, channel_props: ChannelProps) -> PortalChannelInfo:
return await self._api_user_request_result(
PortalChannelInfo,
"get_portal_channel_info",
channel_id=channel_id.serialize()
channel_props=channel_props.serialize(),
)
async def get_participants(self, channel_id: Long) -> list[UserInfoUnion]:
async def get_participants(self, channel_props: ChannelProps) -> list[UserInfoUnion]:
return await self._api_user_request_result(
ResultListType(UserInfoUnion),
"get_participants",
channel_id=channel_id.serialize()
channel_props=channel_props.serialize()
)
async def get_chats(self, channel_id: Long, sync_from: Long | None, limit: int | None) -> list[Chatlog]:
async def get_chats(self, channel_props: ChannelProps, sync_from: Long | None, limit: int | None) -> list[Chatlog]:
return await self._api_user_request_result(
ResultListType(Chatlog),
"get_chats",
channel_id=channel_id.serialize(),
channel_props=channel_props.serialize(),
sync_from=sync_from.serialize() if sync_from else None,
limit=limit
)
async def send_message(self, channel_id: Long, text: str) -> Chatlog:
async def send_message(self, channel_props: ChannelProps, text: str) -> Chatlog:
return await self._api_user_request_result(
Chatlog,
"send_message",
channel_id=channel_id.serialize(),
channel_props=channel_props.serialize(),
text=text
)

View File

@ -21,7 +21,9 @@ from attr import dataclass
from mautrix.types import SerializableAttrs, JSON, deserializer
from ..types.bson import Long
from ..types.channel.channel_info import NormalChannelInfo
from ..types.channel.channel_type import ChannelType
from ..types.openlink.open_channel_info import OpenChannelInfo
from ..types.user.channel_user_info import NormalChannelUserInfo, OpenChannelUserInfo
@ -56,3 +58,9 @@ class PortalChannelInfo(SerializableAttrs):
participants: list[UserInfoUnion]
# TODO Image
channel_info: Optional[ChannelInfoUnion] = None # Should be set manually by caller
@dataclass
class ChannelProps(SerializableAttrs):
id: Long
type: ChannelType

View File

@ -22,8 +22,7 @@ from ..oauth import OAuthCredential
@dataclass
class AuthLoginData(OAuthCredential):
"""aka LoginData"""
class LoginData(OAuthCredential):
countryIso: str
countryCode: str
accountId: int
@ -94,7 +93,7 @@ class KnownAuthStatusCode(IntEnum):
__all__ = [
"AuthLoginData",
"LoginData",
"LoginForm",
"TokenLoginForm",
"KnownAuthStatusCode",

View File

@ -67,8 +67,7 @@ class ChannelData(SerializableAttrs, Generic[T]):
@dataclass
class ChannelLoginData(SerializableAttrs, Generic[T]):
"""aka non-auth LoginData"""
class LoginData(SerializableAttrs, Generic[T]):
lastUpdate: int
channel: T
@ -88,6 +87,6 @@ __all__ = [
"ChannelInfo",
"NormalChannelInfo",
"ChannelData",
"ChannelLoginData",
"LoginData",
"NormalChannelData",
]

View File

@ -20,28 +20,28 @@ from attr import dataclass
from mautrix.types import SerializableAttrs, JSON, deserializer
from ..bson import Long
from ..channel.channel_info import ChannelLoginData, NormalChannelData
from ..channel.channel_info import LoginData, NormalChannelData
from ..openlink.open_channel_info import OpenChannelData
ChannelLoginDataItem = NewType("ChannelLoginDataItem", ChannelLoginData[Union[NormalChannelData, OpenChannelData]])
LoginDataItem = NewType("LoginDataItem", LoginData[Union[NormalChannelData, OpenChannelData]])
@deserializer(ChannelLoginDataItem)
def deserialize_channel_login_data_item(data: JSON) -> ChannelLoginDataItem:
@deserializer(LoginDataItem)
def deserialize_channel_login_data_item(data: JSON) -> LoginDataItem:
channel_data = data["channel"]
if "linkId" in channel_data:
data["channel"] = OpenChannelData.deserialize(channel_data)
else:
data["channel"] = NormalChannelData.deserialize(channel_data)
return ChannelLoginData.deserialize(data)
return LoginData.deserialize(data)
setattr(ChannelLoginDataItem, "deserialize", deserialize_channel_login_data_item)
setattr(LoginDataItem, "deserialize", deserialize_channel_login_data_item)
@dataclass
class LoginResult(SerializableAttrs):
"""Return value of TalkClient.login"""
channelList: list[ChannelLoginDataItem]
channelList: list[LoginDataItem]
userId: Long
lastChannelId: Long
lastTokenId: Long

View File

@ -52,7 +52,7 @@ from .kt.types.channel.channel_info import ChannelInfo
from .kt.types.channel.channel_type import KnownChannelType, ChannelType
from .kt.types.chat.chat import Chatlog
from .kt.client.types import UserInfoUnion, PortalChannelInfo
from .kt.client.types import UserInfoUnion, PortalChannelInfo, ChannelProps
from .kt.client.errors import CommandException
if TYPE_CHECKING:
@ -194,6 +194,13 @@ class Portal(DBPortal, BasePortal):
raise ValueError(f"Non-direct chat portal should have no sender, but has sender {self._kt_sender}")
return self._kt_sender
@property
def channel_props(self) -> ChannelProps:
return ChannelProps(
id=self.ktid,
type=self.kt_type
)
@property
def main_intent(self) -> IntentAPI:
if not self._main_intent:
@ -245,7 +252,7 @@ class Portal(DBPortal, BasePortal):
) -> PortalChannelInfo:
if not info:
self.log.debug("Called update_info with no info, fetching channel info...")
info = await source.client.get_portal_channel_info(self.ktid)
info = await source.client.get_portal_channel_info(self.channel_props)
changed = False
if not self.is_direct:
changed = any(
@ -400,7 +407,7 @@ class Portal(DBPortal, BasePortal):
async def _update_participants(self, source: u.User, participants: list[UserInfoUnion] | None = None) -> bool:
if participants is None:
self.log.debug("Called _update_participants with no participants, fetching them now...")
participants = await source.client.get_participants(self.ktid)
participants = await source.client.get_participants(self.channel_props)
changed = False
if not self._main_intent:
assert self.is_direct, "_main_intent for non-direct chat portal should have been set already"
@ -736,7 +743,7 @@ class Portal(DBPortal, BasePortal):
converted = await matrix_to_kakaotalk(message, self.mxid, self.log)
try:
chatlog = await sender.client.send_message(
self.ktid,
self.channel_props,
text=converted.text,
# TODO
#mentions=converted.mentions,
@ -959,7 +966,7 @@ class Portal(DBPortal, BasePortal):
self.log.debug(f"Backfilling history through {source.mxid}")
self.log.debug(f"Fetching {f'up to {limit}' if limit else 'all'} messages through {source.ktid}")
messages = await source.client.get_chats(
channel_info.channelId,
self.channel_props,
after_log_id,
limit
)
@ -989,7 +996,6 @@ class Portal(DBPortal, BasePortal):
# TODO Save kt_sender in DB instead? Depends on if DM channels are shared...
user = await u.User.get_by_ktid(self.kt_receiver)
assert user, f"Found no user for this portal's receiver of {self.kt_receiver}"
# TODO Should this backfill? Useful for forgotten channels
await self._update_participants(user)
else:
self.log.debug("Not setting _main_intent of new direct chat until after checking participant list")

View File

@ -83,7 +83,7 @@ class RPCClient:
await self.request("register", peer_id=self.config["appservice.address"])
async def disconnect(self) -> None:
assert self._writer is not None
if self._writer is not None:
self._writer.write_eof()
await self._writer.drain()
self._writer = None

View File

@ -42,7 +42,7 @@ from .kt.types.bson import Long
from .kt.types.channel.channel_info import ChannelInfo, NormalChannelInfo, NormalChannelData
from .kt.types.channel.channel_type import ChannelType
from .kt.types.chat.chat import Chatlog
from .kt.types.client.client_session import ChannelLoginDataItem, LoginResult
from .kt.types.client.client_session import LoginDataItem, LoginResult
from .kt.types.oauth import OAuthCredential
from .kt.types.openlink.open_channel_info import OpenChannelData, OpenChannelInfo
@ -449,19 +449,19 @@ class User(DBUser, BaseUser):
sync_count = num_channels if sync_count < 0 else min(sync_count, num_channels)
await self.push_bridge_state(BridgeStateEvent.BACKFILLING)
self.log.debug(f"Syncing {sync_count} of {num_channels} channels...")
for channel_item in login_result.channelList[:sync_count]:
for login_data in login_result.channelList[:sync_count]:
try:
await self._sync_channel(channel_item)
await self._sync_channel(login_data)
except AuthenticationRequired:
raise
except Exception:
self.log.exception(f"Failed to sync channel {channel_item.channel.channelId}")
self.log.exception(f"Failed to sync channel {login_data.channel.channelId}")
await self.update_direct_chats()
async def _sync_channel(self, channel_item: ChannelLoginDataItem) -> None:
channel_data = channel_item.channel
self.log.debug(f"Syncing channel {channel_data.channelId} (last updated at {channel_item.lastUpdate})")
async def _sync_channel(self, login_data: LoginDataItem) -> None:
channel_data = login_data.channel
self.log.debug(f"Syncing channel {channel_data.channelId} (last updated at {login_data.lastUpdate})")
channel_info = channel_data.info
if isinstance(channel_data, NormalChannelData):
channel_data: NormalChannelData
@ -499,7 +499,7 @@ class User(DBUser, BaseUser):
kt_receiver=self.ktid,
kt_type=channel_info.type
)
portal_info = await self.client.get_portal_channel_info(channel_info.channelId)
portal_info = await self.client.get_portal_channel_info(portal.channel_props)
portal_info.channel_info = channel_info
if not portal.mxid:
await portal.create_matrix_room(self, portal_info)

View File

@ -14,6 +14,7 @@
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
import { Long } from "bson"
import {
AuthApiClient,
OAuthApiClient,
@ -22,10 +23,12 @@ import {
KnownAuthStatusCode,
util,
} from "node-kakao"
/** @typedef {import("node-kakao").OAuthCredential} OAuthCredential */
/** @typedef {import("node-kakao/dist/talk").TalkChannelList} TalkChannelList */
/** @typedef {import("node-kakao").ChannelType} ChannelType */
import chat from "node-kakao/chat"
const { KnownChatType } = chat
/** @typedef {import("node-kakao").OAuthCredential} OAuthCredential */
/** @typedef {import("./clientmanager.js").default} ClientManager} */
import { emitLines, promisify } from "./util.js"
@ -67,6 +70,30 @@ class UserClient {
return userClient
}
/**
* @param {Object} channel_props
* @param {Long} channel_props.id
* @param {ChannelType} channel_props.type
*/
async getChannel(channel_props) {
let channel = this.#talkClient.channelList.get(channel_props.id)
if (channel) {
return channel
} else {
const channelList = getChannelListForType(
this.#talkClient.channelList,
channel_props.type
)
const res = await channelList.addChannel({
channelId: channel_props.id,
})
if (!res.success) {
throw new Error(`Unable to add ${channel_props.type} channel ${channel_props.id}`)
}
return res.result
}
}
close() {
this.#talkClient.close()
}
@ -82,20 +109,21 @@ class UserClient {
export default class PeerClient {
/**
* @param {ClientManager} manager
* @param {import("./clientmanager.js").default} manager
* @param {import("net").Socket} socket
* @param {number} connID
* @param {Map<string, UserClient>} userClients
*/
constructor(manager, socket, connID) {
this.manager = manager
this.socket = socket
this.connID = connID
this.stopped = false
this.notificationID = 0
this.maxCommandID = 0
this.peerID = null
/** @type {Map<string, UserClient>} */
this.userClients = new Map()
}
@ -335,11 +363,11 @@ export default class PeerClient {
/**
* @param {Object} req
* @param {string} req.mxid
* @param {Long} req.channel_id
* @param {Object} req.channel_props
*/
getPortalChannelInfo = async (req) => {
const userClient = this.#getUser(req.mxid)
const talkChannel = userClient.talkClient.channelList.get(req.channel_id)
const talkChannel = await userClient.getChannel(req.channel_props)
const res = await talkChannel.updateAll()
if (!res.success) return res
@ -354,24 +382,24 @@ export default class PeerClient {
/**
* @param {Object} req
* @param {string} req.mxid
* @param {Long} req.channel_id
* @param {Object} req.channel_props
*/
getParticipants = async (req) => {
const userClient = this.#getUser(req.mxid)
const talkChannel = userClient.getChannel(req.channel_id)
const talkChannel = await userClient.getChannel(req.channel_props)
return await talkChannel.getAllLatestUserInfo()
}
/**
* @param {Object} req
* @param {string} req.mxid
* @param {Long} req.channel_id
* @param {Object} req.channel_props
* @param {Long?} req.sync_from
* @param {Number?} req.limit
*/
getChats = async (req) => {
const userClient = this.#getUser(req.mxid)
const talkChannel = userClient.talkClient.channelList.get(req.channel_id)
const talkChannel = await userClient.getChannel(req.channel_props)
const res = await talkChannel.getChatListFrom(req.sync_from)
if (res.success && 0 < req.limit && req.limit < res.result.length) {
@ -383,12 +411,12 @@ export default class PeerClient {
/**
* @param {Object} req
* @param {string} req.mxid
* @param {Long} req.channel_id
* @param {Object} req.channel_props
* @param {string} req.text
*/
sendMessage = async (req) => {
const userClient = this.#getUser(req.mxid)
const talkChannel = userClient.talkClient.channelList.get(req.channel_id)
const talkChannel = await userClient.getChannel(req.channel_props)
return await talkChannel.sendChat({
type: KnownChatType.TEXT,
@ -444,6 +472,7 @@ export default class PeerClient {
this.log("Ignoring old request", req.id)
return
}
this.log("Received request", req.id, "with command", req.command)
this.maxCommandID = req.id
let handler
if (!this.peerID) {
@ -514,3 +543,18 @@ export default class PeerClient {
return value
}
}
/**
* @param {TalkChannelList} channelList
* @param {ChannelType} channelType
*/
function getChannelListForType(channelList, channelType) {
switch (channelType) {
case "OM":
case "OD":
return channelList.open
default:
return channelList.normal
}
}