Files
vconnect-api/venv/lib/python3.12/site-packages/asyncssh/packet.py
2025-12-08 21:35:55 +09:00

241 lines
6.5 KiB
Python

# Copyright (c) 2013-2021 by Ron Frederick <ronf@timeheart.net> and others.
#
# This program and the accompanying materials are made available under
# the terms of the Eclipse Public License v2.0 which accompanies this
# distribution and is available at:
#
# http://www.eclipse.org/legal/epl-2.0/
#
# This program may also be made available under the following secondary
# licenses when the conditions for such availability set forth in the
# Eclipse Public License v2.0 are satisfied:
#
# GNU General Public License, Version 2.0, or any later versions of
# that license
#
# SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later
#
# Contributors:
# Ron Frederick - initial implementation, API, and documentation
"""SSH packet encoding and decoding functions"""
from typing import Any, Callable, Iterable, Mapping, Optional, Sequence, Union
from .logging import SSHLogger
from .misc import plural
_LoggedPacket = Union[bytes, 'SSHPacket']
_PacketHandler = Callable[[Any, int, int, 'SSHPacket'], None]
class PacketDecodeError(ValueError):
"""Packet decoding error"""
def Byte(value: int) -> bytes:
"""Encode a single byte"""
return bytes((value,))
def Boolean(value: bool) -> bytes:
"""Encode a boolean value"""
return Byte(bool(value))
def UInt16(value: int) -> bytes:
"""Encode a 16-bit integer value"""
return value.to_bytes(2, 'big')
def UInt32(value: int) -> bytes:
"""Encode a 32-bit integer value"""
return value.to_bytes(4, 'big')
def UInt64(value: int) -> bytes:
"""Encode a 64-bit integer value"""
return value.to_bytes(8, 'big')
def String(value: Union[bytes, str]) -> bytes:
"""Encode a byte string or UTF-8 string value"""
if isinstance(value, str):
value = value.encode('utf-8', errors='strict')
return len(value).to_bytes(4, 'big') + value
def MPInt(value: int) -> bytes:
"""Encode a multiple precision integer value"""
l = value.bit_length()
l += (l % 8 == 0 and value != 0 and value != -1 << (l - 1))
l = (l + 7) // 8
return l.to_bytes(4, 'big') + value.to_bytes(l, 'big', signed=True)
def NameList(value: Iterable[bytes]) -> bytes:
"""Encode a comma-separated list of byte strings"""
return String(b','.join(value))
class SSHPacket:
"""Decoder class for SSH packets"""
def __init__(self, packet: bytes):
self._packet = packet
self._idx = 0
self._len = len(packet)
def __bool__(self) -> bool:
return self._idx != self._len
def check_end(self) -> None:
"""Confirm that all of the data in the packet has been consumed"""
if self:
raise PacketDecodeError('Unexpected data at end of packet')
def get_consumed_payload(self) -> bytes:
"""Return the portion of the packet consumed so far"""
return self._packet[:self._idx]
def get_remaining_payload(self) -> bytes:
"""Return the portion of the packet not yet consumed"""
return self._packet[self._idx:]
def get_full_payload(self) -> bytes:
"""Return the full packet"""
return self._packet
def get_bytes(self, size: int) -> bytes:
"""Extract the requested number of bytes from the packet"""
if self._idx + size > self._len:
raise PacketDecodeError('Incomplete packet')
value = self._packet[self._idx:self._idx+size]
self._idx += size
return value
def get_byte(self) -> int:
"""Extract a single byte from the packet"""
return self.get_bytes(1)[0]
def get_boolean(self) -> bool:
"""Extract a boolean from the packet"""
return bool(self.get_byte())
def get_uint16(self) -> int:
"""Extract a 16-bit integer from the packet"""
return int.from_bytes(self.get_bytes(2), 'big')
def get_uint32(self) -> int:
"""Extract a 32-bit integer from the packet"""
return int.from_bytes(self.get_bytes(4), 'big')
def get_uint64(self) -> int:
"""Extract a 64-bit integer from the packet"""
return int.from_bytes(self.get_bytes(8), 'big')
def get_string(self) -> bytes:
"""Extract a UTF-8 string from the packet"""
return self.get_bytes(self.get_uint32())
def get_mpint(self) -> int:
"""Extract a multiple precision integer from the packet"""
return int.from_bytes(self.get_string(), 'big', signed=True)
def get_namelist(self) -> Sequence[bytes]:
"""Extract a comma-separated list of byte strings from the packet"""
namelist = self.get_string()
return namelist.split(b',') if namelist else []
class SSHPacketLogger:
"""Parent class for SSH packet loggers"""
_handler_names: Mapping[int, str] = {}
@property
def logger(self) -> SSHLogger:
"""The logger to use for packet logging"""
raise NotImplementedError
def _log_packet(self, msg: str, pkttype: int, pktid: Optional[int],
packet: _LoggedPacket, note: str) -> None:
"""Log a sent/received packet"""
if isinstance(packet, SSHPacket):
packet = packet.get_full_payload()
try:
name = '%s (%d)' % (self._handler_names[pkttype], pkttype)
except KeyError:
name = 'packet type %d' % pkttype
count = plural(len(packet), 'byte')
if note:
note = ' (%s)' % note
self.logger.packet(pktid, packet, '%s %s, %s%s',
msg, name, count, note)
def log_sent_packet(self, pkttype: int, pktid: Optional[int],
packet: _LoggedPacket, note: str = '') -> None:
"""Log a sent packet"""
self._log_packet('Sent', pkttype, pktid, packet, note)
def log_received_packet(self, pkttype: int, pktid: Optional[int],
packet: _LoggedPacket, note: str = '') -> None:
"""Log a received packet"""
self._log_packet('Received', pkttype, pktid, packet, note)
class SSHPacketHandler(SSHPacketLogger):
"""Parent class for SSH packet handlers"""
_packet_handlers: Mapping[int, _PacketHandler] = {}
@property
def logger(self) -> SSHLogger:
"""The logger associated with this packet handler"""
raise NotImplementedError
def process_packet(self, pkttype: int, pktid: int,
packet: SSHPacket) -> bool:
"""Log and process a received packet"""
if pkttype in self._packet_handlers:
self._packet_handlers[pkttype](self, pkttype, pktid, packet)
return True
else:
return False