# Copyright (c) 2016-2023 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this # distribution and is available at: # # http://www.eclipse.org/legal/epl-2.0/ # # This program may also be made available under the following secondary # licenses when the conditions for such availability set forth in the # Eclipse Public License v2.0 are satisfied: # # GNU General Public License, Version 2.0, or any later versions of # that license # # SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later # # Contributors: # Ron Frederick - initial implementation, API, and documentation """X11 forwarding support""" import asyncio import os from pathlib import Path import socket import time from typing import TYPE_CHECKING, Callable, Dict, Iterable from typing import NamedTuple, Optional, Sequence, Set, Tuple from .constants import OPEN_CONNECT_FAILED from .forward import SSHForwarder, SSHForwarderCoro from .listener import SSHListener, create_tcp_forward_listener from .logging import logger from .misc import ChannelOpenError from .session import DataType if TYPE_CHECKING: # pylint: disable=cyclic-import from .channel import SSHChannel from .connection import SSHServerConnection _RecvHandler = Optional[Callable[[bytes], None]] # Xauth address families XAUTH_FAMILY_IPV4 = 0 XAUTH_FAMILY_DECNET = 1 XAUTH_FAMILY_IPV6 = 6 XAUTH_FAMILY_HOSTNAME = 256 XAUTH_FAMILY_WILD = 65535 # Xauth protocol values XAUTH_PROTO_COOKIE = b'MIT-MAGIC-COOKIE-1' XAUTH_COOKIE_LEN = 16 # Xauth lock information XAUTH_LOCK_SUFFIX = '-c' XAUTH_LOCK_TRIES = 5 XAUTH_LOCK_DELAY = 0.2 XAUTH_LOCK_DEAD = 5 # X11 display and port numbers X11_BASE_PORT = 6000 X11_DISPLAY_START = 10 X11_MAX_DISPLAYS = 64 # Host to listen on when doing X11 forwarding X11_LISTEN_HOST = 'localhost' def _parse_display(display: str) -> Tuple[str, str, int]: """Parse an X11 display value""" try: host, dpynum = display.rsplit(':', 1) if host.startswith('[') and host.endswith(']'): host = host[1:-1] idx = dpynum.find('.') if idx >= 0: screen = int(dpynum[idx+1:]) dpynum = dpynum[:idx] else: screen = 0 except (ValueError, UnicodeEncodeError): raise ValueError('Invalid X11 display') from None return host, dpynum, screen async def _lookup_host(loop: asyncio.AbstractEventLoop, host: str, family: int) -> Sequence[str]: """Look up IPv4 or IPv6 addresses of a host name""" try: addrinfo = await loop.getaddrinfo(host, 0, family=family, type=socket.SOCK_STREAM) except socket.gaierror: return [] return [ai[4][0] for ai in addrinfo] class SSHXAuthorityEntry(NamedTuple): """An entry in an Xauthority file""" family: int addr: bytes dpynum: bytes proto: bytes data: bytes def __bytes__(self) -> bytes: """Construct an Xauthority entry""" def _uint16(value: int) -> bytes: """Construct a big-endian 16-bit unsigned integer""" return value.to_bytes(2, 'big') def _string(data: bytes) -> bytes: """Construct a binary string with a 16-bit length""" return _uint16(len(data)) + data return b''.join((_uint16(self.family), _string(self.addr), _string(self.dpynum), _string(self.proto), _string(self.data))) class SSHX11ClientForwarder(SSHForwarder): """X11 forwarding connection handler""" def __init__(self, listener: 'SSHX11ClientListener', peer: SSHForwarder): super().__init__(peer) self._listener = listener self._inpbuf = b'' self._bytes_needed = 12 self._recv_handler: _RecvHandler = self._recv_prefix self._endian = b'' self._prefix = b'' self._auth_proto_len = 0 self._auth_data_len = 0 self._auth_proto = b'' self._auth_proto_pad = b'' self._auth_data = b'' self._auth_data_pad = b'' def _encode_uint16(self, value: int) -> bytes: """Encode a 16-bit unsigned integer""" if self._endian == b'B': return bytes((value >> 8, value & 255)) else: return bytes((value & 255, value >> 8)) def _decode_uint16(self, value: bytes) -> int: """Decode a 16-bit unsigned integer""" if self._endian == b'B': return (value[0] << 8) + value[1] else: return (value[1] << 8) + value[0] @staticmethod def _padded_len(length: int) -> int: """Return length rounded up to the next multiple of 4 bytes""" return ((length + 3) // 4) * 4 @staticmethod def _pad(data: bytes) -> bytes: """Pad a string to a multiple of 4 bytes""" length = len(data) % 4 return data + ((4 - length) * b'\00' if length else b'') def _recv_prefix(self, data: bytes) -> None: """Parse X11 client prefix""" self._endian = data[:1] self._prefix = data self._auth_proto_len = self._decode_uint16(data[6:8]) self._auth_data_len = self._decode_uint16(data[8:10]) self._recv_handler = self._recv_auth_proto self._bytes_needed = self._padded_len(self._auth_proto_len) def _recv_auth_proto(self, data: bytes) -> None: """Extract X11 auth protocol""" self._auth_proto = data[:self._auth_proto_len] self._auth_proto_pad = data[self._auth_proto_len:] self._recv_handler = self._recv_auth_data self._bytes_needed = self._padded_len(self._auth_data_len) def _recv_auth_data(self, data: bytes) -> None: """Extract X11 auth data""" self._auth_data = data[:self._auth_data_len] self._auth_data_pad = data[self._auth_data_len:] try: self._auth_data = self._listener.validate_auth(self._auth_data) except KeyError: reason = b'Invalid authentication key\n' response = b''.join((bytes((0, len(reason))), self._encode_uint16(11), self._encode_uint16(0), self._encode_uint16((len(reason) + 3) // 4), self._pad(reason))) try: self.write(response) self.write_eof() except OSError: # pragma: no cover pass self._inpbuf = b'' else: self._inpbuf = (self._prefix + self._auth_proto + self._auth_proto_pad + self._auth_data + self._auth_data_pad) self._recv_handler = None self._bytes_needed = 0 def data_received(self, data: bytes, datatype: DataType = None) -> None: """Handle incoming data from the X11 client""" if self._recv_handler: self._inpbuf += data while self._recv_handler: # type: ignore[truthy-function] if len(self._inpbuf) >= self._bytes_needed: data = self._inpbuf[:self._bytes_needed] self._inpbuf = self._inpbuf[self._bytes_needed:] self._recv_handler(data) else: return data = self._inpbuf self._inpbuf = b'' if data: super().data_received(data, datatype) class SSHX11ClientListener: """Client listener used to accept forwarded X11 connections""" def __init__(self, loop: asyncio.AbstractEventLoop, host: str, dpynum: str, auth_proto: bytes, auth_data: bytes): self._host = host self._dpynum = dpynum self._auth_proto = auth_proto self._local_auth = auth_data if host.startswith('/'): self._connect_coro: SSHForwarderCoro = loop.create_unix_connection self._connect_args: Sequence[object] = (host + ':' + dpynum,) elif host in ('', 'unix'): self._connect_coro = loop.create_unix_connection self._connect_args = ('/tmp/.X11-unix/X' + dpynum,) else: self._connect_coro = loop.create_connection self._connect_args = (host, X11_BASE_PORT + int(dpynum)) self._remote_auth: Dict['SSHChannel', bytes] = {} self._channel: Dict[bytes, Tuple['SSHChannel', bool]] = {} def attach(self, display: str, chan: 'SSHChannel', single_connection: bool) -> Tuple[bytes, bytes, int]: """Attach a channel to this listener""" host, dpynum, screen = _parse_display(display) if self._host != host or self._dpynum != dpynum: raise ValueError('Already forwarding to another X11 display') remote_auth = os.urandom(len(self._local_auth)) self._remote_auth[chan] = remote_auth self._channel[remote_auth] = chan, single_connection return self._auth_proto, remote_auth, screen def detach(self, chan: 'SSHChannel') -> bool: """Detach a channel from this listener""" try: remote_auth = self._remote_auth.pop(chan) del self._channel[remote_auth] except KeyError: pass return not bool(self._remote_auth) async def forward_connection(self) -> SSHX11ClientForwarder: """Forward an incoming connection to the local X server""" peer: SSHForwarder try: _, peer = await self._connect_coro(SSHForwarder, *self._connect_args) except OSError as exc: raise ChannelOpenError(OPEN_CONNECT_FAILED, str(exc)) from None return SSHX11ClientForwarder(self, peer) def validate_auth(self, remote_auth: bytes) -> bytes: """Validate client auth and enforce single connection flag""" chan, single_connection = self._channel[remote_auth] if single_connection: del self._channel[remote_auth] del self._remote_auth[chan] return self._local_auth class SSHX11ServerListener: """Server listener used to forward X11 connections""" def __init__(self, tcp_listener: SSHListener, display: str): self._tcp_listener = tcp_listener self._display = display self._channels: Set[object] = set() def attach(self, chan: 'SSHChannel', screen: int) -> str: """Attach a channel to this listener and return its display""" self._channels.add(chan) return '%s.%s' % (self._display, screen) def detach(self, chan: 'SSHChannel') -> bool: """Detach a channel from this listener""" try: self._channels.remove(chan) except KeyError: pass if not self._channels: self._tcp_listener.close() return True else: return False def get_xauth_path(auth_path: Optional[str]) -> str: """Compute the path to the Xauthority file""" if not auth_path: auth_path = os.environ.get('XAUTHORITY') if not auth_path: auth_path = str(Path('~', '.Xauthority').expanduser()) return auth_path def walk_xauth(auth_path: str) -> Iterable[SSHXAuthorityEntry]: """Walk the entries in an Xauthority file""" def _read_bytes(n: int) -> bytes: """Read exactly n bytes""" data = auth_file.read(n) if len(data) != n: raise EOFError return data def _read_uint16() -> int: """Read a 16-bit unsigned integer""" return int.from_bytes(_read_bytes(2), 'big') def _read_string() -> bytes: """Read a string""" return _read_bytes(_read_uint16()) try: with open(auth_path, 'rb') as auth_file: while True: try: family = _read_uint16() except EOFError: break try: yield SSHXAuthorityEntry(family, _read_string(), _read_string(), _read_string(), _read_string()) except EOFError: raise ValueError('Incomplete Xauthority entry') from None except OSError: pass async def lookup_xauth(loop: asyncio.AbstractEventLoop, auth_path: Optional[str], host: str, dpynum: str) -> Tuple[bytes, bytes]: """Look up Xauthority data for the specified display""" auth_path = get_xauth_path(auth_path) if host.startswith('/') or host in ('', 'unix', 'localhost'): host = socket.gethostname() dpynum = dpynum.encode('ascii') ipv4_addrs: Sequence[str] = [] ipv6_addrs: Sequence[str] = [] for entry in walk_xauth(auth_path): if entry.dpynum and entry.dpynum != dpynum: continue if entry.family == XAUTH_FAMILY_IPV4: if not ipv4_addrs: ipv4_addrs = await _lookup_host(loop, host, socket.AF_INET) addr = socket.inet_ntop(socket.AF_INET, entry.addr) match = addr in ipv4_addrs elif entry.family == XAUTH_FAMILY_IPV6: if not ipv6_addrs: ipv6_addrs = await _lookup_host(loop, host, socket.AF_INET6) addr = socket.inet_ntop(socket.AF_INET6, entry.addr) match = addr in ipv6_addrs elif entry.family == XAUTH_FAMILY_HOSTNAME: match = entry.addr == host.encode('idna') elif entry.family == XAUTH_FAMILY_WILD: match = True else: match = False if match: return entry.proto, entry.data logger.debug1('No xauth entry found for display: using random auth') return XAUTH_PROTO_COOKIE, os.urandom(XAUTH_COOKIE_LEN) async def update_xauth(auth_path: Optional[str], host: str, dpynum: str, auth_proto: bytes, auth_data: bytes) -> None: """Update Xauthority data for the specified display""" if host.startswith('/') or host in ('', 'unix', 'localhost'): host = socket.gethostname() host = host.encode('idna') dpynum = str(dpynum).encode('ascii') auth_path = get_xauth_path(auth_path) new_auth_path = auth_path + XAUTH_LOCK_SUFFIX new_file = None try: if time.time() - os.stat(new_auth_path).st_ctime > XAUTH_LOCK_DEAD: os.unlink(new_auth_path) except FileNotFoundError: pass for _ in range(XAUTH_LOCK_TRIES): try: new_file = open(new_auth_path, 'xb') except FileExistsError: await asyncio.sleep(XAUTH_LOCK_DELAY) else: break if not new_file: raise ValueError('Unable to acquire Xauthority lock') new_entry = SSHXAuthorityEntry(XAUTH_FAMILY_HOSTNAME, host, dpynum, auth_proto, auth_data) new_file.write(bytes(new_entry)) for entry in walk_xauth(auth_path): if (entry.family != new_entry.family or entry.addr != new_entry.addr or entry.dpynum != new_entry.dpynum): new_file.write(bytes(entry)) new_file.close() os.replace(new_auth_path, auth_path) async def create_x11_client_listener(loop: asyncio.AbstractEventLoop, display: str, auth_path: Optional[str]) -> \ SSHX11ClientListener: """Create a listener to accept X11 connections forwarded over SSH""" host, dpynum, _ = _parse_display(display) auth_proto, auth_data = await lookup_xauth(loop, auth_path, host, dpynum) return SSHX11ClientListener(loop, host, dpynum, auth_proto, auth_data) async def create_x11_server_listener(conn: 'SSHServerConnection', loop: asyncio.AbstractEventLoop, auth_path: Optional[str], auth_proto: bytes, auth_data: bytes) -> \ Optional[SSHX11ServerListener]: """Create a listener to forward X11 connections over SSH""" for dpynum in range(X11_DISPLAY_START, X11_MAX_DISPLAYS): try: tcp_listener = await create_tcp_forward_listener( conn, loop, conn.create_x11_connection, X11_LISTEN_HOST, X11_BASE_PORT + dpynum) except OSError: continue display = '%s:%d' % (X11_LISTEN_HOST, dpynum) try: await update_xauth(auth_path, X11_LISTEN_HOST, str(dpynum), auth_proto, auth_data) except ValueError: tcp_listener.close() break return SSHX11ServerListener(tcp_listener, display) return None