241 lines
6.5 KiB
Python
241 lines
6.5 KiB
Python
# Copyright (c) 2013-2021 by Ron Frederick <ronf@timeheart.net> and others.
|
|
#
|
|
# This program and the accompanying materials are made available under
|
|
# the terms of the Eclipse Public License v2.0 which accompanies this
|
|
# distribution and is available at:
|
|
#
|
|
# http://www.eclipse.org/legal/epl-2.0/
|
|
#
|
|
# This program may also be made available under the following secondary
|
|
# licenses when the conditions for such availability set forth in the
|
|
# Eclipse Public License v2.0 are satisfied:
|
|
#
|
|
# GNU General Public License, Version 2.0, or any later versions of
|
|
# that license
|
|
#
|
|
# SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later
|
|
#
|
|
# Contributors:
|
|
# Ron Frederick - initial implementation, API, and documentation
|
|
|
|
"""SSH packet encoding and decoding functions"""
|
|
|
|
from typing import Any, Callable, Iterable, Mapping, Optional, Sequence, Union
|
|
|
|
from .logging import SSHLogger
|
|
from .misc import plural
|
|
|
|
|
|
_LoggedPacket = Union[bytes, 'SSHPacket']
|
|
_PacketHandler = Callable[[Any, int, int, 'SSHPacket'], None]
|
|
|
|
|
|
class PacketDecodeError(ValueError):
|
|
"""Packet decoding error"""
|
|
|
|
|
|
def Byte(value: int) -> bytes:
|
|
"""Encode a single byte"""
|
|
|
|
return bytes((value,))
|
|
|
|
|
|
def Boolean(value: bool) -> bytes:
|
|
"""Encode a boolean value"""
|
|
|
|
return Byte(bool(value))
|
|
|
|
|
|
def UInt16(value: int) -> bytes:
|
|
"""Encode a 16-bit integer value"""
|
|
|
|
return value.to_bytes(2, 'big')
|
|
|
|
|
|
def UInt32(value: int) -> bytes:
|
|
"""Encode a 32-bit integer value"""
|
|
|
|
return value.to_bytes(4, 'big')
|
|
|
|
|
|
def UInt64(value: int) -> bytes:
|
|
"""Encode a 64-bit integer value"""
|
|
|
|
return value.to_bytes(8, 'big')
|
|
|
|
|
|
def String(value: Union[bytes, str]) -> bytes:
|
|
"""Encode a byte string or UTF-8 string value"""
|
|
|
|
if isinstance(value, str):
|
|
value = value.encode('utf-8', errors='strict')
|
|
|
|
return len(value).to_bytes(4, 'big') + value
|
|
|
|
|
|
def MPInt(value: int) -> bytes:
|
|
"""Encode a multiple precision integer value"""
|
|
|
|
l = value.bit_length()
|
|
l += (l % 8 == 0 and value != 0 and value != -1 << (l - 1))
|
|
l = (l + 7) // 8
|
|
|
|
return l.to_bytes(4, 'big') + value.to_bytes(l, 'big', signed=True)
|
|
|
|
|
|
def NameList(value: Iterable[bytes]) -> bytes:
|
|
"""Encode a comma-separated list of byte strings"""
|
|
|
|
return String(b','.join(value))
|
|
|
|
|
|
class SSHPacket:
|
|
"""Decoder class for SSH packets"""
|
|
|
|
def __init__(self, packet: bytes):
|
|
self._packet = packet
|
|
self._idx = 0
|
|
self._len = len(packet)
|
|
|
|
def __bool__(self) -> bool:
|
|
return self._idx != self._len
|
|
|
|
def check_end(self) -> None:
|
|
"""Confirm that all of the data in the packet has been consumed"""
|
|
|
|
if self:
|
|
raise PacketDecodeError('Unexpected data at end of packet')
|
|
|
|
def get_consumed_payload(self) -> bytes:
|
|
"""Return the portion of the packet consumed so far"""
|
|
|
|
return self._packet[:self._idx]
|
|
|
|
def get_remaining_payload(self) -> bytes:
|
|
"""Return the portion of the packet not yet consumed"""
|
|
|
|
return self._packet[self._idx:]
|
|
|
|
def get_full_payload(self) -> bytes:
|
|
"""Return the full packet"""
|
|
|
|
return self._packet
|
|
|
|
def get_bytes(self, size: int) -> bytes:
|
|
"""Extract the requested number of bytes from the packet"""
|
|
|
|
if self._idx + size > self._len:
|
|
raise PacketDecodeError('Incomplete packet')
|
|
|
|
value = self._packet[self._idx:self._idx+size]
|
|
self._idx += size
|
|
return value
|
|
|
|
def get_byte(self) -> int:
|
|
"""Extract a single byte from the packet"""
|
|
|
|
return self.get_bytes(1)[0]
|
|
|
|
def get_boolean(self) -> bool:
|
|
"""Extract a boolean from the packet"""
|
|
|
|
return bool(self.get_byte())
|
|
|
|
def get_uint16(self) -> int:
|
|
"""Extract a 16-bit integer from the packet"""
|
|
|
|
return int.from_bytes(self.get_bytes(2), 'big')
|
|
|
|
def get_uint32(self) -> int:
|
|
"""Extract a 32-bit integer from the packet"""
|
|
|
|
return int.from_bytes(self.get_bytes(4), 'big')
|
|
|
|
def get_uint64(self) -> int:
|
|
"""Extract a 64-bit integer from the packet"""
|
|
|
|
return int.from_bytes(self.get_bytes(8), 'big')
|
|
|
|
def get_string(self) -> bytes:
|
|
"""Extract a UTF-8 string from the packet"""
|
|
|
|
return self.get_bytes(self.get_uint32())
|
|
|
|
def get_mpint(self) -> int:
|
|
"""Extract a multiple precision integer from the packet"""
|
|
|
|
return int.from_bytes(self.get_string(), 'big', signed=True)
|
|
|
|
def get_namelist(self) -> Sequence[bytes]:
|
|
"""Extract a comma-separated list of byte strings from the packet"""
|
|
|
|
namelist = self.get_string()
|
|
return namelist.split(b',') if namelist else []
|
|
|
|
|
|
class SSHPacketLogger:
|
|
"""Parent class for SSH packet loggers"""
|
|
|
|
_handler_names: Mapping[int, str] = {}
|
|
|
|
@property
|
|
def logger(self) -> SSHLogger:
|
|
"""The logger to use for packet logging"""
|
|
|
|
raise NotImplementedError
|
|
|
|
def _log_packet(self, msg: str, pkttype: int, pktid: Optional[int],
|
|
packet: _LoggedPacket, note: str) -> None:
|
|
"""Log a sent/received packet"""
|
|
|
|
if isinstance(packet, SSHPacket):
|
|
packet = packet.get_full_payload()
|
|
|
|
try:
|
|
name = '%s (%d)' % (self._handler_names[pkttype], pkttype)
|
|
except KeyError:
|
|
name = 'packet type %d' % pkttype
|
|
|
|
count = plural(len(packet), 'byte')
|
|
|
|
if note:
|
|
note = ' (%s)' % note
|
|
|
|
self.logger.packet(pktid, packet, '%s %s, %s%s',
|
|
msg, name, count, note)
|
|
|
|
def log_sent_packet(self, pkttype: int, pktid: Optional[int],
|
|
packet: _LoggedPacket, note: str = '') -> None:
|
|
"""Log a sent packet"""
|
|
|
|
self._log_packet('Sent', pkttype, pktid, packet, note)
|
|
|
|
|
|
def log_received_packet(self, pkttype: int, pktid: Optional[int],
|
|
packet: _LoggedPacket, note: str = '') -> None:
|
|
"""Log a received packet"""
|
|
|
|
self._log_packet('Received', pkttype, pktid, packet, note)
|
|
|
|
|
|
class SSHPacketHandler(SSHPacketLogger):
|
|
"""Parent class for SSH packet handlers"""
|
|
|
|
_packet_handlers: Mapping[int, _PacketHandler] = {}
|
|
|
|
@property
|
|
def logger(self) -> SSHLogger:
|
|
"""The logger associated with this packet handler"""
|
|
|
|
raise NotImplementedError
|
|
|
|
def process_packet(self, pkttype: int, pktid: int,
|
|
packet: SSHPacket) -> bool:
|
|
"""Log and process a received packet"""
|
|
|
|
if pkttype in self._packet_handlers:
|
|
self._packet_handlers[pkttype](self, pkttype, pktid, packet)
|
|
return True
|
|
else:
|
|
return False
|