first commit

This commit is contained in:
unknown
2025-12-08 21:35:55 +09:00
commit f343f405f7
5357 changed files with 923703 additions and 0 deletions

View File

@@ -0,0 +1,171 @@
# Copyright (c) 2013-2022 by Ron Frederick <ronf@timeheart.net> and others.
#
# This program and the accompanying materials are made available under
# the terms of the Eclipse Public License v2.0 which accompanies this
# distribution and is available at:
#
# http://www.eclipse.org/legal/epl-2.0/
#
# This program may also be made available under the following secondary
# licenses when the conditions for such availability set forth in the
# Eclipse Public License v2.0 are satisfied:
#
# GNU General Public License, Version 2.0, or any later versions of
# that license
#
# SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later
#
# Contributors:
# Ron Frederick - initial implementation, API, and documentation
"""An asynchronous SSH2 library for Python"""
from .version import __author__, __author_email__, __url__, __version__
# pylint: disable=wildcard-import
from .constants import *
# pylint: enable=wildcard-import
from .agent import SSHAgentClient, SSHAgentKeyPair, connect_agent
from .auth_keys import SSHAuthorizedKeys
from .auth_keys import import_authorized_keys, read_authorized_keys
from .channel import SSHClientChannel, SSHServerChannel
from .channel import SSHTCPChannel, SSHUNIXChannel
from .client import SSHClient
from .config import ConfigParseError
from .forward import SSHForwarder
from .connection import SSHAcceptor, SSHClientConnection, SSHServerConnection
from .connection import SSHClientConnectionOptions, SSHServerConnectionOptions
from .connection import SSHAcceptHandler
from .connection import create_connection, create_server, connect, listen
from .connection import connect_reverse, listen_reverse, get_server_host_key
from .connection import get_server_auth_methods, run_client, run_server
from .editor import SSHLineEditorChannel
from .known_hosts import SSHKnownHosts
from .known_hosts import import_known_hosts, read_known_hosts
from .known_hosts import match_known_hosts
from .listener import SSHListener
from .logging import logger, set_log_level, set_sftp_log_level, set_debug_level
from .misc import BytesOrStr
from .misc import Error, DisconnectError, ChannelOpenError, ChannelListenError
from .misc import ConnectionLost, CompressionError, HostKeyNotVerifiable
from .misc import KeyExchangeFailed, IllegalUserName, MACError
from .misc import PermissionDenied, ProtocolError, ProtocolNotSupported
from .misc import ServiceNotAvailable, PasswordChangeRequired
from .misc import BreakReceived, SignalReceived, TerminalSizeChanged
from .pbe import KeyEncryptionError
from .pkcs11 import load_pkcs11_keys
from .process import SSHServerProcessFactory
from .process import SSHClientProcess, SSHServerProcess
from .process import SSHCompletedProcess, ProcessError
from .process import TimeoutError # pylint: disable=redefined-builtin
from .process import DEVNULL, PIPE, STDOUT
from .public_key import SSHKey, SSHKeyPair, SSHCertificate
from .public_key import KeyGenerationError, KeyImportError, KeyExportError
from .public_key import generate_private_key, import_private_key
from .public_key import import_public_key, import_certificate
from .public_key import read_private_key, read_public_key, read_certificate
from .public_key import read_private_key_list, read_public_key_list
from .public_key import read_certificate_list
from .public_key import load_keypairs, load_public_keys, load_certificates
from .public_key import load_resident_keys
from .rsa import set_default_skip_rsa_key_validation
from .scp import scp
from .session import DataType, SSHClientSession, SSHServerSession
from .session import SSHTCPSession, SSHUNIXSession
from .server import SSHServer
from .sftp import SFTPClient, SFTPClientFile, SFTPServer, SFTPError
from .sftp import SFTPEOFError, SFTPNoSuchFile, SFTPPermissionDenied
from .sftp import SFTPFailure, SFTPBadMessage, SFTPNoConnection
from .sftp import SFTPInvalidHandle, SFTPNoSuchPath, SFTPFileAlreadyExists
from .sftp import SFTPWriteProtect, SFTPNoMedia, SFTPNoSpaceOnFilesystem
from .sftp import SFTPQuotaExceeded, SFTPUnknownPrincipal, SFTPLockConflict
from .sftp import SFTPDirNotEmpty, SFTPNotADirectory, SFTPInvalidFilename
from .sftp import SFTPLinkLoop, SFTPCannotDelete, SFTPInvalidParameter
from .sftp import SFTPFileIsADirectory, SFTPByteRangeLockConflict
from .sftp import SFTPByteRangeLockRefused, SFTPDeletePending
from .sftp import SFTPFileCorrupt, SFTPOwnerInvalid, SFTPGroupInvalid
from .sftp import SFTPNoMatchingByteRangeLock
from .sftp import SFTPConnectionLost, SFTPOpUnsupported
from .sftp import SFTPAttrs, SFTPVFSAttrs, SFTPName
from .sftp import SEEK_SET, SEEK_CUR, SEEK_END
from .stream import SSHSocketSessionFactory, SSHServerSessionFactory
from .stream import SFTPServerFactory, SSHReader, SSHWriter
from .subprocess import SSHSubprocessReadPipe, SSHSubprocessWritePipe
from .subprocess import SSHSubprocessProtocol, SSHSubprocessTransport
# Import these explicitly to trigger register calls in them
from . import sk_eddsa, sk_ecdsa, eddsa, ecdsa, rsa, dsa, kex_dh, kex_rsa
__all__ = [
'BreakReceived', 'BytesOrStr', 'ChannelListenError',
'ChannelOpenError', 'CompressionError', 'ConfigParseError',
'ConnectionLost', 'DEVNULL', 'DataType', 'DisconnectError', 'Error',
'HostKeyNotVerifiable', 'IllegalUserName', 'KeyEncryptionError',
'KeyExchangeFailed', 'KeyExportError', 'KeyGenerationError',
'KeyImportError', 'MACError', 'PIPE', 'PasswordChangeRequired',
'PermissionDenied', 'ProcessError', 'ProtocolError',
'ProtocolNotSupported', 'SEEK_CUR', 'SEEK_END', 'SEEK_SET',
'SFTPAttrs', 'SFTPBadMessage', 'SFTPByteRangeLockConflict',
'SFTPByteRangeLockRefused', 'SFTPCannotDelete', 'SFTPClient',
'SFTPClientFile', 'SFTPConnectionLost', 'SFTPDeletePending',
'SFTPDirNotEmpty', 'SFTPEOFError', 'SFTPError', 'SFTPFailure',
'SFTPFileAlreadyExists', 'SFTPFileCorrupt', 'SFTPFileIsADirectory',
'SFTPGroupInvalid', 'SFTPInvalidFilename', 'SFTPInvalidHandle',
'SFTPInvalidParameter', 'SFTPLinkLoop', 'SFTPLockConflict', 'SFTPName',
'SFTPNoConnection', 'SFTPNoMatchingByteRangeLock', 'SFTPNoMedia',
'SFTPNoSpaceOnFilesystem', 'SFTPNoSuchFile', 'SFTPNoSuchPath',
'SFTPNotADirectory', 'SFTPOpUnsupported', 'SFTPOwnerInvalid',
'SFTPPermissionDenied', 'SFTPQuotaExceeded', 'SFTPServer',
'SFTPServerFactory', 'SFTPUnknownPrincipal', 'SFTPVFSAttrs',
'SFTPWriteProtect', 'SSHAcceptor', 'SSHAgentClient', 'SSHAgentKeyPair',
'SSHAuthorizedKeys', 'SSHCertificate', 'SSHClient', 'SSHClientChannel',
'SSHClientConnection', 'SSHClientConnectionOptions',
'SSHClientProcess', 'SSHClientSession', 'SSHCompletedProcess',
'SSHForwarder', 'SSHKey', 'SSHKeyPair', 'SSHKnownHosts',
'SSHLineEditorChannel', 'SSHListener', 'SSHReader', 'SSHServer',
'SSHServerChannel', 'SSHServerConnection',
'SSHServerConnectionOptions', 'SSHServerProcess',
'SSHServerProcessFactory', 'SSHServerSession',
'SSHServerSessionFactory', 'SSHSocketSessionFactory',
'SSHSubprocessProtocol', 'SSHSubprocessReadPipe',
'SSHSubprocessTransport', 'SSHSubprocessWritePipe', 'SSHTCPChannel',
'SSHTCPSession', 'SSHUNIXChannel', 'SSHUNIXSession', 'SSHWriter',
'STDOUT', 'ServiceNotAvailable', 'SignalReceived',
'TerminalSizeChanged', 'TimeoutError', 'connect', 'connect_agent',
'connect_reverse', 'create_connection', 'create_server',
'generate_private_key', 'get_server_auth_methods',
'get_server_host_key', 'import_authorized_keys', 'import_certificate',
'import_known_hosts', 'import_private_key', 'import_public_key',
'listen', 'listen_reverse', 'load_certificates', 'load_keypairs',
'load_pkcs11_keys', 'load_public_keys', 'load_resident_keys', 'logger',
'match_known_hosts', 'read_authorized_keys', 'read_certificate',
'read_certificate_list', 'read_known_hosts', 'read_private_key',
'read_private_key_list', 'read_public_key', 'read_public_key_list',
'run_client', 'run_server', 'scp', 'set_debug_level', 'set_log_level',
'set_sftp_log_level', 'set_default_skip_rsa_key_validation',
]

View File

@@ -0,0 +1,676 @@
# Copyright (c) 2016-2023 by Ron Frederick <ronf@timeheart.net> and others.
#
# This program and the accompanying materials are made available under
# the terms of the Eclipse Public License v2.0 which accompanies this
# distribution and is available at:
#
# http://www.eclipse.org/legal/epl-2.0/
#
# This program may also be made available under the following secondary
# licenses when the conditions for such availability set forth in the
# Eclipse Public License v2.0 are satisfied:
#
# GNU General Public License, Version 2.0, or any later versions of
# that license
#
# SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later
#
# Contributors:
# Ron Frederick - initial implementation, API, and documentation
"""SSH agent client"""
import asyncio
import errno
import os
import sys
from types import TracebackType
from typing import TYPE_CHECKING, List, Optional, Sequence, Tuple, Type, Union
from typing_extensions import Protocol
from .listener import SSHForwardListener
from .misc import async_context_manager, maybe_wait_closed
from .packet import Byte, String, UInt32, PacketDecodeError, SSHPacket
from .public_key import KeyPairListArg, SSHCertificate, SSHKeyPair
from .public_key import load_default_keypairs, load_keypairs
if TYPE_CHECKING:
from tempfile import TemporaryDirectory
class AgentReader(Protocol):
"""Protocol for reading from an SSH agent"""
async def readexactly(self, n: int) -> bytes:
"""Read exactly n bytes from the SSH agent"""
class AgentWriter(Protocol):
"""Protocol for writing to an SSH agent"""
def write(self, data: bytes) -> None:
"""Write bytes to the SSH agent"""
def close(self) -> None:
"""Close connection to the SSH agent"""
async def wait_closed(self) -> None:
"""Wait for the connection to the SSH agent to close"""
try:
if sys.platform == 'win32': # pragma: no cover
from .agent_win32 import open_agent
else:
from .agent_unix import open_agent
except ImportError as _exc: # pragma: no cover
async def open_agent(agent_path: str) -> \
Tuple[AgentReader, AgentWriter]:
"""Dummy function if we're unable to import agent support"""
raise OSError(errno.ENOENT, 'Agent support unavailable: %s' % str(_exc))
class _SupportsOpenAgentConnection(Protocol):
"""A class that supports open_agent_connection"""
async def open_agent_connection(self) -> Tuple[AgentReader, AgentWriter]:
"""Open a forwarded ssh-agent connection back to the client"""
_AgentPath = Union[str, _SupportsOpenAgentConnection]
# Client request message numbers
SSH_AGENTC_REQUEST_IDENTITIES = 11
SSH_AGENTC_SIGN_REQUEST = 13
SSH_AGENTC_ADD_IDENTITY = 17
SSH_AGENTC_REMOVE_IDENTITY = 18
SSH_AGENTC_REMOVE_ALL_IDENTITIES = 19
SSH_AGENTC_ADD_SMARTCARD_KEY = 20
SSH_AGENTC_REMOVE_SMARTCARD_KEY = 21
SSH_AGENTC_LOCK = 22
SSH_AGENTC_UNLOCK = 23
SSH_AGENTC_ADD_ID_CONSTRAINED = 25
SSH_AGENTC_ADD_SMARTCARD_KEY_CONSTRAINED = 26
SSH_AGENTC_EXTENSION = 27
# Agent response message numbers
SSH_AGENT_FAILURE = 5
SSH_AGENT_SUCCESS = 6
SSH_AGENT_IDENTITIES_ANSWER = 12
SSH_AGENT_SIGN_RESPONSE = 14
SSH_AGENT_EXTENSION_FAILURE = 28
# SSH agent constraint numbers
SSH_AGENT_CONSTRAIN_LIFETIME = 1
SSH_AGENT_CONSTRAIN_CONFIRM = 2
SSH_AGENT_CONSTRAIN_EXTENSION = 255
# SSH agent signature flags
SSH_AGENT_RSA_SHA2_256 = 2
SSH_AGENT_RSA_SHA2_512 = 4
class SSHAgentKeyPair(SSHKeyPair):
"""Surrogate for a key managed by the SSH agent"""
_key_type = 'agent'
def __init__(self, agent: 'SSHAgentClient', algorithm: bytes,
public_data: bytes, comment: bytes):
is_cert = algorithm.endswith(b'-cert-v01@openssh.com')
if is_cert:
if algorithm.startswith(b'sk-'):
sig_algorithm = algorithm[:-21] + b'@openssh.com'
else:
sig_algorithm = algorithm[:-21]
else:
sig_algorithm = algorithm
# Neither Pageant nor the Win10 OpenSSH agent seems to support the
# ssh-agent protocol flags used to request RSA SHA2 signatures yet
if sig_algorithm == b'ssh-rsa' and sys.platform != 'win32':
sig_algorithms: Sequence[bytes] = \
(b'rsa-sha2-256', b'rsa-sha2-512', b'ssh-rsa')
else:
sig_algorithms = (sig_algorithm,)
if is_cert:
host_key_algorithms: Sequence[bytes] = (algorithm,)
else:
host_key_algorithms = sig_algorithms
super().__init__(algorithm, sig_algorithm, sig_algorithms,
host_key_algorithms, public_data, comment)
self._agent = agent
self._is_cert = is_cert
self._flags = 0
@property
def has_cert(self) -> bool:
""" Return if this key pair has an associated cert"""
return self._is_cert
@property
def has_x509_chain(self) -> bool:
""" Return if this key pair has an associated X.509 cert chain"""
return False
def set_certificate(self, cert: SSHCertificate) -> None:
"""Set certificate to use with this key"""
super().set_certificate(cert)
self._is_cert = True
def set_sig_algorithm(self, sig_algorithm: bytes) -> None:
"""Set the signature algorithm to use when signing data"""
super().set_sig_algorithm(sig_algorithm)
if sig_algorithm in (b'rsa-sha2-256', b'x509v3-rsa2048-sha256'):
self._flags |= SSH_AGENT_RSA_SHA2_256
elif sig_algorithm == b'rsa-sha2-512':
self._flags |= SSH_AGENT_RSA_SHA2_512
async def sign_async(self, data: bytes) -> bytes:
"""Asynchronously sign a block of data with this private key"""
return await self._agent.sign(self.key_public_data, data, self._flags)
async def remove(self) -> None:
"""Remove this key pair from the agent"""
await self._agent.remove_keys([self])
class SSHAgentClient:
"""SSH agent client"""
def __init__(self, agent_path: _AgentPath):
self._agent_path = agent_path
self._reader: Optional[AgentReader] = None
self._writer: Optional[AgentWriter] = None
self._lock = asyncio.Lock()
async def __aenter__(self) -> 'SSHAgentClient':
"""Allow SSHAgentClient to be used as an async context manager"""
return self
async def __aexit__(self, exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType]) -> bool:
"""Wait for connection close when used as an async context manager"""
await self._cleanup()
return False
async def _cleanup(self) -> None:
"""Clean up this SSH agent client"""
self.close()
await self.wait_closed()
@staticmethod
def encode_constraints(lifetime: Optional[int], confirm: bool) -> bytes:
"""Encode key constraints"""
result = b''
if lifetime:
result += Byte(SSH_AGENT_CONSTRAIN_LIFETIME) + UInt32(lifetime)
if confirm:
result += Byte(SSH_AGENT_CONSTRAIN_CONFIRM)
return result
async def connect(self) -> None:
"""Connect to the SSH agent"""
if isinstance(self._agent_path, str):
self._reader, self._writer = await open_agent(self._agent_path)
else:
self._reader, self._writer = \
await self._agent_path.open_agent_connection()
async def _make_request(self, msgtype: int, *args: bytes) -> \
Tuple[int, SSHPacket]:
"""Send an SSH agent request"""
async with self._lock:
try:
if not self._writer:
await self.connect()
reader = self._reader
writer = self._writer
assert reader is not None
assert writer is not None
payload = Byte(msgtype) + b''.join(args)
writer.write(UInt32(len(payload)) + payload)
resplen = int.from_bytes((await reader.readexactly(4)), 'big')
resp = SSHPacket((await reader.readexactly(resplen)))
resptype = resp.get_byte()
return resptype, resp
except (OSError, EOFError, PacketDecodeError) as exc:
await self._cleanup()
raise ValueError(str(exc)) from None
async def get_keys(self, identities: Optional[Sequence[bytes]] = None) -> \
Sequence[SSHKeyPair]:
"""Request the available client keys
This method is a coroutine which returns a list of client keys
available in the ssh-agent.
:param identities: (optional)
A list of allowed byte string identities to return. If empty,
all identities on the SSH agent will be returned.
:returns: A list of :class:`SSHKeyPair` objects
"""
resptype, resp = \
await self._make_request(SSH_AGENTC_REQUEST_IDENTITIES)
if resptype == SSH_AGENT_IDENTITIES_ANSWER:
result: List[SSHKeyPair] = []
num_keys = resp.get_uint32()
for _ in range(num_keys):
key_blob = resp.get_string()
comment = resp.get_string()
if identities and key_blob not in identities:
continue
packet = SSHPacket(key_blob)
algorithm = packet.get_string()
result.append(SSHAgentKeyPair(self, algorithm,
key_blob, comment))
resp.check_end()
return result
else:
raise ValueError('Unknown SSH agent response: %d' % resptype)
async def sign(self, key_blob: bytes, data: bytes,
flags: int = 0) -> bytes:
"""Sign a block of data with the requested key"""
resptype, resp = await self._make_request(SSH_AGENTC_SIGN_REQUEST,
String(key_blob),
String(data), UInt32(flags))
if resptype == SSH_AGENT_SIGN_RESPONSE:
sig = resp.get_string()
resp.check_end()
return sig
elif resptype == SSH_AGENT_FAILURE:
raise ValueError('Unable to sign with requested key')
else:
raise ValueError('Unknown SSH agent response: %d' % resptype)
async def add_keys(self, keylist: KeyPairListArg = (),
passphrase: Optional[str] = None,
lifetime: Optional[int] = None,
confirm: bool = False) -> None:
"""Add keys to the agent
This method adds a list of local private keys and optional
matching certificates to the agent.
:param keylist: (optional)
The list of keys to add. If not specified, an attempt will
be made to load keys from the files
:file:`.ssh/id_ed25519_sk`, :file:`.ssh/id_ecdsa_sk`,
:file:`.ssh/id_ed448`, :file:`.ssh/id_ed25519`,
:file:`.ssh/id_ecdsa`, :file:`.ssh/id_rsa` and
:file:`.ssh/id_dsa` in the user's home directory with
optional matching certificates loaded from the files
:file:`.ssh/id_ed25519_sk-cert.pub`,
:file:`.ssh/id_ecdsa_sk-cert.pub`,
:file:`.ssh/id_ed448-cert.pub`,
:file:`.ssh/id_ed25519-cert.pub`,
:file:`.ssh/id_ecdsa-cert.pub`, :file:`.ssh/id_rsa-cert.pub`,
and :file:`.ssh/id_dsa-cert.pub`. Failures when adding keys
are ignored in this case, as the agent may not recognize
some of these key types.
:param passphrase: (optional)
The passphrase to use to decrypt the keys.
:param lifetime: (optional)
The time in seconds after which the keys should be
automatically deleted, or `None` to store these keys
indefinitely (the default).
:param confirm: (optional)
Whether or not to require confirmation for each private
key operation which uses these keys, defaulting to `False`.
:type keylist: *see* :ref:`SpecifyingPrivateKeys`
:type passphrase: `str`
:type lifetime: `int` or `None`
:type confirm: `bool`
:raises: :exc:`ValueError` if the keys cannot be added
"""
if keylist:
keypairs = load_keypairs(keylist, passphrase)
ignore_failures = False
else:
keypairs = load_default_keypairs(passphrase)
ignore_failures = True
base_constraints = self.encode_constraints(lifetime, confirm)
provider = os.environ.get('SSH_SK_PROVIDER') or 'internal'
sk_constraints = Byte(SSH_AGENT_CONSTRAIN_EXTENSION) + \
String('sk-provider@openssh.com') + \
String(provider)
for keypair in keypairs:
constraints = base_constraints
if keypair.algorithm.startswith(b'sk-'):
constraints += sk_constraints
msgtype = SSH_AGENTC_ADD_ID_CONSTRAINED if constraints else \
SSH_AGENTC_ADD_IDENTITY
comment = keypair.get_comment_bytes()
resptype, resp = \
await self._make_request(msgtype,
keypair.get_agent_private_key(),
String(comment or b''), constraints)
if resptype == SSH_AGENT_SUCCESS:
resp.check_end()
elif resptype == SSH_AGENT_FAILURE:
if not ignore_failures:
raise ValueError('Unable to add key')
else:
raise ValueError('Unknown SSH agent response: %d' % resptype)
async def add_smartcard_keys(self, provider: str,
pin: Optional[str] = None,
lifetime: Optional[int] = None,
confirm: bool = False) -> None:
"""Store keys associated with a smart card in the agent
:param provider:
The name of the smart card provider
:param pin: (optional)
The PIN to use to unlock the smart card
:param lifetime: (optional)
The time in seconds after which the keys should be
automatically deleted, or `None` to store these keys
indefinitely (the default).
:param confirm: (optional)
Whether or not to require confirmation for each private
key operation which uses these keys, defaulting to `False`.
:type provider: `str`
:type pin: `str` or `None`
:type lifetime: `int` or `None`
:type confirm: `bool`
:raises: :exc:`ValueError` if the keys cannot be added
"""
constraints = self.encode_constraints(lifetime, confirm)
msgtype = SSH_AGENTC_ADD_SMARTCARD_KEY_CONSTRAINED \
if constraints else SSH_AGENTC_ADD_SMARTCARD_KEY
resptype, resp = await self._make_request(msgtype, String(provider),
String(pin or ''),
constraints)
if resptype == SSH_AGENT_SUCCESS:
resp.check_end()
elif resptype == SSH_AGENT_FAILURE:
raise ValueError('Unable to add keys')
else:
raise ValueError('Unknown SSH agent response: %d' % resptype)
async def remove_keys(self, keylist: Sequence[SSHKeyPair]) -> None:
"""Remove a key stored in the agent
:param keylist:
The list of keys to remove.
:type keylist: `list` of :class:`SSHKeyPair`
:raises: :exc:`ValueError` if any keys are not found
"""
for keypair in keylist:
resptype, resp = \
await self._make_request(SSH_AGENTC_REMOVE_IDENTITY,
String(keypair.public_data))
if resptype == SSH_AGENT_SUCCESS:
resp.check_end()
elif resptype == SSH_AGENT_FAILURE:
raise ValueError('Key not found')
else:
raise ValueError('Unknown SSH agent response: %d' % resptype)
async def remove_smartcard_keys(self, provider: str,
pin: Optional[str] = None) -> None:
"""Remove keys associated with a smart card stored in the agent
:param provider:
The name of the smart card provider
:param pin: (optional)
The PIN to use to unlock the smart card
:type provider: `str`
:type pin: `str` or `None`
:raises: :exc:`ValueError` if the keys are not found
"""
resptype, resp = \
await self._make_request(SSH_AGENTC_REMOVE_SMARTCARD_KEY,
String(provider), String(pin or ''))
if resptype == SSH_AGENT_SUCCESS:
resp.check_end()
elif resptype == SSH_AGENT_FAILURE:
raise ValueError('Keys not found')
else:
raise ValueError('Unknown SSH agent response: %d' % resptype)
async def remove_all(self) -> None:
"""Remove all keys stored in the agent
:raises: :exc:`ValueError` if the keys can't be removed
"""
resptype, resp = \
await self._make_request(SSH_AGENTC_REMOVE_ALL_IDENTITIES)
if resptype == SSH_AGENT_SUCCESS:
resp.check_end()
elif resptype == SSH_AGENT_FAILURE:
raise ValueError('Unable to remove all keys')
else:
raise ValueError('Unknown SSH agent response: %d' % resptype)
async def lock(self, passphrase: str) -> None:
"""Lock the agent using the specified passphrase
.. note:: The lock and unlock actions don't appear to be
supported on the Windows 10 OpenSSH agent.
:param passphrase:
The passphrase required to later unlock the agent
:type passphrase: `str`
:raises: :exc:`ValueError` if the agent can't be locked
"""
resptype, resp = await self._make_request(SSH_AGENTC_LOCK,
String(passphrase))
if resptype == SSH_AGENT_SUCCESS:
resp.check_end()
elif resptype == SSH_AGENT_FAILURE:
raise ValueError('Unable to lock SSH agent')
else:
raise ValueError('Unknown SSH agent response: %d' % resptype)
async def unlock(self, passphrase: str) -> None:
"""Unlock the agent using the specified passphrase
.. note:: The lock and unlock actions don't appear to be
supported on the Windows 10 OpenSSH agent.
:param passphrase:
The passphrase to use to unlock the agent
:type passphrase: `str`
:raises: :exc:`ValueError` if the agent can't be unlocked
"""
resptype, resp = await self._make_request(SSH_AGENTC_UNLOCK,
String(passphrase))
if resptype == SSH_AGENT_SUCCESS:
resp.check_end()
elif resptype == SSH_AGENT_FAILURE:
raise ValueError('Unable to unlock SSH agent')
else:
raise ValueError('Unknown SSH agent response: %d' % resptype)
async def query_extensions(self) -> Sequence[str]:
"""Return a list of extensions supported by the agent
:returns: A list of strings of supported extension names
"""
resptype, resp = await self._make_request(SSH_AGENTC_EXTENSION,
String('query'))
if resptype == SSH_AGENT_SUCCESS:
result = []
while resp:
exttype = resp.get_string()
try:
exttype_str = exttype.decode('utf-8')
except UnicodeDecodeError:
raise ValueError('Invalid extension type name') from None
result.append(exttype_str)
return result
elif resptype == SSH_AGENT_FAILURE:
return []
else:
raise ValueError('Unknown SSH agent response: %d' % resptype)
def close(self) -> None:
"""Close the SSH agent connection
This method closes the connection to the ssh-agent. Any
attempts to use this :class:`SSHAgentClient` or the key
pairs it previously returned will result in an error.
"""
if self._writer:
self._writer.close()
async def wait_closed(self) -> None:
"""Wait for this agent connection to close
This method is a coroutine which can be called to block until
the connection to the agent has finished closing.
"""
if self._writer:
await maybe_wait_closed(self._writer)
self._reader = None
self._writer = None
class SSHAgentListener:
"""Listener used to forward agent connections"""
def __init__(self, tempdir: 'TemporaryDirectory[str]', path: str,
unix_listener: SSHForwardListener):
self._tempdir = tempdir
self._path = path
self._unix_listener = unix_listener
def get_path(self) -> str:
"""Return the path being listened on"""
return self._path
def close(self) -> None:
"""Close the agent listener"""
self._unix_listener.close()
self._tempdir.cleanup()
@async_context_manager
async def connect_agent(agent_path: _AgentPath = '') -> 'SSHAgentClient':
"""Make a connection to the SSH agent
This function attempts to connect to an ssh-agent process
listening on a UNIX domain socket at `agent_path`. If not
provided, it will attempt to get the path from the `SSH_AUTH_SOCK`
environment variable.
If the connection is successful, an :class:`SSHAgentClient` object
is returned that has methods on it you can use to query the
ssh-agent. If no path is specified and the environment variable
is not set or the connection to the agent fails, an error is
raised.
:param agent_path: (optional)
The path to use to contact the ssh-agent process, or the
:class:`SSHServerConnection` to forward the agent request
over.
:type agent_path: `str` or :class:`SSHServerConnection`
:returns: An :class:`SSHAgentClient`
:raises: :exc:`OSError` or :exc:`ChannelOpenError` if the
connection to the agent can't be opened
"""
if not agent_path:
agent_path = os.environ.get('SSH_AUTH_SOCK', '')
agent = SSHAgentClient(agent_path)
await agent.connect()
return agent

View File

@@ -0,0 +1,39 @@
# Copyright (c) 2016-2021 by Ron Frederick <ronf@timeheart.net> and others.
#
# This program and the accompanying materials are made available under
# the terms of the Eclipse Public License v2.0 which accompanies this
# distribution and is available at:
#
# http://www.eclipse.org/legal/epl-2.0/
#
# This program may also be made available under the following secondary
# licenses when the conditions for such availability set forth in the
# Eclipse Public License v2.0 are satisfied:
#
# GNU General Public License, Version 2.0, or any later versions of
# that license
#
# SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later
#
# Contributors:
# Ron Frederick - initial implementation, API, and documentation
"""SSH agent support code for UNIX"""
import asyncio
import errno
from typing import TYPE_CHECKING, Tuple
if TYPE_CHECKING:
# pylint: disable=cyclic-import
from .agent import AgentReader, AgentWriter
async def open_agent(agent_path: str) -> Tuple['AgentReader', 'AgentWriter']:
"""Open a connection to ssh-agent"""
if not agent_path:
raise OSError(errno.ENOENT, 'Agent not found')
return await asyncio.open_unix_connection(agent_path)

View File

@@ -0,0 +1,182 @@
# Copyright (c) 2016-2021 by Ron Frederick <ronf@timeheart.net> and others.
#
# This program and the accompanying materials are made available under
# the terms of the Eclipse Public License v2.0 which accompanies this
# distribution and is available at:
#
# http://www.eclipse.org/legal/epl-2.0/
#
# This program may also be made available under the following secondary
# licenses when the conditions for such availability set forth in the
# Eclipse Public License v2.0 are satisfied:
#
# GNU General Public License, Version 2.0, or any later versions of
# that license
#
# SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later
#
# Contributors:
# Ron Frederick - initial implementation, API, and documentation
"""SSH agent support code for Windows"""
# Some of the imports below won't be found when running pylint on UNIX
# pylint: disable=import-error
import asyncio
import ctypes
import ctypes.wintypes
import errno
from typing import TYPE_CHECKING, Tuple, Union, cast
from .misc import open_file
if TYPE_CHECKING:
# pylint: disable=cyclic-import
from .agent import AgentReader, AgentWriter
try:
import mmapfile
import win32api
import win32con
import win32ui
_pywin32_available = True
except ImportError:
_pywin32_available = False
_AGENT_COPYDATA_ID = 0x804e50ba
_AGENT_MAX_MSGLEN = 8192
_AGENT_NAME = 'Pageant'
_DEFAULT_OPENSSH_PATH = r'\\.\pipe\openssh-ssh-agent'
def _find_agent_window() -> 'win32ui.PyCWnd':
"""Find and return the Pageant window"""
if _pywin32_available:
try:
return win32ui.FindWindow(_AGENT_NAME, _AGENT_NAME)
except win32ui.error:
raise OSError(errno.ENOENT, 'Agent not found') from None
else:
raise OSError(errno.ENOENT, 'PyWin32 not installed') from None
class _CopyDataStruct(ctypes.Structure):
"""Windows COPYDATASTRUCT argument for WM_COPYDATA message"""
_fields_ = (('dwData', ctypes.wintypes.LPARAM),
('cbData', ctypes.wintypes.DWORD),
('lpData', ctypes.c_char_p))
class _PageantTransport:
"""Transport to connect to Pageant agent on Windows"""
def __init__(self) -> None:
self._mapname = '%s%08x' % (_AGENT_NAME, win32api.GetCurrentThreadId())
try:
self._mapfile = mmapfile.mmapfile('', self._mapname,
_AGENT_MAX_MSGLEN, 0, 0)
except mmapfile.error as exc:
raise OSError(errno.EIO, str(exc)) from None
self._cds = _CopyDataStruct(_AGENT_COPYDATA_ID, len(self._mapname) + 1,
self._mapname.encode())
self._writing = False
def write(self, data: bytes) -> None:
"""Write request data to Pageant agent"""
if not self._writing:
self._mapfile.seek(0)
self._writing = True
try:
self._mapfile.write(data)
except ValueError as exc:
raise OSError(errno.EIO, str(exc)) from None
async def readexactly(self, n: int) -> bytes:
"""Read response data from Pageant agent"""
if self._writing:
cwnd = _find_agent_window()
if not cwnd.SendMessage(win32con.WM_COPYDATA, 0,
cast(int, self._cds)):
raise OSError(errno.EIO, 'Unable to send agent request')
self._writing = False
self._mapfile.seek(0)
result = self._mapfile.read(n)
if len(result) != n:
raise asyncio.IncompleteReadError(result, n)
return result
def close(self) -> None:
"""Close the connection to Pageant"""
if self._mapfile:
self._mapfile.close()
async def wait_closed(self) -> None:
"""Wait for the transport to close"""
class _W10OpenSSHTransport:
"""Transport to connect to OpenSSH agent on Windows 10"""
def __init__(self, agent_path: str):
self._agentfile = open_file(agent_path, 'r+b')
async def readexactly(self, n: int) -> bytes:
"""Read response data from OpenSSH agent"""
result = self._agentfile.read(n)
if len(result) != n:
raise asyncio.IncompleteReadError(result, n)
return result
def write(self, data: bytes) -> None:
"""Write request data to OpenSSH agent"""
self._agentfile.write(data)
def close(self) -> None:
"""Close the connection to OpenSSH"""
if self._agentfile:
self._agentfile.close()
async def wait_closed(self) -> None:
"""Wait for the transport to close"""
async def open_agent(agent_path: str) -> Tuple['AgentReader', 'AgentWriter']:
"""Open a connection to the Pageant or Windows 10 OpenSSH agent"""
transport: Union[None, _PageantTransport, _W10OpenSSHTransport] = None
if not agent_path:
try:
_find_agent_window()
transport = _PageantTransport()
except OSError:
agent_path = _DEFAULT_OPENSSH_PATH
if not transport:
transport = _W10OpenSSHTransport(agent_path)
return transport, transport

View File

@@ -0,0 +1,786 @@
# Copyright (c) 2013-2021 by Ron Frederick <ronf@timeheart.net> and others.
#
# This program and the accompanying materials are made available under
# the terms of the Eclipse Public License v2.0 which accompanies this
# distribution and is available at:
#
# http://www.eclipse.org/legal/epl-2.0/
#
# This program may also be made available under the following secondary
# licenses when the conditions for such availability set forth in the
# Eclipse Public License v2.0 are satisfied:
#
# GNU General Public License, Version 2.0, or any later versions of
# that license
#
# SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later
#
# Contributors:
# Ron Frederick - initial implementation, API, and documentation
"""Utilities for encoding and decoding ASN.1 DER data
The der_encode function takes a Python value and encodes it in DER
format, returning a byte string. In addition to supporting standard
Python types, BitString can be used to encode a DER bit string,
ObjectIdentifier can be used to encode OIDs, values can be wrapped
in a TaggedDERObject to set an alternate DER tag on them, and
non-standard types can be encoded by placing them in a RawDERObject.
The der_decode function takes a byte string in DER format and decodes
it into the corresponding Python values.
"""
from typing import Dict, FrozenSet, Sequence, Set, Tuple, Type, TypeVar, Union
from typing import cast
_DERClass = Type['DERType']
_DERClassVar = TypeVar('_DERClassVar', bound='_DERClass')
# ASN.1 object classes
UNIVERSAL = 0x00
APPLICATION = 0x01
CONTEXT_SPECIFIC = 0x02
PRIVATE = 0x03
# ASN.1 universal object tags
END_OF_CONTENT = 0x00
BOOLEAN = 0x01
INTEGER = 0x02
BIT_STRING = 0x03
OCTET_STRING = 0x04
NULL = 0x05
OBJECT_IDENTIFIER = 0x06
UTF8_STRING = 0x0c
SEQUENCE = 0x10
SET = 0x11
IA5_STRING = 0x16
_asn1_class = ('Universal', 'Application', 'Context-specific', 'Private')
_der_class_by_tag: Dict[int, _DERClass] = {}
_der_class_by_type: Dict[Union[object, _DERClass], _DERClass] = {}
def _encode_identifier(asn1_class: int, constructed: bool, tag: int) -> bytes:
"""Encode a DER object's identifier"""
if asn1_class not in (UNIVERSAL, APPLICATION, CONTEXT_SPECIFIC, PRIVATE):
raise ASN1EncodeError('Invalid ASN.1 class')
flags = (asn1_class << 6) | (0x20 if constructed else 0x00)
if tag < 0x20:
identifier = [flags | tag]
else:
identifier = [tag & 0x7f]
while tag >= 0x80:
tag >>= 7
identifier.append(0x80 | (tag & 0x7f))
identifier.append(flags | 0x1f)
return bytes(identifier[::-1])
class ASN1Error(ValueError):
"""ASN.1 coding error"""
class ASN1EncodeError(ASN1Error):
"""ASN.1 DER encoding error"""
class ASN1DecodeError(ASN1Error):
"""ASN.1 DER decoding error"""
class DERType:
"""Parent class for classes which use DERTag decorator"""
identifier: bytes = b''
@staticmethod
def encode(value: object) -> bytes:
"""Encode value as a DER byte string"""
raise NotImplementedError
@classmethod
def decode(cls, constructed: bool, content: bytes) -> object:
"""Decode a DER byte string into an object"""
raise NotImplementedError
class DERTag:
"""A decorator used by classes which convert values to/from DER
Classes which convert Python values to and from DER format
should use the DERTag decorator to indicate what DER tag value
they understand. When DER data is decoded, the tag is looked
up in the list to see which class to call to perform the
decoding.
Classes which convert existing Python types to and from DER
format can specify the list of types they understand in the
optional "types" argument. Otherwise, conversion is expected
to be to and from the new class being defined.
"""
def __init__(self, tag: int, types: Sequence[object] = (),
constructed: bool = False):
self._tag = tag
self._types = types
self._identifier = _encode_identifier(UNIVERSAL, constructed, tag)
def __call__(self, cls: _DERClassVar) -> _DERClassVar:
cls.identifier = self._identifier
_der_class_by_tag[self._tag] = cls
if self._types:
for t in self._types:
_der_class_by_type[t] = cls
else:
_der_class_by_type[cls] = cls
return cls
class RawDERObject:
"""A class which can encode a DER object of an arbitrary type
This object is initialized with an ASN.1 class, tag, and a
byte string representing the already encoded data. Such
objects will never have the constructed flag set, since
that is represented here as a TaggedDERObject.
"""
def __init__(self, tag: int, content: bytes, asn1_class: int):
self.asn1_class = asn1_class
self.tag = tag
self.content = content
def __repr__(self) -> str:
return ('RawDERObject(%s, %s, %r)' %
(_asn1_class[self.asn1_class], self.tag, self.content))
def __eq__(self, other: object) -> bool:
if not isinstance(other, RawDERObject): # pragma: no cover
return NotImplemented
return (self.asn1_class == other.asn1_class and
self.tag == other.tag and self.content == other.content)
def __hash__(self) -> int:
return hash((self.asn1_class, self.tag, self.content))
def encode_identifier(self) -> bytes:
"""Encode the DER identifier for this object as a byte string"""
return _encode_identifier(self.asn1_class, False, self.tag)
@staticmethod
def encode(value: object) -> bytes:
"""Encode the content for this object as a DER byte string"""
return cast('RawDERObject', value).content
class TaggedDERObject:
"""An explicitly tagged DER object
This object provides a way to wrap an ASN.1 object with an
explicit tag. The value (including the tag representing its
actual type) is then encoded as part of its value. By
default, the ASN.1 class for these objects is CONTEXT_SPECIFIC,
and the DER encoding always marks these values as constructed.
"""
def __init__(self, tag: int, value: object,
asn1_class: int = CONTEXT_SPECIFIC):
self.asn1_class = asn1_class
self.tag = tag
self.value = value
def __repr__(self) -> str:
if self.asn1_class == CONTEXT_SPECIFIC:
return 'TaggedDERObject(%s, %r)' % (self.tag, self.value)
else:
return ('TaggedDERObject(%s, %s, %r)' %
(_asn1_class[self.asn1_class], self.tag, self.value))
def __eq__(self, other: object) -> bool:
if not isinstance(other, TaggedDERObject): # pragma: no cover
return NotImplemented
return (self.asn1_class == other.asn1_class and
self.tag == other.tag and self.value == other.value)
def __hash__(self) -> int:
return hash((self.asn1_class, self.tag, self.value))
def encode_identifier(self) -> bytes:
"""Encode the DER identifier for this object as a byte string"""
return _encode_identifier(self.asn1_class, True, self.tag)
@staticmethod
def encode(value: object) -> bytes:
"""Encode the content for this object as a DER byte string"""
return der_encode(cast('TaggedDERObject', value).value)
@DERTag(NULL, (type(None),))
class _Null(DERType):
"""A null value"""
@staticmethod
def encode(value: object) -> bytes:
"""Encode a DER null value"""
# pylint: disable=unused-argument
return b''
@classmethod
def decode(cls, constructed: bool, content: bytes) -> None:
"""Decode a DER null value"""
if constructed:
raise ASN1DecodeError('NULL should not be constructed')
if content:
raise ASN1DecodeError('NULL should not have associated content')
return None
@DERTag(BOOLEAN, (bool,))
class _Boolean(DERType):
"""A boolean value"""
@staticmethod
def encode(value: object) -> bytes:
"""Encode a DER boolean value"""
return b'\xff' if value else b'\0'
@classmethod
def decode(cls, constructed: bool, content: bytes) -> bool:
"""Decode a DER boolean value"""
if constructed:
raise ASN1DecodeError('BOOLEAN should not be constructed')
if content not in {b'\x00', b'\xff'}:
raise ASN1DecodeError('BOOLEAN content must be 0x00 or 0xff')
return bool(content[0])
@DERTag(INTEGER, (int,))
class _Integer(DERType):
"""An integer value"""
@staticmethod
def encode(value: object) -> bytes:
"""Encode a DER integer value"""
i = cast(int, value)
l = i.bit_length()
l = l // 8 + 1 if l % 8 == 0 else (l + 7) // 8
result = i.to_bytes(l, 'big', signed=True)
return result[1:] if result.startswith(b'\xff\x80') else result
@classmethod
def decode(cls, constructed: bool, content: bytes) -> int:
"""Decode a DER integer value"""
if constructed:
raise ASN1DecodeError('INTEGER should not be constructed')
return int.from_bytes(content, 'big', signed=True)
@DERTag(OCTET_STRING, (bytes, bytearray))
class _OctetString(DERType):
"""An octet string value"""
@staticmethod
def encode(value: object) -> bytes:
"""Encode a DER octet string"""
return cast(bytes, value)
@classmethod
def decode(cls, constructed: bool, content: bytes) -> bytes:
"""Decode a DER octet string"""
if constructed:
raise ASN1DecodeError('OCTET STRING should not be constructed')
return content
@DERTag(UTF8_STRING, (str,))
class _UTF8String(DERType):
"""A UTF-8 string value"""
@staticmethod
def encode(value: object) -> bytes:
"""Encode a DER UTF-8 string"""
return cast(str, value).encode('utf-8')
@classmethod
def decode(cls, constructed: bool, content: bytes) -> str:
"""Decode a DER UTF-8 string"""
if constructed:
raise ASN1DecodeError('UTF8 STRING should not be constructed')
return content.decode('utf-8')
@DERTag(SEQUENCE, (list, tuple), constructed=True)
class _Sequence(DERType):
"""A sequence of values"""
@staticmethod
def encode(value: object) -> bytes:
"""Encode a sequence of DER values"""
seq_value = cast(Sequence[object], value)
return b''.join(der_encode(item) for item in seq_value)
@classmethod
def decode(cls, constructed: bool, content: bytes) -> Sequence[object]:
"""Decode a sequence of DER values"""
if not constructed:
raise ASN1DecodeError('SEQUENCE should always be constructed')
offset = 0
length = len(content)
value = []
while offset < length:
item, consumed = der_decode_partial(content[offset:])
value.append(item)
offset += consumed
return tuple(value)
@DERTag(SET, (set, frozenset), constructed=True)
class _Set(DERType):
"""A set of DER values"""
@staticmethod
def encode(value: object) -> bytes:
"""Encode a set of DER values"""
set_value = cast(Union[FrozenSet[object], Set[object]], value)
return b''.join(sorted(der_encode(item) for item in set_value))
@classmethod
def decode(cls, constructed: bool, content: bytes) -> FrozenSet[object]:
"""Decode a set of DER values"""
if not constructed:
raise ASN1DecodeError('SET should always be constructed')
offset = 0
length = len(content)
value = set()
while offset < length:
item, consumed = der_decode_partial(content[offset:])
value.add(item)
offset += consumed
return frozenset(value)
@DERTag(BIT_STRING)
class BitString(DERType):
"""A string of bits
This object can be initialized either with a byte string and an
optional count of the number of least-significant bits in the last
byte which should not be included in the value, or with a string
consisting only of the digits '0' and '1'.
An optional 'named' flag can also be set, indicating that the
BitString was specified with named bits, indicating that the proper
DER encoding of it should strip any trailing zeroes.
"""
def __init__(self, value: object, unused: int = 0, named: bool = False):
if unused < 0 or unused > 7:
raise ASN1EncodeError('Unused bit count must be between 0 and 7')
if isinstance(value, bytes):
if unused:
if not value:
raise ASN1EncodeError('Can\'t have unused bits with empty '
'value')
elif value[-1] & ((1 << unused) - 1):
raise ASN1EncodeError('Unused bits in value should be '
'zero')
elif isinstance(value, str):
if unused:
raise ASN1EncodeError('Unused bit count should not be set '
'when providing a string')
used = len(value) % 8
unused = 8 - used if used else 0
value += unused * '0'
value = bytes(int(value[i:i+8], 2)
for i in range(0, len(value), 8))
else:
raise ASN1EncodeError('Unexpected type of bit string value')
if named:
while value and not value[-1] & (1 << unused):
unused += 1
if unused == 8:
value = value[:-1]
unused = 0
self.value = value
self.unused = unused
def __str__(self) -> str:
result = ''.join(bin(b)[2:].zfill(8) for b in self.value)
if self.unused:
result = result[:-self.unused]
return result
def __repr__(self) -> str:
return "BitString('%s')" % self
def __eq__(self, other: object) -> bool:
if not isinstance(other, BitString): # pragma: no cover
return NotImplemented
return self.value == other.value and self.unused == other.unused
def __hash__(self) -> int:
return hash((self.value, self.unused))
@staticmethod
def encode(value: object) -> bytes:
"""Encode a DER bit string"""
bitstr_value = cast('BitString', value)
return bytes((bitstr_value.unused,)) + bitstr_value.value
@classmethod
def decode(cls, constructed: bool, content: bytes) -> 'BitString':
"""Decode a DER bit string"""
if constructed:
raise ASN1DecodeError('BIT STRING should not be constructed')
if not content or content[0] > 7:
raise ASN1DecodeError('Invalid unused bit count')
return cls(content[1:], unused=content[0])
@DERTag(IA5_STRING)
class IA5String(DERType):
"""An ASCII string value"""
def __init__(self, value: Union[bytes, bytearray]):
self.value = value
def __str__(self) -> str:
return '%s' % self.value.decode('ascii')
def __repr__(self) -> str:
return 'IA5String(%r)' % self.value
def __eq__(self, other: object) -> bool: # pragma: no cover
if not isinstance(other, IA5String):
return NotImplemented
return self.value == other.value
def __hash__(self) -> int:
return hash(self.value)
@staticmethod
def encode(value: object) -> bytes:
"""Encode a DER IA5 string"""
# ASN.1 defines this type as only containing ASCII characters, but
# some tools expecting ASN.1 allow IA5Strings to contain other
# characters, so we leave it up to the caller to pass in a byte
# string which has already done the appropriate encoding of any
# non-ASCII characters.
return cast('IA5String', value).value
@classmethod
def decode(cls, constructed: bool, content: bytes) -> 'IA5String':
"""Decode a DER IA5 string"""
if constructed:
raise ASN1DecodeError('IA5 STRING should not be constructed')
# As noted in the encode method above, the decoded value for this
# type is a byte string, leaving the decoding of any non-ASCII
# characters up to the caller.
return cls(content)
@DERTag(OBJECT_IDENTIFIER)
class ObjectIdentifier(DERType):
"""An object identifier (OID) value
This object can be initialized from a string of dot-separated
integer values, representing a hierarchical namespace. All OIDs
show have at least two components, with the first being between
0 and 2 (indicating ITU-T, ISO, or joint assignment). In cases
where the first component is 0 or 1, the second component must
be in the range 0 to 39 due to the way these first two components
are encoded.
"""
def __init__(self, value: str):
self.value = value
def __str__(self) -> str:
return self.value
def __repr__(self) -> str:
return "ObjectIdentifier('%s')" % self.value
def __eq__(self, other: object) -> bool:
if not isinstance(other, ObjectIdentifier): # pragma: no cover
return NotImplemented
return self.value == other.value
def __hash__(self) -> int:
return hash(self.value)
@staticmethod
def encode(value: object) -> bytes:
"""Encode a DER object identifier"""
def _bytes(component: int) -> bytes:
"""Convert a single element of an OID to a DER byte string"""
if component < 0:
raise ASN1EncodeError('Components of object identifier must '
'be greater than or equal to 0')
result = [component & 0x7f]
while component >= 0x80:
component >>= 7
result.append(0x80 | (component & 0x7f))
return bytes(result[::-1])
oid_value = cast('ObjectIdentifier', value)
try:
components = [int(c) for c in oid_value.value.split('.')]
except ValueError:
raise ASN1EncodeError('Component values must be '
'integers') from None
if len(components) < 2:
raise ASN1EncodeError('Object identifiers must have at least two '
'components')
elif components[0] < 0 or components[0] > 2:
raise ASN1EncodeError('First component of object identifier must '
'be between 0 and 2')
elif components[0] < 2 and (components[1] < 0 or components[1] > 39):
raise ASN1EncodeError('Second component of object identifier must '
'be between 0 and 39')
components[0:2] = [components[0]*40 + components[1]]
return b''.join(_bytes(c) for c in components)
@classmethod
def decode(cls, constructed: bool, content: bytes) -> 'ObjectIdentifier':
"""Decode a DER object identifier"""
if constructed:
raise ASN1DecodeError('OBJECT IDENTIFIER should not be '
'constructed')
if not content:
raise ASN1DecodeError('Empty object identifier')
b = content[0]
components = list(divmod(b, 40)) if b < 80 else [2, b-80]
component = 0
for b in content[1:]:
if b == 0x80 and component == 0:
raise ASN1DecodeError('Invalid component')
elif b < 0x80:
components.append(component | b)
component = 0
else:
component |= b & 0x7f
component <<= 7
if component:
raise ASN1DecodeError('Incomplete component')
return cls('.'.join(str(c) for c in components))
def der_encode(value: object) -> bytes:
"""Encode a value in DER format
This function takes a Python value and encodes it in DER format.
The following mapping of types is used:
NoneType -> NULL
bool -> BOOLEAN
int -> INTEGER
bytes, bytearray -> OCTET STRING
str -> UTF8 STRING
list, tuple -> SEQUENCE
set, frozenset -> SET
BitString -> BIT STRING
ObjectIdentifier -> OBJECT IDENTIFIER
An explicitly tagged DER object can be encoded by passing in a
TaggedDERObject which specifies the ASN.1 class, tag, and value
to encode.
Other types can be encoded by passing in a RawDERObject which
specifies the ASN.1 class, tag, and raw content octets to encode.
"""
t = type(value)
if t in (RawDERObject, TaggedDERObject):
value = cast(Union[RawDERObject, TaggedDERObject], value)
identifier = value.encode_identifier()
content = value.encode(value)
elif t in _der_class_by_type:
cls = _der_class_by_type[t]
identifier = cls.identifier
content = cls.encode(value)
else:
raise ASN1EncodeError('Cannot DER encode type %s' % t.__name__)
length = len(content)
if length < 0x80:
len_bytes = bytes((length,))
else:
len_bytes = length.to_bytes((length.bit_length() + 7) // 8, 'big')
len_bytes = bytes((0x80 | len(len_bytes),)) + len_bytes
return identifier + len_bytes + content
def der_decode_partial(data: bytes) -> Tuple[object, int]:
"""Decode a value in DER format and return the number of bytes consumed"""
if len(data) < 2:
raise ASN1DecodeError('Incomplete data')
tag = data[0]
asn1_class, constructed, tag = tag >> 6, bool(tag & 0x20), tag & 0x1f
offset = 1
if tag == 0x1f:
tag = 0
for b in data[offset:]:
offset += 1
if b < 0x80:
tag |= b
break
else:
tag |= b & 0x7f
tag <<= 7
else:
raise ASN1DecodeError('Incomplete tag')
if offset >= len(data):
raise ASN1DecodeError('Incomplete data')
length = data[offset]
offset += 1
if length > 0x80:
len_size = length & 0x7f
length = int.from_bytes(data[offset:offset+len_size], 'big')
offset += len_size
elif length == 0x80:
raise ASN1DecodeError('Indefinite length not allowed')
end = offset + length
content = data[offset:end]
if end > len(data):
raise ASN1DecodeError('Incomplete data')
if asn1_class == UNIVERSAL and tag in _der_class_by_tag:
cls = _der_class_by_tag[tag]
value = cls.decode(constructed, content)
elif constructed:
value = TaggedDERObject(tag, der_decode(content), asn1_class)
else:
value = RawDERObject(tag, content, asn1_class)
return value, end
def der_decode(data: bytes) -> object:
"""Decode a value in DER format
This function takes a byte string in DER format and converts it
to a corresponding set of Python objects. The following mapping
of ASN.1 tags to Python types is used:
NULL -> NoneType
BOOLEAN -> bool
INTEGER -> int
OCTET STRING -> bytes
UTF8 STRING -> str
SEQUENCE -> tuple
SET -> frozenset
BIT_STRING -> BitString
OBJECT IDENTIFIER -> ObjectIdentifier
Explicitly tagged objects are returned as type TaggedDERObject,
with fields holding the object class, tag, and tagged value.
Other object tags are returned as type RawDERObject, with fields
holding the object class, tag, and raw content octets.
If partial_ok is True, this function returns a tuple of the decoded
value and number of bytes consumed. Otherwise, all data bytes must
be consumed and only the decoded value is returned.
"""
value, end = der_decode_partial(data)
if end < len(data):
raise ASN1DecodeError('Data contains unexpected bytes at end')
return value

View File

@@ -0,0 +1,992 @@
# Copyright (c) 2013-2022 by Ron Frederick <ronf@timeheart.net> and others.
#
# This program and the accompanying materials are made available under
# the terms of the Eclipse Public License v2.0 which accompanies this
# distribution and is available at:
#
# http://www.eclipse.org/legal/epl-2.0/
#
# This program may also be made available under the following secondary
# licenses when the conditions for such availability set forth in the
# Eclipse Public License v2.0 are satisfied:
#
# GNU General Public License, Version 2.0, or any later versions of
# that license
#
# SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later
#
# Contributors:
# Ron Frederick - initial implementation, API, and documentation
"""SSH authentication handlers"""
from typing import TYPE_CHECKING, Awaitable, Dict, List, Optional
from typing import Sequence, Tuple, Type, Union, cast
from .constants import DEFAULT_LANG
from .gss import GSSBase, GSSError
from .logging import SSHLogger
from .misc import ProtocolError, PasswordChangeRequired, get_symbol_names
from .packet import Boolean, String, UInt32, SSHPacket, SSHPacketHandler
from .public_key import SigningKey
from .saslprep import saslprep, SASLPrepError
if TYPE_CHECKING:
import asyncio
# pylint: disable=cyclic-import
from .connection import SSHConnection, SSHClientConnection
from .connection import SSHServerConnection
KbdIntPrompts = Sequence[Tuple[str, bool]]
KbdIntNewChallenge = Tuple[str, str, str, KbdIntPrompts]
KbdIntChallenge = Union[bool, KbdIntNewChallenge]
KbdIntResponse = Sequence[str]
PasswordChangeResponse = Tuple[str, str]
# SSH message values for GSS auth
MSG_USERAUTH_GSSAPI_RESPONSE = 60
MSG_USERAUTH_GSSAPI_TOKEN = 61
MSG_USERAUTH_GSSAPI_EXCHANGE_COMPLETE = 63
MSG_USERAUTH_GSSAPI_ERROR = 64
MSG_USERAUTH_GSSAPI_ERRTOK = 65
MSG_USERAUTH_GSSAPI_MIC = 66
# SSH message values for public key auth
MSG_USERAUTH_PK_OK = 60
# SSH message values for keyboard-interactive auth
MSG_USERAUTH_INFO_REQUEST = 60
MSG_USERAUTH_INFO_RESPONSE = 61
# SSH message values for password auth
MSG_USERAUTH_PASSWD_CHANGEREQ = 60
_auth_methods: List[bytes] = []
_client_auth_handlers: Dict[bytes, Type['ClientAuth']] = {}
_server_auth_handlers: Dict[bytes, Type['ServerAuth']] = {}
class Auth(SSHPacketHandler):
"""Parent class for authentication"""
def __init__(self, conn: 'SSHConnection', coro: Awaitable[None]):
self._conn = conn
self._logger = conn.logger
self._coro: Optional['asyncio.Task[None]'] = conn.create_task(coro)
def send_packet(self, pkttype: int, *args: bytes,
trivial: bool = True) -> None:
"""Send an auth packet"""
self._conn.send_userauth_packet(pkttype, *args, handler=self,
trivial=trivial)
@property
def logger(self) -> SSHLogger:
"""A logger associated with this authentication handler"""
return self._logger
def create_task(self, coro: Awaitable[None]) -> None:
"""Create an asynchronous auth task"""
self.cancel()
self._coro = self._conn.create_task(coro)
def cancel(self) -> None:
"""Cancel any authentication in progress"""
if self._coro: # pragma: no branch
self._coro.cancel()
self._coro = None
class ClientAuth(Auth):
"""Parent class for client authentication"""
_conn: 'SSHClientConnection'
def __init__(self, conn: 'SSHClientConnection', method: bytes):
self._method = method
super().__init__(conn, self._start())
async def _start(self) -> None:
"""Abstract method for starting client authentication"""
# Provided by subclass
raise NotImplementedError
def auth_succeeded(self) -> None:
"""Callback when auth succeeds"""
def auth_failed(self) -> None:
"""Callback when auth fails"""
async def send_request(self, *args: bytes,
key: Optional[SigningKey] = None,
trivial: bool = True) -> None:
"""Send a user authentication request"""
await self._conn.send_userauth_request(self._method, *args, key=key,
trivial=trivial)
class _ClientNullAuth(ClientAuth):
"""Client side implementation of null auth"""
async def _start(self) -> None:
"""Start client null authentication"""
await self.send_request()
class _ClientGSSKexAuth(ClientAuth):
"""Client side implementation of GSS key exchange auth"""
async def _start(self) -> None:
"""Start client GSS key exchange authentication"""
if self._conn.gss_kex_auth_requested():
self.logger.debug1('Trying GSS key exchange auth')
await self.send_request(key=self._conn.get_gss_context(),
trivial=False)
else:
self._conn.try_next_auth()
class _ClientGSSMICAuth(ClientAuth):
"""Client side implementation of GSS MIC auth"""
_handler_names = get_symbol_names(globals(), 'MSG_USERAUTH_GSSAPI_')
def __init__(self, conn: 'SSHClientConnection', method: bytes):
super().__init__(conn, method)
self._gss: Optional[GSSBase] = None
self._got_error = False
async def _start(self) -> None:
"""Start client GSS MIC authentication"""
if self._conn.gss_mic_auth_requested():
self.logger.debug1('Trying GSS MIC auth')
self._gss = self._conn.get_gss_context()
self._gss.reset()
mechs = b''.join((String(mech) for mech in self._gss.mechs))
await self.send_request(UInt32(len(self._gss.mechs)), mechs)
else:
self._conn.try_next_auth()
def _finish(self) -> None:
"""Finish client GSS MIC authentication"""
assert self._gss is not None
if self._gss.provides_integrity:
data = self._conn.get_userauth_request_data(self._method)
self.send_packet(MSG_USERAUTH_GSSAPI_MIC,
String(self._gss.sign(data)),
trivial=False)
else:
self.send_packet(MSG_USERAUTH_GSSAPI_EXCHANGE_COMPLETE)
def _process_response(self, _pkttype: int, _pktid: int,
packet: SSHPacket) -> None:
"""Process a GSS response from the server"""
mech = packet.get_string()
packet.check_end()
assert self._gss is not None
if mech not in self._gss.mechs:
raise ProtocolError('Mechanism mismatch')
try:
token = self._gss.step()
assert token is not None
self.send_packet(MSG_USERAUTH_GSSAPI_TOKEN, String(token))
if self._gss.complete:
self._finish()
except GSSError as exc:
if exc.token:
self.send_packet(MSG_USERAUTH_GSSAPI_ERRTOK, String(exc.token))
self._conn.try_next_auth()
def _process_token(self, _pkttype: int, _pktid: int,
packet: SSHPacket) -> None:
"""Process a GSS token from the server"""
token: Optional[bytes] = packet.get_string()
packet.check_end()
assert self._gss is not None
try:
token = self._gss.step(token)
if token:
self.send_packet(MSG_USERAUTH_GSSAPI_TOKEN, String(token))
if self._gss.complete:
self._finish()
except GSSError as exc:
if exc.token:
self.send_packet(MSG_USERAUTH_GSSAPI_ERRTOK, String(exc.token))
self._conn.try_next_auth()
def _process_error(self, _pkttype: int, _pktid: int,
packet: SSHPacket) -> None:
"""Process a GSS error from the server"""
_ = packet.get_uint32() # major_status
_ = packet.get_uint32() # minor_status
msg = packet.get_string()
_ = packet.get_string() # lang
packet.check_end()
self.logger.debug1('GSS error from server: %s', msg)
self._got_error = True
def _process_error_token(self, _pkttype: int, _pktid: int,
packet: SSHPacket) -> None:
"""Process a GSS error token from the server"""
token = packet.get_string()
packet.check_end()
assert self._gss is not None
try:
self._gss.step(token)
except GSSError as exc:
if not self._got_error: # pragma: no cover
self.logger.debug1('GSS error from server: %s', str(exc))
_packet_handlers = {
MSG_USERAUTH_GSSAPI_RESPONSE: _process_response,
MSG_USERAUTH_GSSAPI_TOKEN: _process_token,
MSG_USERAUTH_GSSAPI_ERROR: _process_error,
MSG_USERAUTH_GSSAPI_ERRTOK: _process_error_token
}
class _ClientHostBasedAuth(ClientAuth):
"""Client side implementation of host based auth"""
async def _start(self) -> None:
"""Start client host based authentication"""
keypair, client_host, client_username = \
await self._conn.host_based_auth_requested()
if keypair is None:
self._conn.try_next_auth()
return
self.logger.debug1('Trying host based auth of user %s on host %s '
'with %s host key', client_username, client_host,
keypair.algorithm)
try:
await self.send_request(String(keypair.algorithm),
String(keypair.public_data),
String(client_host),
String(client_username), key=keypair)
except ValueError as exc:
self.logger.debug1('Host based auth failed: %s', str(exc))
self._conn.try_next_auth()
class _ClientPublicKeyAuth(ClientAuth):
"""Client side implementation of public key auth"""
_handler_names = get_symbol_names(globals(), 'MSG_USERAUTH_PK_')
async def _start(self) -> None:
"""Start client public key authentication"""
self._keypair = await self._conn.public_key_auth_requested()
if self._keypair is None:
self._conn.try_next_auth()
return
self.logger.debug1('Trying public key auth with %s key',
self._keypair.algorithm)
await self.send_request(Boolean(False),
String(self._keypair.algorithm),
String(self._keypair.public_data))
async def _send_signed_request(self) -> None:
"""Send signed public key request"""
assert self._keypair is not None
self.logger.debug1('Signing request with %s key',
self._keypair.algorithm)
await self.send_request(Boolean(True),
String(self._keypair.algorithm),
String(self._keypair.public_data),
key=self._keypair, trivial=False)
def _process_public_key_ok(self, _pkttype: int, _pktid: int,
packet: SSHPacket) -> None:
"""Process a public key ok response"""
algorithm = packet.get_string()
key_data = packet.get_string()
packet.check_end()
assert self._keypair is not None
if (algorithm != self._keypair.algorithm or
key_data != self._keypair.public_data):
raise ProtocolError('Key mismatch')
self.create_task(self._send_signed_request())
_packet_handlers = {
MSG_USERAUTH_PK_OK: _process_public_key_ok
}
class _ClientKbdIntAuth(ClientAuth):
"""Client side implementation of keyboard-interactive auth"""
_handler_names = get_symbol_names(globals(), 'MSG_USERAUTH_INFO_')
async def _start(self) -> None:
"""Start client keyboard interactive authentication"""
submethods = await self._conn.kbdint_auth_requested()
if submethods is None:
self._conn.try_next_auth()
return
self.logger.debug1('Trying keyboard-interactive auth')
await self.send_request(String(''), String(submethods))
async def _receive_challenge(self, name: str, instruction: str, lang: str,
prompts: KbdIntPrompts) -> None:
"""Receive and respond to a keyboard interactive challenge"""
responses = \
await self._conn.kbdint_challenge_received(name, instruction,
lang, prompts)
if responses is None:
self._conn.try_next_auth()
return
self.send_packet(MSG_USERAUTH_INFO_RESPONSE, UInt32(len(responses)),
b''.join(String(r) for r in responses),
trivial=not responses)
def _process_info_request(self, _pkttype: int, _pktid: int,
packet: SSHPacket) -> None:
"""Process a keyboard interactive authentication request"""
name_bytes = packet.get_string()
instruction_bytes = packet.get_string()
lang_bytes = packet.get_string()
try:
name = name_bytes.decode('utf-8')
instruction = instruction_bytes.decode('utf-8')
lang = lang_bytes.decode('ascii')
except UnicodeDecodeError:
raise ProtocolError('Invalid keyboard interactive '
'info request') from None
num_prompts = packet.get_uint32()
prompts = []
for _ in range(num_prompts):
prompt_bytes = packet.get_string()
echo = packet.get_boolean()
try:
prompt = prompt_bytes.decode('utf-8')
except UnicodeDecodeError:
raise ProtocolError('Invalid keyboard interactive '
'info request') from None
prompts.append((prompt, echo))
self.create_task(self._receive_challenge(name, instruction,
lang, prompts))
_packet_handlers = {
MSG_USERAUTH_INFO_REQUEST: _process_info_request
}
class _ClientPasswordAuth(ClientAuth):
"""Client side implementation of password auth"""
_handler_names = get_symbol_names(globals(), 'MSG_USERAUTH_PASSWD_')
def __init__(self, conn: 'SSHClientConnection', method: bytes):
super().__init__(conn, method)
self._password_change = False
async def _start(self) -> None:
"""Start client password authentication"""
password = await self._conn.password_auth_requested()
if password is None:
self._conn.try_next_auth()
return
self.logger.debug1('Trying password auth')
await self.send_request(Boolean(False), String(password),
trivial=False)
async def _change_password(self, prompt: str, lang: str) -> None:
"""Start password change"""
result = await self._conn.password_change_requested(prompt, lang)
if result == NotImplemented:
# Password change not supported - move on to the next auth method
self._conn.try_next_auth()
return
self.logger.debug1('Trying to chsnge password')
old_password, new_password = cast(PasswordChangeResponse, result)
self._password_change = True
await self.send_request(Boolean(True),
String(old_password.encode('utf-8')),
String(new_password.encode('utf-8')),
trivial=False)
def auth_succeeded(self) -> None:
if self._password_change:
self._password_change = False
self._conn.password_changed()
def auth_failed(self) -> None:
if self._password_change:
self._password_change = False
self._conn.password_change_failed()
def _process_password_change(self, _pkttype: int, _pktid: int,
packet: SSHPacket) -> None:
"""Process a password change request"""
prompt_bytes = packet.get_string()
lang_bytes = packet.get_string()
try:
prompt = prompt_bytes.decode('utf-8')
lang = lang_bytes.decode('ascii')
except UnicodeDecodeError:
raise ProtocolError('Invalid password change request') from None
self.auth_failed()
self.create_task(self._change_password(prompt, lang))
_packet_handlers = {
MSG_USERAUTH_PASSWD_CHANGEREQ: _process_password_change
}
class ServerAuth(Auth):
"""Parent class for server authentication"""
_conn: 'SSHServerConnection'
def __init__(self, conn: 'SSHServerConnection', username: str,
method: bytes, packet: SSHPacket):
self._username = username
self._method = method
super().__init__(conn, self._start(packet))
@classmethod
def supported(cls, conn: 'SSHServerConnection') -> bool:
"""Return whether this authentication method is supported"""
raise NotImplementedError
async def _start(self, packet: SSHPacket) -> None:
"""Abstract method for starting server authentication"""
# Provided by subclass
raise NotImplementedError
def send_failure(self, partial_success: bool = False) -> None:
"""Send a user authentication failure response"""
self._conn.send_userauth_failure(partial_success)
def send_success(self) -> None:
"""Send a user authentication success response"""
self._conn.send_userauth_success()
class _ServerNullAuth(ServerAuth):
"""Server side implementation of null auth"""
@classmethod
def supported(cls, conn: 'SSHServerConnection') -> bool:
"""Return that null authentication is never a supported auth mode"""
return False
async def _start(self, packet: SSHPacket) -> None:
"""Supported always returns false, so we never get here"""
class _ServerGSSKexAuth(ServerAuth):
"""Server side implementation of GSS key exchange auth"""
def __init__(self, conn: 'SSHServerConnection', username: str,
method: bytes, packet: SSHPacket):
super().__init__(conn, username, method, packet)
self._gss = conn.get_gss_context()
@classmethod
def supported(cls, conn: 'SSHServerConnection') -> bool:
"""Return whether GSS key exchange authentication is supported"""
return conn.gss_kex_auth_supported()
async def _start(self, packet: SSHPacket) -> None:
"""Start server GSS key exchange authentication"""
mic = packet.get_string()
packet.check_end()
self.logger.debug1('Trying GSS key exchange auth')
data = self._conn.get_userauth_request_data(self._method)
if (self._gss.complete and self._gss.verify(data, mic) and
(await self._conn.validate_gss_principal(self._username,
self._gss.user,
self._gss.host))):
self.send_success()
else:
self.send_failure()
class _ServerGSSMICAuth(ServerAuth):
"""Server side implementation of GSS MIC auth"""
_handler_names = get_symbol_names(globals(), 'MSG_USERAUTH_GSSAPI_')
def __init__(self, conn: 'SSHServerConnection', username: str,
method: bytes, packet: SSHPacket) -> None:
super().__init__(conn, username, method, packet)
self._gss = conn.get_gss_context()
@classmethod
def supported(cls, conn: 'SSHServerConnection') -> bool:
"""Return whether GSS MIC authentication is supported"""
return conn.gss_mic_auth_supported()
async def _start(self, packet: SSHPacket) -> None:
"""Start server GSS MIC authentication"""
mechs = set()
n = packet.get_uint32()
for _ in range(n):
mechs.add(packet.get_string())
packet.check_end()
match = None
for mech in self._gss.mechs:
if mech in mechs:
match = mech
break
if not match:
self.send_failure()
return
self.logger.debug1('Trying GSS MIC auth')
self._gss.reset()
self.send_packet(MSG_USERAUTH_GSSAPI_RESPONSE, String(match))
async def _finish(self) -> None:
"""Finish server GSS MIC authentication"""
if (await self._conn.validate_gss_principal(self._username,
self._gss.user,
self._gss.host)):
self.send_success()
else:
self.send_failure()
def _process_token(self, _pkttype: int, _pktid: int,
packet: SSHPacket) -> None:
"""Process a GSS token from the client"""
token: Optional[bytes] = packet.get_string()
packet.check_end()
try:
token = self._gss.step(token)
if token:
self.send_packet(MSG_USERAUTH_GSSAPI_TOKEN, String(token))
except GSSError as exc:
self.send_packet(MSG_USERAUTH_GSSAPI_ERROR, UInt32(exc.maj_code),
UInt32(exc.min_code), String(str(exc)),
String(DEFAULT_LANG))
if exc.token:
self.send_packet(MSG_USERAUTH_GSSAPI_ERRTOK, String(exc.token))
self.send_failure()
def _process_exchange_complete(self, _pkttype: int, _pktid: int,
packet: SSHPacket) -> None:
"""Process a GSS exchange complete message from the client"""
packet.check_end()
if self._gss.complete and not self._gss.provides_integrity:
self.create_task(self._finish())
else:
self.send_failure()
def _process_error_token(self, _pkttype: int, _pktid: int,
packet: SSHPacket) -> None:
"""Process a GSS error token from the client"""
token = packet.get_string()
packet.check_end()
try:
self._gss.step(token)
except GSSError as exc:
self.logger.debug1('GSS error from client: %s', str(exc))
def _process_mic(self, _pkttype: int, _pktid: int,
packet: SSHPacket) -> None:
"""Process a GSS MIC from the client"""
mic = packet.get_string()
packet.check_end()
data = self._conn.get_userauth_request_data(self._method)
if (self._gss.complete and self._gss.provides_integrity and
self._gss.verify(data, mic)):
self.create_task(self._finish())
else:
self.send_failure()
_packet_handlers = {
MSG_USERAUTH_GSSAPI_TOKEN: _process_token,
MSG_USERAUTH_GSSAPI_EXCHANGE_COMPLETE: _process_exchange_complete,
MSG_USERAUTH_GSSAPI_ERRTOK: _process_error_token,
MSG_USERAUTH_GSSAPI_MIC: _process_mic
}
class _ServerHostBasedAuth(ServerAuth):
"""Server side implementation of host based auth"""
@classmethod
def supported(cls, conn: 'SSHServerConnection') -> bool:
"""Return whether host based authentication is supported"""
return conn.host_based_auth_supported()
async def _start(self, packet: SSHPacket) -> None:
"""Start server host based authentication"""
algorithm = packet.get_string()
key_data = packet.get_string()
client_host_bytes = packet.get_string()
client_username_bytes = packet.get_string()
msg = packet.get_consumed_payload()
signature = packet.get_string()
packet.check_end()
try:
client_host = client_host_bytes.decode('utf-8')
client_username = saslprep(client_username_bytes.decode('utf-8'))
except (UnicodeDecodeError, SASLPrepError):
raise ProtocolError('Invalid host-based auth request') from None
self.logger.debug1('Verifying host based auth of user %s '
'on host %s with %s host key', client_username,
client_host, algorithm)
if (await self._conn.validate_host_based_auth(self._username,
key_data, client_host,
client_username,
msg, signature)):
self.send_success()
else:
self.send_failure()
class _ServerPublicKeyAuth(ServerAuth):
"""Server side implementation of public key auth"""
@classmethod
def supported(cls, conn: 'SSHServerConnection') -> bool:
"""Return whether public key authentication is supported"""
return conn.public_key_auth_supported()
async def _start(self, packet: SSHPacket) -> None:
"""Start server public key authentication"""
sig_present = packet.get_boolean()
algorithm = packet.get_string()
key_data = packet.get_string()
if sig_present:
msg = packet.get_consumed_payload()
signature = packet.get_string()
else:
msg = b''
signature = b''
packet.check_end()
if sig_present:
self.logger.debug1('Verifying request with %s key', algorithm)
else:
self.logger.debug1('Trying public key auth with %s key', algorithm)
if (await self._conn.validate_public_key(self._username, key_data,
msg, signature)):
if sig_present:
self.send_success()
else:
self.send_packet(MSG_USERAUTH_PK_OK, String(algorithm),
String(key_data))
else:
self.send_failure()
class _ServerKbdIntAuth(ServerAuth):
"""Server side implementation of keyboard-interactive auth"""
_handler_names = get_symbol_names(globals(), 'MSG_USERAUTH_INFO_')
@classmethod
def supported(cls, conn: 'SSHServerConnection') -> bool:
"""Return whether keyboard interactive authentication is supported"""
return conn.kbdint_auth_supported()
async def _start(self, packet: SSHPacket) -> None:
"""Start server keyboard interactive authentication"""
lang_bytes = packet.get_string()
submethods_bytes = packet.get_string()
packet.check_end()
try:
lang = lang_bytes.decode('ascii')
submethods = submethods_bytes.decode('utf-8')
except UnicodeDecodeError:
raise ProtocolError('Invalid keyboard interactive '
'auth request') from None
self.logger.debug1('Trying keyboard-interactive auth')
challenge = await self._conn.get_kbdint_challenge(self._username,
lang, submethods)
self._send_challenge(challenge)
def _send_challenge(self, challenge: KbdIntChallenge) -> None:
"""Send a keyboard interactive authentication request"""
if isinstance(challenge, (tuple, list)):
name, instruction, lang, prompts = challenge
num_prompts = len(prompts)
prompts_bytes = (String(prompt) + Boolean(echo)
for prompt, echo in prompts)
self.send_packet(MSG_USERAUTH_INFO_REQUEST, String(name),
String(instruction), String(lang),
UInt32(num_prompts), *prompts_bytes)
elif challenge:
self.send_success()
else:
self.send_failure()
async def _validate_response(self, responses: KbdIntResponse) -> None:
"""Validate a keyboard interactive authentication response"""
next_challenge = \
await self._conn.validate_kbdint_response(self._username, responses)
self._send_challenge(next_challenge)
def _process_info_response(self, _pkttype: int, _pktid: int,
packet: SSHPacket) -> None:
"""Process a keyboard interactive authentication response"""
num_responses = packet.get_uint32()
responses = []
for _ in range(num_responses):
response_bytes = packet.get_string()
try:
response = response_bytes.decode('utf-8')
except UnicodeDecodeError:
raise ProtocolError('Invalid keyboard interactive '
'info response') from None
responses.append(response)
packet.check_end()
self.create_task(self._validate_response(responses))
_packet_handlers = {
MSG_USERAUTH_INFO_RESPONSE: _process_info_response
}
class _ServerPasswordAuth(ServerAuth):
"""Server side implementation of password auth"""
@classmethod
def supported(cls, conn: 'SSHServerConnection') -> bool:
"""Return whether password authentication is supported"""
return conn.password_auth_supported()
async def _start(self, packet: SSHPacket) -> None:
"""Start server password authentication"""
password_change = packet.get_boolean()
password_bytes = packet.get_string()
new_password_bytes = packet.get_string() if password_change else b''
packet.check_end()
try:
password = saslprep(password_bytes.decode('utf-8'))
new_password = saslprep(new_password_bytes.decode('utf-8'))
except (UnicodeDecodeError, SASLPrepError):
raise ProtocolError('Invalid password auth request') from None
try:
if password_change:
self.logger.debug1('Trying to chsnge password')
result = await self._conn.change_password(self._username,
password,
new_password)
else:
self.logger.debug1('Trying password auth')
result = \
await self._conn.validate_password(self._username, password)
if result:
self.send_success()
else:
self.send_failure()
except PasswordChangeRequired as exc:
self.send_packet(MSG_USERAUTH_PASSWD_CHANGEREQ,
String(exc.prompt), String(exc.lang))
def register_auth_method(alg: bytes, client_handler: Type[ClientAuth],
server_handler: Type[ServerAuth]) -> None:
"""Register an authentication method"""
_auth_methods.append(alg)
_client_auth_handlers[alg] = client_handler
_server_auth_handlers[alg] = server_handler
def get_supported_client_auth_methods() -> Sequence[bytes]:
"""Return a list of supported client auth methods"""
return [method for method in _client_auth_handlers
if method != b'none']
def lookup_client_auth(conn: 'SSHClientConnection',
method: bytes) -> Optional[ClientAuth]:
"""Look up the client authentication method to use"""
if method in _auth_methods:
return _client_auth_handlers[method](conn, method)
else:
return None
def get_supported_server_auth_methods(conn: 'SSHServerConnection') -> \
Sequence[bytes]:
"""Return a list of supported server auth methods"""
auth_methods = []
for method in _auth_methods:
if _server_auth_handlers[method].supported(conn):
auth_methods.append(method)
return auth_methods
def lookup_server_auth(conn: 'SSHServerConnection', username: str,
method: bytes, packet: SSHPacket) -> \
Optional[ServerAuth]:
"""Look up the server authentication method to use"""
handler = _server_auth_handlers.get(method)
if handler and handler.supported(conn):
return handler(conn, username, method, packet)
else:
conn.send_userauth_failure(False)
return None
_auth_method_list = (
(b'none', _ClientNullAuth, _ServerNullAuth),
(b'gssapi-keyex', _ClientGSSKexAuth, _ServerGSSKexAuth),
(b'gssapi-with-mic', _ClientGSSMICAuth, _ServerGSSMICAuth),
(b'hostbased', _ClientHostBasedAuth, _ServerHostBasedAuth),
(b'publickey', _ClientPublicKeyAuth, _ServerPublicKeyAuth),
(b'keyboard-interactive', _ClientKbdIntAuth, _ServerKbdIntAuth),
(b'password', _ClientPasswordAuth, _ServerPasswordAuth)
)
for _args in _auth_method_list:
register_auth_method(*_args)

View File

@@ -0,0 +1,350 @@
# Copyright (c) 2015-2021 by Ron Frederick <ronf@timeheart.net> and others.
#
# This program and the accompanying materials are made available under
# the terms of the Eclipse Public License v2.0 which accompanies this
# distribution and is available at:
#
# http://www.eclipse.org/legal/epl-2.0/
#
# This program may also be made available under the following secondary
# licenses when the conditions for such availability set forth in the
# Eclipse Public License v2.0 are satisfied:
#
# GNU General Public License, Version 2.0, or any later versions of
# that license
#
# SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later
#
# Contributors:
# Ron Frederick - initial implementation, API, and documentation
"""Parser for SSH authorized_keys files"""
from typing import Dict, List, Mapping, Optional, Sequence
from typing import Set, Tuple, Union, cast
try:
# pylint: disable=unused-import
from .crypto import X509Name, X509NamePattern
_x509_available = True
except ImportError: # pragma: no cover
_x509_available = False
from .misc import ip_address, read_file
from .pattern import HostPatternList, WildcardPatternList
from .public_key import KeyImportError, SSHKey
from .public_key import SSHX509Certificate, SSHX509CertificateChain
from .public_key import import_public_key, import_certificate
from .public_key import import_certificate_subject
_EntryOptions = Mapping[str, object]
class _SSHAuthorizedKeyEntry:
"""An entry in an SSH authorized_keys list"""
def __init__(self, line: str):
self.key: Optional[SSHKey] = None
self.cert: Optional[SSHX509Certificate] = None
self.options: Dict[str, object] = {}
try:
self._import_key_or_cert(line)
return
except KeyImportError:
pass
line = self._parse_options(line)
self._import_key_or_cert(line)
def _import_key_or_cert(self, line: str) -> None:
"""Import key or certificate in this entry"""
try:
self.key = import_public_key(line)
return
except KeyImportError:
pass
try:
self.cert = cast(SSHX509Certificate, import_certificate(line))
if ('cert-authority' in self.options and
self.cert.subject != self.cert.issuer):
raise ValueError('X.509 cert-authority entries must '
'contain a root CA certificate')
return
except KeyImportError:
pass
if 'cert-authority' not in self.options:
try:
self.key = None
self.cert = None
self._add_subject('subject', import_certificate_subject(line))
return
except KeyImportError:
pass
raise KeyImportError('Unrecognized key, certificate, or subject')
def _set_string(self, option: str, value: str) -> None:
"""Set an option with a string value"""
self.options[option] = value
def _add_environment(self, option: str, value: str) -> None:
"""Add an environment key/value pair"""
if value.startswith('=') or '=' not in value:
raise ValueError('Invalid environment entry in authorized_keys')
name, value = value.split('=', 1)
cast(Dict[str, str], self.options.setdefault(option, {}))[name] = value
def _add_from(self, option: str, value: str) -> None:
"""Add a from host pattern"""
from_patterns = cast(List[HostPatternList],
self.options.setdefault(option, []))
from_patterns.append(HostPatternList(value))
def _add_permitopen(self, option: str, value: str) -> None:
"""Add a permitopen host/port pair"""
try:
host, port_str = value.rsplit(':', 1)
if host.startswith('[') and host.endswith(']'):
host = host[1:-1]
port = None if port_str == '*' else int(port_str)
except:
raise ValueError('Illegal permitopen value: %s' % value) from None
permitted_opens = cast(Set[Tuple[str, Optional[int]]],
self.options.setdefault(option, set()))
permitted_opens.add((host, port))
def _add_principals(self, option: str, value: str) -> None:
"""Add a principals wildcard pattern list"""
principal_patterns = cast(List[WildcardPatternList],
self.options.setdefault(option, []))
principal_patterns.append(WildcardPatternList(value))
def _add_subject(self, option: str, value: str) -> None:
"""Add an X.509 subject pattern"""
if _x509_available: # pragma: no branch
subject_patterns = cast(List[X509NamePattern],
self.options.setdefault(option, []))
subject_patterns.append(X509NamePattern(value))
_handlers = {
'command': _set_string,
'environment': _add_environment,
'from': _add_from,
'permitopen': _add_permitopen,
'principals': _add_principals,
'subject': _add_subject
}
def _add_option(self) -> None:
"""Add an option value"""
if self._option.startswith('='):
raise ValueError('Missing option name in authorized_keys')
if '=' in self._option:
option, value = self._option.split('=', 1)
handler = self._handlers.get(option)
if handler:
handler(self, option, value)
else:
values = cast(List[str], self.options.setdefault(option, []))
values.append(value)
else:
self.options[self._option] = True
def _parse_options(self, line: str) -> str:
"""Parse options in this entry"""
self._option = ''
idx = 0
quoted = False
escaped = False
for idx, ch in enumerate(line):
if escaped:
self._option += ch
escaped = False
elif ch == '\\':
escaped = True
elif ch == '"':
quoted = not quoted
elif quoted:
self._option += ch
elif ch in ' \t':
break
elif ch == ',':
self._add_option()
self._option = ''
else:
self._option += ch
self._add_option()
if quoted:
raise ValueError('Unbalanced quote in authorized_keys')
elif escaped:
raise ValueError('Unbalanced backslash in authorized_keys')
return line[idx:].strip()
def match_options(self, client_host: str, client_addr: str,
cert_principals: Optional[Sequence[str]],
cert_subject: Optional['X509Name'] = None) -> bool:
"""Match "from", "principals" and "subject" options in entry"""
from_patterns = cast(List[HostPatternList], self.options.get('from'))
if from_patterns:
client_ip = ip_address(client_addr)
if not all(pattern.matches(client_host, client_addr, client_ip)
for pattern in from_patterns):
return False
principal_patterns = cast(List[WildcardPatternList],
self.options.get('principals'))
if cert_principals is not None and principal_patterns is not None:
if not all(any(pattern.matches(principal)
for principal in cert_principals)
for pattern in principal_patterns):
return False
subject_patterns = cast(List['X509NamePattern'],
self.options.get('subject'))
if cert_subject is not None and subject_patterns is not None:
if not all(pattern.matches(cert_subject)
for pattern in subject_patterns):
return False
return True
class SSHAuthorizedKeys:
"""An SSH authorized keys list"""
def __init__(self, authorized_keys: Optional[str] = None):
self._user_entries: List[_SSHAuthorizedKeyEntry] = []
self._ca_entries: List[_SSHAuthorizedKeyEntry] = []
self._x509_entries: List[_SSHAuthorizedKeyEntry] = []
if authorized_keys:
self.load(authorized_keys)
def load(self, authorized_keys: str) -> None:
"""Load authorized keys data into this object"""
for line in authorized_keys.splitlines():
line = line.strip()
if not line or line.startswith('#'):
continue
try:
entry = _SSHAuthorizedKeyEntry(line)
except KeyImportError:
continue
if entry.key:
if 'cert-authority' in entry.options:
self._ca_entries.append(entry)
else:
self._user_entries.append(entry)
else:
self._x509_entries.append(entry)
if (not self._user_entries and not self._ca_entries and
not self._x509_entries):
raise ValueError('No valid entries found')
def validate(self, key: SSHKey, client_host: str, client_addr: str,
cert_principals: Optional[Sequence[str]] = None,
ca: bool = False) -> Optional[Mapping[str, object]]:
"""Return whether a public key or CA is valid for authentication"""
for entry in self._ca_entries if ca else self._user_entries:
if (entry.key == key and
entry.match_options(client_host, client_addr,
cert_principals)):
return entry.options
return None
def validate_x509(self, cert: SSHX509CertificateChain, client_host: str,
client_addr: str) -> Tuple[Optional[_EntryOptions],
Optional[SSHX509Certificate]]:
"""Return whether an X.509 certificate is valid for authentication"""
for entry in self._x509_entries:
if (entry.cert and 'cert-authority' not in entry.options and
(cert.key != entry.cert.key or
cert.subject != entry.cert.subject)):
continue # pragma: no cover (work around bug in coverage tool)
if entry.match_options(client_host, client_addr,
cert.user_principals, cert.subject):
return entry.options, entry.cert
return None, None
def import_authorized_keys(data: str) -> SSHAuthorizedKeys:
"""Import SSH authorized keys
This function imports public keys and associated options in
OpenSSH authorized keys format.
:param data:
The key data to import.
:type data: `str`
:returns: An :class:`SSHAuthorizedKeys` object
"""
return SSHAuthorizedKeys(data)
def read_authorized_keys(filelist: Union[str, Sequence[str]]) -> \
SSHAuthorizedKeys:
"""Read SSH authorized keys from a file or list of files
This function reads public keys and associated options in
OpenSSH authorized_keys format from a file or list of files.
:param filelist:
The file or list of files to read the keys from.
:type filenlist: `str` or `list` of `str`
:returns: An :class:`SSHAuthorizedKeys` object
"""
authorized_keys = SSHAuthorizedKeys()
if isinstance(filelist, str):
files: Sequence[str] = [filelist]
else:
files = filelist
for filename in files:
authorized_keys.load(read_file(filename, 'r'))
return authorized_keys

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,409 @@
# Copyright (c) 2013-2021 by Ron Frederick <ronf@timeheart.net> and others.
#
# This program and the accompanying materials are made available under
# the terms of the Eclipse Public License v2.0 which accompanies this
# distribution and is available at:
#
# http://www.eclipse.org/legal/epl-2.0/
#
# This program may also be made available under the following secondary
# licenses when the conditions for such availability set forth in the
# Eclipse Public License v2.0 are satisfied:
#
# GNU General Public License, Version 2.0, or any later versions of
# that license
#
# SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later
#
# Contributors:
# Ron Frederick - initial implementation, API, and documentation
"""SSH client protocol handler"""
from typing import TYPE_CHECKING, Optional
from .auth import KbdIntPrompts, KbdIntResponse, PasswordChangeResponse
from .misc import MaybeAwait
from .public_key import KeyPairListArg, SSHKey
if TYPE_CHECKING:
# pylint: disable=cyclic-import
from .connection import SSHClientConnection
class SSHClient:
"""SSH client protocol handler
Applications may subclass this when implementing an SSH client
to receive callbacks when certain events occur on the SSH
connection.
For simple password or public key based authentication, nothing
needs to be defined here if the password or client keys are passed
in when the connection is created. However, to prompt interactively
or otherwise dynamically select these values, the methods
:meth:`password_auth_requested` and/or :meth:`public_key_auth_requested`
can be defined. Keyboard-interactive authentication is also supported
via :meth:`kbdint_auth_requested` and :meth:`kbdint_challenge_received`.
If the server sends an authentication banner, the method
:meth:`auth_banner_received` will be called.
If the server requires a password change, the method
:meth:`password_change_requested` will be called, followed by either
:meth:`password_changed` or :meth:`password_change_failed` depending
on whether the password change is successful.
.. note:: The authentication callbacks described here can be
defined as coroutines. However, they may be cancelled if
they are running when the SSH connection is closed by
the server. If they attempt to catch the CancelledError
exception to perform cleanup, they should make sure to
re-raise it to allow AsyncSSH to finish its own cleanup.
"""
# pylint: disable=no-self-use,unused-argument
def connection_made(self, conn: 'SSHClientConnection') -> None:
"""Called when a connection is made
This method is called as soon as the TCP connection completes.
The `conn` parameter should be stored if needed for later use.
:param conn:
The connection which was successfully opened
:type conn: :class:`SSHClientConnection`
"""
def connection_lost(self, exc: Optional[Exception]) -> None:
"""Called when a connection is lost or closed
This method is called when a connection is closed. If the
connection is shut down cleanly, *exc* will be `None`.
Otherwise, it will be an exception explaining the reason for
the disconnect.
:param exc:
The exception which caused the connection to close, or
`None` if the connection closed cleanly
:type exc: :class:`Exception`
"""
def debug_msg_received(self, msg: str, lang: str,
always_display: bool) -> None:
"""A debug message was received on this connection
This method is called when the other end of the connection sends
a debug message. Applications should implement this method if
they wish to process these debug messages.
:param msg:
The debug message sent
:param lang:
The language the message is in
:param always_display:
Whether or not to display the message
:type msg: `str`
:type lang: `str`
:type always_display: `bool`
"""
def validate_host_public_key(self, host: str, addr: str,
port: int, key: SSHKey) -> bool:
"""Return whether key is an authorized key for this host
Server host key validation can be supported by passing known
host keys in the `known_hosts` argument of
:func:`create_connection`. However, for more flexibility
in matching on the allowed set of keys, this method can be
implemented by the application to do the matching itself. It
should return `True` if the specified key is a valid host key
for the server being connected to.
By default, this method returns `False` for all host keys.
.. note:: This function only needs to report whether the
public key provided is a valid key for this
host. If it is, AsyncSSH will verify that the
server possesses the corresponding private key
before allowing the validation to succeed.
:param host:
The hostname of the target host
:param addr:
The IP address of the target host
:param port:
The port number on the target host
:param key:
The public key sent by the server
:type host: `str`
:type addr: `str`
:type port: `int`
:type key: :class:`SSHKey` *public key*
:returns: A `bool` indicating if the specified key is a valid
key for the target host
"""
return False # pragma: no cover
def validate_host_ca_key(self, host: str, addr: str,
port: int, key: SSHKey) -> bool:
"""Return whether key is an authorized CA key for this host
Server host certificate validation can be supported by passing
known host CA keys in the `known_hosts` argument of
:func:`create_connection`. However, for more flexibility
in matching on the allowed set of keys, this method can be
implemented by the application to do the matching itself. It
should return `True` if the specified key is a valid certificate
authority key for the server being connected to.
By default, this method returns `False` for all CA keys.
.. note:: This function only needs to report whether the
public key provided is a valid CA key for this
host. If it is, AsyncSSH will verify that the
certificate is valid, that the host is one of
the valid principals for the certificate, and
that the server possesses the private key
corresponding to the public key in the certificate
before allowing the validation to succeed.
:param host:
The hostname of the target host
:param addr:
The IP address of the target host
:param port:
The port number on the target host
:param key:
The public key which signed the certificate sent by the server
:type host: `str`
:type addr: `str`
:type port: `int`
:type key: :class:`SSHKey` *public key*
:returns: A `bool` indicating if the specified key is a valid
CA key for the target host
"""
return False # pragma: no cover
def auth_banner_received(self, msg: str, lang: str) -> None:
"""An incoming authentication banner was received
This method is called when the server sends a banner to display
during authentication. Applications should implement this method
if they wish to do something with the banner.
:param msg:
The message the server wanted to display
:param lang:
The language the message is in
:type msg: `str`
:type lang: `str`
"""
def auth_completed(self) -> None:
"""Authentication was completed successfully
This method is called when authentication has completed
successfully. Applications may use this method to create
whatever client sessions and direct TCP/IP or UNIX domain
connections are needed and/or set up listeners for incoming
TCP/IP or UNIX domain connections coming from the server.
However, :func:`create_connection` now blocks until
authentication is complete, so any code which wishes to
use the SSH connection can simply follow that call and
doesn't need to be performed in a callback.
"""
def public_key_auth_requested(self) -> \
MaybeAwait[Optional[KeyPairListArg]]:
"""Public key authentication has been requested
This method should return a private key corresponding to
the user that authentication is being attempted for.
This method may be called multiple times and can return a
different key to try each time it is called. When there are
no keys left to try, it should return `None` to indicate
that some other authentication method should be tried.
If client keys were provided when the connection was opened,
they will be tried before this method is called.
If blocking operations need to be performed to determine the
key to authenticate with, this method may be defined as a
coroutine.
:returns: A key as described in :ref:`SpecifyingPrivateKeys`
or `None` to move on to another authentication
method
"""
return None # pragma: no cover
def password_auth_requested(self) -> MaybeAwait[Optional[str]]:
"""Password authentication has been requested
This method should return a string containing the password
corresponding to the user that authentication is being
attempted for. It may be called multiple times and can
return a different password to try each time, but most
servers have a limit on the number of attempts allowed.
When there's no password left to try, this method should
return `None` to indicate that some other authentication
method should be tried.
If a password was provided when the connection was opened,
it will be tried before this method is called.
If blocking operations need to be performed to determine the
password to authenticate with, this method may be defined as
a coroutine.
:returns: A string containing the password to authenticate
with or `None` to move on to another authentication
method
"""
return None # pragma: no cover
def password_change_requested(self, prompt: str, lang: str) -> \
MaybeAwait[PasswordChangeResponse]:
"""A password change has been requested
This method is called when password authentication was
attempted and the user's password was expired on the
server. To request a password change, this method should
return a tuple or two strings containing the old and new
passwords. Otherwise, it should return `NotImplemented`.
If blocking operations need to be performed to determine the
passwords to authenticate with, this method may be defined
as a coroutine.
By default, this method returns `NotImplemented`.
:param prompt:
The prompt requesting that the user enter a new password
:param lang:
The language that the prompt is in
:type prompt: `str`
:type lang: `str`
:returns: A tuple of two strings containing the old and new
passwords or `NotImplemented` if password changes
aren't supported
"""
return NotImplemented # pragma: no cover
def password_changed(self) -> None:
"""The requested password change was successful
This method is called to indicate that a requested password
change was successful. It is generally followed by a call to
:meth:`auth_completed` since this means authentication was
also successful.
"""
def password_change_failed(self) -> None:
"""The requested password change has failed
This method is called to indicate that a requested password
change failed, generally because the requested new password
doesn't meet the password criteria on the remote system.
After this method is called, other forms of authentication
will automatically be attempted.
"""
def kbdint_auth_requested(self) -> MaybeAwait[Optional[str]]:
"""Keyboard-interactive authentication has been requested
This method should return a string containing a comma-separated
list of submethods that the server should use for
keyboard-interactive authentication. An empty string can be
returned to let the server pick the type of keyboard-interactive
authentication to perform. If keyboard-interactive authentication
is not supported, `None` should be returned.
By default, keyboard-interactive authentication is supported
if a password was provided when the :class:`SSHClient` was
created and it hasn't been sent yet. If the challenge is not
a password challenge, this authentication will fail. This
method and the :meth:`kbdint_challenge_received` method can be
overridden if other forms of challenge should be supported.
If blocking operations need to be performed to determine the
submethods to request, this method may be defined as a
coroutine.
:returns: A string containing the submethods the server should
use for authentication or `None` to move on to
another authentication method
"""
return NotImplemented # pragma: no cover
def kbdint_challenge_received(self, name: str, instructions: str,
lang: str, prompts: KbdIntPrompts) -> \
MaybeAwait[Optional[KbdIntResponse]]:
"""A keyboard-interactive auth challenge has been received
This method is called when the server sends a keyboard-interactive
authentication challenge.
The return value should be a list of strings of the same length
as the number of prompts provided if the challenge can be
answered, or `None` to indicate that some other form of
authentication should be attempted.
If blocking operations need to be performed to determine the
responses to authenticate with, this method may be defined
as a coroutine.
By default, this method will look for a challenge consisting
of a single 'Password:' prompt, and call the method
:meth:`password_auth_requested` to provide the response.
It will also ignore challenges with no prompts (generally used
to provide instructions). Any other form of challenge will
cause this method to return `None` to move on to another
authentication method.
:param name:
The name of the challenge
:param instructions:
Instructions to the user about how to respond to the challenge
:param lang:
The language the challenge is in
:param prompts:
The challenges the user should respond to and whether or
not the responses should be echoed when they are entered
:type name: `str`
:type instructions: `str`
:type lang: `str`
:type prompts: `list` of tuples of `str` and `bool`
:returns: List of string responses to the challenge or `None`
to move on to another authentication method
"""
return None # pragma: no cover

View File

@@ -0,0 +1,157 @@
# Copyright (c) 2013-2021 by Ron Frederick <ronf@timeheart.net> and others.
#
# This program and the accompanying materials are made available under
# the terms of the Eclipse Public License v2.0 which accompanies this
# distribution and is available at:
#
# http://www.eclipse.org/legal/epl-2.0/
#
# This program may also be made available under the following secondary
# licenses when the conditions for such availability set forth in the
# Eclipse Public License v2.0 are satisfied:
#
# GNU General Public License, Version 2.0, or any later versions of
# that license
#
# SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later
#
# Contributors:
# Ron Frederick - initial implementation, API, and documentation
"""SSH compression handlers"""
from typing import Callable, List, Optional
import zlib
_cmp_algs: List[bytes] = []
_default_cmp_algs: List[bytes] = []
_cmp_params = {}
_cmp_compressors = {}
_cmp_decompressors = {}
class Compressor:
"""Base class for data compressor"""
def compress(self, data: bytes) -> Optional[bytes]:
"""Compress data"""
raise NotImplementedError
class Decompressor:
"""Base class for data decompressor"""
def decompress(self, data: bytes) -> Optional[bytes]:
"""Decompress data"""
raise NotImplementedError
_CompressorType = Callable[[], Optional[Compressor]]
_DecompressorType = Callable[[], Optional[Decompressor]]
def _none() -> None:
"""Compressor/decompressor for no compression"""
return None
class _ZLibCompress(Compressor):
"""Wrapper class to force a sync flush and handle exceptions"""
def __init__(self) -> None:
self._comp = zlib.compressobj()
def compress(self, data: bytes) -> Optional[bytes]:
"""Compress data using zlib compression with sync flush"""
try:
return self._comp.compress(data) + \
self._comp.flush(zlib.Z_SYNC_FLUSH)
except zlib.error: # pragma: no cover
return None
class _ZLibDecompress(Decompressor):
"""Wrapper class to handle exceptions"""
def __init__(self) -> None:
self._decomp = zlib.decompressobj()
def decompress(self, data: bytes) -> Optional[bytes]:
"""Decompress data using zlib compression"""
try:
return self._decomp.decompress(data)
except zlib.error: # pragma: no cover
return None
def register_compression_alg(alg: bytes, compressor: _CompressorType,
decompressor: _DecompressorType,
after_auth: bool, default: bool) -> None:
"""Register a compression algorithm"""
_cmp_algs.append(alg)
if default:
_default_cmp_algs.append(alg)
_cmp_params[alg] = after_auth
_cmp_compressors[alg] = compressor
_cmp_decompressors[alg] = decompressor
def get_compression_algs() -> List[bytes]:
"""Return supported compression algorithms"""
return _cmp_algs
def get_default_compression_algs() -> List[bytes]:
"""Return default compression algorithms"""
return _default_cmp_algs
def get_compression_params(alg: bytes) -> bool:
"""Get parameters of a compression algorithm
This function returns whether or not a compression algorithm should
be delayed until after authentication completes.
"""
return _cmp_params[alg]
def get_compressor(alg: bytes) -> Optional[Compressor]:
"""Return an instance of a compressor
This function returns an object that can be used for data compression.
"""
return _cmp_compressors[alg]()
def get_decompressor(alg: bytes) -> Optional[Decompressor]:
"""Return an instance of a decompressor
This function returns an object that can be used for data decompression.
"""
return _cmp_decompressors[alg]()
register_compression_alg(b'zlib@openssh.com',
_ZLibCompress, _ZLibDecompress, True, True)
register_compression_alg(b'zlib',
_ZLibCompress, _ZLibDecompress, False, False)
register_compression_alg(b'none',
_none, _none, False, True)

View File

@@ -0,0 +1,635 @@
# Copyright (c) 2020-2022 by Ron Frederick <ronf@timeheart.net> and others.
#
# This program and the accompanying materials are made available under
# the terms of the Eclipse Public License v2.0 which accompanies this
# distribution and is available at:
#
# http://www.eclipse.org/legal/epl-2.0/
#
# This program may also be made available under the following secondary
# licenses when the conditions for such availability set forth in the
# Eclipse Public License v2.0 are satisfied:
#
# GNU General Public License, Version 2.0, or any later versions of
# that license
#
# SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later
#
# Contributors:
# Ron Frederick - initial implementation, API, and documentation
"""Parser for OpenSSH config files"""
import os
import re
import shlex
import socket
import subprocess
from hashlib import sha1
from pathlib import Path, PurePath
from subprocess import DEVNULL
from typing import Callable, Dict, List, NoReturn, Optional, Sequence
from typing import Set, Tuple, Union, cast
from .constants import DEFAULT_PORT
from .logging import logger
from .misc import DefTuple, FilePath, ip_address
from .pattern import HostPatternList, WildcardPatternList
ConfigPaths = Union[None, FilePath, Sequence[FilePath]]
def _exec(cmd: str) -> bool:
"""Execute a command and return if exit status is 0"""
return subprocess.run(cmd, check=False, shell=True, stdin=DEVNULL,
stdout=DEVNULL, stderr=DEVNULL).returncode == 0
class ConfigParseError(ValueError):
"""Configuration parsing exception"""
class SSHConfig:
"""Settings from an OpenSSH config file"""
_conditionals = {'match'}
_no_split: Set[str] = set()
_percent_expand = {'AuthorizedKeysFile'}
_handlers: Dict[str, Tuple[str, Callable]] = {}
def __init__(self, last_config: Optional['SSHConfig'], reload: bool):
if last_config:
self._last_options = last_config.get_options(reload)
else:
self._last_options = {}
self._default_path = Path('~', '.ssh').expanduser()
self._path = Path()
self._line_no = 0
self._matching = True
self._options = self._last_options.copy()
self._tokens: Dict[str, str] = {}
self.loaded = False
def _error(self, reason: str, *args: object) -> NoReturn:
"""Raise a configuration parsing error"""
raise ConfigParseError('%s line %s: %s' % (self._path, self._line_no,
reason % args))
def _match_val(self, match: str) -> object:
"""Return the value to match against in a match condition"""
raise NotImplementedError
def _set_tokens(self) -> None:
"""Set the tokens available for percent expansion"""
raise NotImplementedError
def _expand_val(self, value: str) -> str:
"""Perform percent token expansion on a string"""
last_idx = 0
result: List[str] = []
for match in re.finditer(r'%', value):
idx = match.start()
if idx < last_idx:
continue
try:
token = value[idx+1]
result.extend([value[last_idx:idx], self._tokens[token]])
last_idx = idx + 2
except IndexError:
raise ConfigParseError('Invalid token substitution') from None
except KeyError:
if token == 'd':
raise ConfigParseError('Home directory is '
'not available') from None
elif token == 'i':
raise ConfigParseError('User id not available') from None
else:
raise ConfigParseError('Invalid token substitution: %s' %
value[idx+1]) from None
result.append(value[last_idx:])
return ''.join(result)
def _include(self, option: str, args: List[str]) -> None:
"""Read config from a list of other config files"""
# pylint: disable=unused-argument
old_path = self._path
for pattern in args:
path = Path(pattern).expanduser()
if path.anchor:
pattern = str(Path(*path.parts[1:]))
path = Path(path.anchor)
else:
path = self._default_path
paths = list(path.glob(pattern))
if not paths:
logger.debug1('Config pattern "%s" matched no files', pattern)
for path in paths:
self.parse(path)
self._path = old_path
args.clear()
def _match(self, option: str, args: List[str]) -> None:
"""Begin a conditional block"""
# pylint: disable=unused-argument
while args:
match = args.pop(0).lower()
if match == 'all':
self._matching = True
continue
match_val = self._match_val(match)
if match != 'exec' and match_val is None:
self._error('Invalid match condition')
try:
if match == 'exec':
self._matching = _exec(args.pop(0))
elif match in ('address', 'localaddress'):
host_pat = HostPatternList(args.pop(0))
ip = ip_address(cast(str, match_val)) \
if match_val else None
self._matching = host_pat.matches(None, match_val, ip)
else:
wild_pat = WildcardPatternList(args.pop(0))
self._matching = wild_pat.matches(match_val)
except IndexError:
self._error('Missing %s match pattern', match)
if not self._matching:
args.clear()
break
def _set_bool(self, option: str, args: List[str]) -> None:
"""Set a boolean config option"""
value_str = args.pop(0).lower()
if value_str in ('yes', 'true'):
value = True
elif value_str in ('no', 'false'):
value = False
else:
self._error('Invalid %s boolean value: %s', option, value_str)
if option not in self._options:
self._options[option] = value
def _set_int(self, option: str, args: List[str]) -> None:
"""Set an integer config option"""
value_str = args.pop(0)
try:
value = int(value_str)
except ValueError:
self._error('Invalid %s integer value: %s', option, value_str)
if option not in self._options:
self._options[option] = value
def _set_string(self, option: str, args: List[str]) -> None:
"""Set a string config option"""
value_str = args.pop(0)
if value_str.lower() == 'none':
value = None
else:
value = value_str
if option not in self._options:
self._options[option] = value
def _append_string(self, option: str, args: List[str]) -> None:
"""Append a string config option to a list"""
value_str = args.pop(0)
if value_str.lower() != 'none':
if option in self._options:
cast(List[str], self._options[option]).append(value_str)
else:
self._options[option] = [value_str]
else:
if option not in self._options:
self._options[option] = []
def _set_string_list(self, option: str, args: List[str]) -> None:
"""Set whitespace-separated string config options as a list"""
if option not in self._options:
if len(args) == 1 and args[0].lower() == 'none':
self._options[option] = []
else:
self._options[option] = args[:]
args.clear()
def _append_string_list(self, option: str, args: List[str]) -> None:
"""Append whitespace-separated string config options to a list"""
if option in self._options:
cast(List[str], self._options[option]).extend(args)
else:
self._options[option] = args[:]
args.clear()
def _set_address_family(self, option: str, args: List[str]) -> None:
"""Set an address family config option"""
value_str = args.pop(0).lower()
if value_str == 'any':
value = socket.AF_UNSPEC
elif value_str == 'inet':
value = socket.AF_INET
elif value_str == 'inet6':
value = socket.AF_INET6
else:
self._error('Invalid %s value: %s', option, value_str)
if option not in self._options:
self._options[option] = value
def _set_rekey_limits(self, option: str, args: List[str]) -> None:
"""Set rekey limits config option"""
byte_limit: Union[str, Tuple[()]] = args.pop(0).lower()
if byte_limit == 'default':
byte_limit = ()
if args:
time_limit: Optional[Union[str, Tuple[()]]] = args.pop(0).lower()
if time_limit == 'none':
time_limit = None
else:
time_limit = ()
if option not in self._options:
self._options[option] = byte_limit, time_limit
def parse(self, path: Path) -> None:
"""Parse an OpenSSH config file and return matching declarations"""
self._path = path
self._line_no = 0
self._matching = True
self._tokens = {'%': '%'}
logger.debug1('Reading config from "%s"', path)
with open(path) as file:
for line in file:
self._line_no += 1
line = line.strip()
if not line or line[0] == '#':
continue
try:
split_args = shlex.split(line)
except ValueError as exc:
self._error(str(exc))
args = []
for arg in split_args:
if arg.startswith('='):
if len(arg) > 1:
args.append(arg[1:])
elif arg.endswith('='):
args.append(arg[:-1])
elif '=' in arg:
arg, val = arg.split('=', 1)
args.append(arg)
args.append(val)
else:
args.append(arg)
option = args.pop(0)
loption = option.lower()
if loption in self._no_split:
args = [line.lstrip()[len(loption):].strip()]
if not self._matching and loption not in self._conditionals:
continue
try:
option, handler = self._handlers[loption]
except KeyError:
continue
if not args:
self._error('Missing %s value', option)
handler(self, option, args)
if args:
self._error('Extra data at end: %s', ' '.join(args))
self._set_tokens()
for option in self._percent_expand:
try:
value = self._options[option]
except KeyError:
pass
else:
if isinstance(value, list):
value = [self._expand_val(item) for item in value]
elif isinstance(value, str):
value = self._expand_val(value)
self._options[option] = value
def get_options(self, reload: bool) -> Dict[str, object]:
"""Return options to base a new config object on"""
return self._last_options.copy() if reload else self._options.copy()
@classmethod
def load(cls, last_config: Optional['SSHConfig'],
config_paths: ConfigPaths, reload: bool,
*args: object) -> 'SSHConfig':
"""Load a list of OpenSSH config files into a config object"""
config = cls(last_config, reload, *args)
if config_paths:
if isinstance(config_paths, (str, PurePath)):
paths: Sequence[FilePath] = [config_paths]
else:
paths = config_paths
for path in paths:
config.parse(Path(path))
config.loaded = True
return config
def get(self, option: str, default: object = None) -> object:
"""Get the value of a config option"""
return self._options.get(option, default)
def get_compression_algs(self) -> DefTuple[str]:
"""Return the compression algorithms to use"""
compression = self.get('Compression')
if compression is None:
return ()
elif compression:
return 'zlib@openssh.com,zlib,none'
else:
return 'none,zlib@openssh.com,zlib'
class SSHClientConfig(SSHConfig):
"""Settings from an OpenSSH client config file"""
_conditionals = {'host', 'match'}
_no_split = {'remotecommand'}
_percent_expand = {'CertificateFile', 'IdentityAgent',
'IdentityFile', 'ProxyCommand', 'RemoteCommand'}
def __init__(self, last_config: 'SSHConfig', reload: bool,
local_user: str, user: str, host: str, port: int) -> None:
super().__init__(last_config, reload)
self._local_user = local_user
self._orig_host = host
if user != ():
self._options['User'] = user
if port != ():
self._options['Port'] = port
def _match_val(self, match: str) -> object:
"""Return the value to match against in a match condition"""
if match == 'host':
return self._options.get('Hostname', self._orig_host)
elif match == 'originalhost':
return self._orig_host
elif match == 'localuser':
return self._local_user
elif match == 'user':
return self._options.get('User', self._local_user)
else:
return None
def _match_host(self, option: str, args: List[str]) -> None:
"""Begin a conditional block matching on host"""
# pylint: disable=unused-argument
pattern = ','.join(args)
self._matching = WildcardPatternList(pattern).matches(self._orig_host)
args.clear()
def _set_hostname(self, option: str, args: List[str]) -> None:
"""Set hostname config option"""
value = args.pop(0)
if option not in self._options:
self._tokens['h'] = \
cast(str, self._options.get(option, self._orig_host))
self._options[option] = self._expand_val(value)
def _set_request_tty(self, option: str, args: List[str]) -> None:
"""Set a pseudo-terminal request config option"""
value_str = args.pop(0).lower()
if value_str in ('yes', 'true'):
value: Union[bool, str] = True
elif value_str in ('no', 'false'):
value = False
elif value_str not in ('force', 'auto'):
self._error('Invalid %s value: %s', option, value_str)
else:
value = value_str
if option not in self._options:
self._options[option] = value
def _set_tokens(self) -> None:
"""Set the tokens available for percent expansion"""
local_host = socket.gethostname()
idx = local_host.find('.')
short_local_host = local_host if idx < 0 else local_host[:idx]
host = cast(str, self._options.get('Hostname', self._orig_host))
port = str(self._options.get('Port', DEFAULT_PORT))
user = cast(str, self._options.get('User') or self._local_user)
home = os.path.expanduser('~')
conn_info = ''.join((local_host, host, port, user))
conn_hash = sha1(conn_info.encode('utf-8')).hexdigest()
self._tokens.update({'C': conn_hash,
'h': host,
'L': short_local_host,
'l': local_host,
'n': self._orig_host,
'p': port,
'r': user,
'u': self._local_user})
if home != '~':
self._tokens['d'] = home
if hasattr(os, 'getuid'):
self._tokens['i'] = str(os.getuid())
_handlers = {option.lower(): (option, handler) for option, handler in (
('Host', _match_host),
('Match', SSHConfig._match),
('Include', SSHConfig._include),
('AddressFamily', SSHConfig._set_address_family),
('BindAddress', SSHConfig._set_string),
('CASignatureAlgorithms', SSHConfig._set_string),
('CertificateFile', SSHConfig._append_string),
('ChallengeResponseAuthentication', SSHConfig._set_bool),
('Ciphers', SSHConfig._set_string),
('Compression', SSHConfig._set_bool),
('ConnectTimeout', SSHConfig._set_int),
('EnableSSHKeySign', SSHConfig._set_bool),
('ForwardAgent', SSHConfig._set_bool),
('ForwardX11Trusted', SSHConfig._set_bool),
('GlobalKnownHostsFile', SSHConfig._set_string_list),
('GSSAPIAuthentication', SSHConfig._set_bool),
('GSSAPIDelegateCredentials', SSHConfig._set_bool),
('GSSAPIKeyExchange', SSHConfig._set_bool),
('HostbasedAuthentication', SSHConfig._set_bool),
('HostKeyAlgorithms', SSHConfig._set_string),
('Hostname', _set_hostname),
('HostKeyAlias', SSHConfig._set_string),
('IdentitiesOnly', SSHConfig._set_bool),
('IdentityAgent', SSHConfig._set_string),
('IdentityFile', SSHConfig._append_string),
('KbdInteractiveAuthentication', SSHConfig._set_bool),
('KexAlgorithms', SSHConfig._set_string),
('MACs', SSHConfig._set_string),
('PasswordAuthentication', SSHConfig._set_bool),
('PKCS11Provider', SSHConfig._set_string),
('PreferredAuthentications', SSHConfig._set_string),
('Port', SSHConfig._set_int),
('ProxyCommand', SSHConfig._set_string_list),
('ProxyJump', SSHConfig._set_string),
('PubkeyAuthentication', SSHConfig._set_bool),
('RekeyLimit', SSHConfig._set_rekey_limits),
('RemoteCommand', SSHConfig._set_string),
('RequestTTY', _set_request_tty),
('SendEnv', SSHConfig._append_string_list),
('ServerAliveCountMax', SSHConfig._set_int),
('ServerAliveInterval', SSHConfig._set_int),
('SetEnv', SSHConfig._append_string_list),
('TCPKeepAlive', SSHConfig._set_bool),
('User', SSHConfig._set_string),
('UserKnownHostsFile', SSHConfig._set_string_list)
)}
class SSHServerConfig(SSHConfig):
"""Settings from an OpenSSH server config file"""
def __init__(self, last_config: 'SSHConfig', reload: bool,
local_addr: str, local_port: int, user: str,
host: str, addr: str) -> None:
super().__init__(last_config, reload)
self._local_addr = local_addr
self._local_port = local_port
self._user = user
self._host = host or addr
self._addr = addr
def _match_val(self, match: str) -> object:
"""Return the value to match against in a match condition"""
if match == 'localaddress':
return self._local_addr
elif match == 'localport':
return str(self._local_port)
elif match == 'user':
return self._user
elif match == 'host':
return self._host
elif match == 'address':
return self._addr
else:
return None
def _set_tokens(self) -> None:
"""Set the tokens available for percent expansion"""
self._tokens.update({'u': self._user})
_handlers = {option.lower(): (option, handler) for option, handler in (
('Match', SSHConfig._match),
('Include', SSHConfig._include),
('AddressFamily', SSHConfig._set_address_family),
('AuthorizedKeysFile', SSHConfig._set_string_list),
('AllowAgentForwarding', SSHConfig._set_bool),
('BindAddress', SSHConfig._set_string),
('CASignatureAlgorithms', SSHConfig._set_string),
('ChallengeResponseAuthentication', SSHConfig._set_bool),
('Ciphers', SSHConfig._set_string),
('ClientAliveCountMax', SSHConfig._set_int),
('ClientAliveInterval', SSHConfig._set_int),
('Compression', SSHConfig._set_bool),
('GSSAPIAuthentication', SSHConfig._set_bool),
('GSSAPIKeyExchange', SSHConfig._set_bool),
('HostbasedAuthentication', SSHConfig._set_bool),
('HostCertificate', SSHConfig._append_string),
('HostKey', SSHConfig._append_string),
('KbdInteractiveAuthentication', SSHConfig._set_bool),
('KexAlgorithms', SSHConfig._set_string),
('LoginGraceTime', SSHConfig._set_int),
('MACs', SSHConfig._set_string),
('PasswordAuthentication', SSHConfig._set_bool),
('PermitTTY', SSHConfig._set_bool),
('Port', SSHConfig._set_int),
('PubkeyAuthentication', SSHConfig._set_bool),
('RekeyLimit', SSHConfig._set_rekey_limits),
('TCPKeepAlive', SSHConfig._set_bool),
('UseDNS', SSHConfig._set_bool)
)}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,363 @@
# Copyright (c) 2013-2021 by Ron Frederick <ronf@timeheart.net> and others.
#
# This program and the accompanying materials are made available under
# the terms of the Eclipse Public License v2.0 which accompanies this
# distribution and is available at:
#
# http://www.eclipse.org/legal/epl-2.0/
#
# This program may also be made available under the following secondary
# licenses when the conditions for such availability set forth in the
# Eclipse Public License v2.0 are satisfied:
#
# GNU General Public License, Version 2.0, or any later versions of
# that license
#
# SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later
#
# Contributors:
# Ron Frederick - initial implementation, API, and documentation
"""SSH constants"""
# Default language for error messages
DEFAULT_LANG = 'en-US'
# Default SSH listening port
DEFAULT_PORT = 22
# SSH message codes
MSG_DISCONNECT = 1
MSG_IGNORE = 2
MSG_UNIMPLEMENTED = 3
MSG_DEBUG = 4
MSG_SERVICE_REQUEST = 5
MSG_SERVICE_ACCEPT = 6
MSG_EXT_INFO = 7
MSG_KEXINIT = 20
MSG_NEWKEYS = 21
MSG_KEX_FIRST = 30
MSG_KEX_LAST = 49
MSG_USERAUTH_REQUEST = 50
MSG_USERAUTH_FAILURE = 51
MSG_USERAUTH_SUCCESS = 52
MSG_USERAUTH_BANNER = 53
MSG_USERAUTH_FIRST = 60
MSG_USERAUTH_LAST = 79
MSG_GLOBAL_REQUEST = 80
MSG_REQUEST_SUCCESS = 81
MSG_REQUEST_FAILURE = 82
MSG_CHANNEL_OPEN = 90
MSG_CHANNEL_OPEN_CONFIRMATION = 91
MSG_CHANNEL_OPEN_FAILURE = 92
MSG_CHANNEL_WINDOW_ADJUST = 93
MSG_CHANNEL_DATA = 94
MSG_CHANNEL_EXTENDED_DATA = 95
MSG_CHANNEL_EOF = 96
MSG_CHANNEL_CLOSE = 97
MSG_CHANNEL_REQUEST = 98
MSG_CHANNEL_SUCCESS = 99
MSG_CHANNEL_FAILURE = 100
# Messages 90-92 are excluded here as they relate to opening a new channel
MSG_CHANNEL_FIRST = 93
MSG_CHANNEL_LAST = 127
# SSH disconnect reason codes
DISC_HOST_NOT_ALLOWED_TO_CONNECT = 1
DISC_PROTOCOL_ERROR = 2
DISC_KEY_EXCHANGE_FAILED = 3
DISC_RESERVED = 4
DISC_MAC_ERROR = 5
DISC_COMPRESSION_ERROR = 6
DISC_SERVICE_NOT_AVAILABLE = 7
DISC_PROTOCOL_VERSION_NOT_SUPPORTED = 8
DISC_HOST_KEY_NOT_VERIFIABLE = 9
DISC_CONNECTION_LOST = 10
DISC_BY_APPLICATION = 11
DISC_TOO_MANY_CONNECTIONS = 12
DISC_AUTH_CANCELLED_BY_USER = 13
DISC_NO_MORE_AUTH_METHODS_AVAILABLE = 14
DISC_ILLEGAL_USER_NAME = 15
DISC_HOST_KEY_NOT_VERIFYABLE = 9 # Error in naming, left here to not
# break backward compatibility
# SSH channel open failure reason codes
OPEN_ADMINISTRATIVELY_PROHIBITED = 1
OPEN_CONNECT_FAILED = 2
OPEN_UNKNOWN_CHANNEL_TYPE = 3
OPEN_RESOURCE_SHORTAGE = 4
# Internal failure reason codes
OPEN_REQUEST_X11_FORWARDING_FAILED = 0xfffffffd
OPEN_REQUEST_PTY_FAILED = 0xfffffffe
OPEN_REQUEST_SESSION_FAILED = 0xffffffff
# SFTPv3-v5 packet types
FXP_INIT = 1
FXP_VERSION = 2
FXP_OPEN = 3
FXP_CLOSE = 4
FXP_READ = 5
FXP_WRITE = 6
FXP_LSTAT = 7
FXP_FSTAT = 8
FXP_SETSTAT = 9
FXP_FSETSTAT = 10
FXP_OPENDIR = 11
FXP_READDIR = 12
FXP_REMOVE = 13
FXP_MKDIR = 14
FXP_RMDIR = 15
FXP_REALPATH = 16
FXP_STAT = 17
FXP_RENAME = 18
FXP_READLINK = 19
FXP_SYMLINK = 20
FXP_STATUS = 101
FXP_HANDLE = 102
FXP_DATA = 103
FXP_NAME = 104
FXP_ATTRS = 105
FXP_EXTENDED = 200
FXP_EXTENDED_REPLY = 201
# SFTPv6 packet types
FXP_LINK = 21
FXP_BLOCK = 22
FXP_UNBLOCK = 23
# SFTPv3 open flags
FXF_READ = 0x00000001
FXF_WRITE = 0x00000002
FXF_APPEND = 0x00000004
FXF_CREAT = 0x00000008
FXF_TRUNC = 0x00000010
FXF_EXCL = 0x00000020
# SFTPv4 open flags
FXF_TEXT = 0x00000040
# SFTPv5 open flags
FXF_ACCESS_DISPOSITION = 0x00000007
FXF_CREATE_NEW = 0x00000000
FXF_CREATE_TRUNCATE = 0x00000001
FXF_OPEN_EXISTING = 0x00000002
FXF_OPEN_OR_CREATE = 0x00000003
FXF_TRUNCATE_EXISTING = 0x00000004
FXF_APPEND_DATA = 0x00000008
FXF_APPEND_DATA_ATOMIC = 0x00000010
FXF_TEXT_MODE = 0x00000020
FXF_BLOCK_READ = 0x00000040
FXF_BLOCK_WRITE = 0x00000080
FXF_BLOCK_DELETE = 0x00000100
# SFTPv6 open flags
FXF_BLOCK_ADVISORY = 0x00000200
FXF_NOFOLLOW = 0x00000400
FXF_DELETE_ON_CLOSE = 0x00000800
FXF_ACCESS_AUDIT_ALARM_INFO = 0x00001000
FXF_ACCESS_BACKUP = 0x00002000
FXF_BACKUP_STREAM = 0x00004000
FXF_OVERRIDE_OWNER = 0x00008000
# SFTPv5-v6 ACE mask values used in desired-access
ACE4_READ_DATA = 0x00000001
ACE4_WRITE_DATA = 0x00000002
ACE4_APPEND_DATA = 0x00000004
ACE4_READ_ATTRIBUTES = 0x00000080
ACE4_WRITE_ATTRIBUTES = 0x00000100
# SFTPv3 attribute flags
FILEXFER_ATTR_SIZE = 0x00000001
FILEXFER_ATTR_UIDGID = 0x00000002
FILEXFER_ATTR_PERMISSIONS = 0x00000004
FILEXFER_ATTR_ACMODTIME = 0x00000008
FILEXFER_ATTR_EXTENDED = 0x80000000
FILEXFER_ATTR_DEFINED_V3 = 0x8000000f
# SFTPv4 attribute flags
FILEXFER_ATTR_ACCESSTIME = 0x00000008
FILEXFER_ATTR_CREATETIME = 0x00000010
FILEXFER_ATTR_MODIFYTIME = 0x00000020
FILEXFER_ATTR_ACL = 0x00000040
FILEXFER_ATTR_OWNERGROUP = 0x00000080
FILEXFER_ATTR_SUBSECOND_TIMES = 0x00000100
FILEXFER_ATTR_DEFINED_V4 = 0x800001fd
# SFTPv5 attribute flags
FILEXFER_ATTR_BITS = 0x00000200
FILEXFER_ATTR_DEFINED_V5 = 0x800003fd
# SFTPv6 attribute flags
FILEXFER_ATTR_ALLOCATION_SIZE = 0x00000400
FILEXFER_ATTR_TEXT_HINT = 0x00000800
FILEXFER_ATTR_MIME_TYPE = 0x00001000
FILEXFER_ATTR_LINK_COUNT = 0x00002000
FILEXFER_ATTR_UNTRANSLATED_NAME = 0x00004000
FILEXFER_ATTR_CTIME = 0x00008000
FILEXFER_ATTR_DEFINED_V6 = 0x8000fffd
# SFTPv4 file types
FILEXFER_TYPE_REGULAR = 1
FILEXFER_TYPE_DIRECTORY = 2
FILEXFER_TYPE_SYMLINK = 3
FILEXFER_TYPE_SPECIAL = 4
FILEXFER_TYPE_UNKNOWN = 5
# SFTPv5 file types
FILEXFER_TYPE_SOCKET = 6
FILEXFER_TYPE_CHAR_DEVICE = 7
FILEXFER_TYPE_BLOCK_DEVICE = 8
FILEXFER_TYPE_FIFO = 9
# SFTPv5 attrib bits
FILEXFER_ATTR_BITS_READONLY = 0x00000001
FILEXFER_ATTR_BITS_SYSTEM = 0x00000002
FILEXFER_ATTR_BITS_HIDDEN = 0x00000004
FILEXFER_ATTR_BITS_CASE_INSENSITIVE = 0x00000008
FILEXFER_ATTR_BITS_ARCHIVE = 0x00000010
FILEXFER_ATTR_BITS_ENCRYPTED = 0x00000020
FILEXFER_ATTR_BITS_COMPRESSED = 0x00000040
FILEXFER_ATTR_BITS_SPARSE = 0x00000080
FILEXFER_ATTR_BITS_APPEND_ONLY = 0x00000100
FILEXFER_ATTR_BITS_IMMUTABLE = 0x00000200
FILEXFER_ATTR_BITS_SYNC = 0x00000400
# SFTPv6 attrib bits
FILEXFER_ATTR_BITS_TRANSLATION_ERR = 0x00000800
# SFTPv6 text hint flags
FILEXFER_ATTR_KNOWN_TEXT = 0
FILEXFER_ATTR_GUESSED_TEXT = 1
FILEXFER_ATTR_KNOWN_BINARY = 2
FILEXFER_ATTR_GUESSED_BINARY = 3
# SFTPv5 rename flags
FXR_OVERWRITE = 0x00000001
FXR_ATOMIC = 0x00000002
FXR_NATIVE = 0x00000004
# SFTPv6 realpath control byte
FXRP_NO_CHECK = 1
FXRP_STAT_IF_EXISTS = 2
FXRP_STAT_ALWAYS = 3
# OpenSSH statvfs attribute flags
FXE_STATVFS_ST_RDONLY = 0x1
FXE_STATVFS_ST_NOSUID = 0x2
# SFTPv3 error codes
FX_OK = 0
FX_EOF = 1
FX_NO_SUCH_FILE = 2
FX_PERMISSION_DENIED = 3
FX_FAILURE = 4
FX_BAD_MESSAGE = 5
FX_NO_CONNECTION = 6
FX_CONNECTION_LOST = 7
FX_OP_UNSUPPORTED = 8
FX_V3_END = FX_OP_UNSUPPORTED
# SFTPv4 error codes
FX_INVALID_HANDLE = 9
FX_NO_SUCH_PATH = 10
FX_FILE_ALREADY_EXISTS = 11
FX_WRITE_PROTECT = 12
FX_NO_MEDIA = 13
FX_V4_END = FX_NO_MEDIA
# SFTPv5 error codes
FX_NO_SPACE_ON_FILESYSTEM = 14
FX_QUOTA_EXCEEDED = 15
FX_UNKNOWN_PRINCIPAL = 16
FX_LOCK_CONFLICT = 17
FX_V5_END = FX_LOCK_CONFLICT
# SFTPv6 error codes
FX_DIR_NOT_EMPTY = 18
FX_NOT_A_DIRECTORY = 19
FX_INVALID_FILENAME = 20
FX_LINK_LOOP = 21
FX_CANNOT_DELETE = 22
FX_INVALID_PARAMETER = 23
FX_FILE_IS_A_DIRECTORY = 24
FX_BYTE_RANGE_LOCK_CONFLICT = 25
FX_BYTE_RANGE_LOCK_REFUSED = 26
FX_DELETE_PENDING = 27
FX_FILE_CORRUPT = 28
FX_OWNER_INVALID = 29
FX_GROUP_INVALID = 30
FX_NO_MATCHING_BYTE_RANGE_LOCK = 31
FX_V6_END = FX_NO_MATCHING_BYTE_RANGE_LOCK
# SSH channel data type codes
EXTENDED_DATA_STDERR = 1
# SSH pty mode opcodes
PTY_OP_END = 0
PTY_VINTR = 1
PTY_VQUIT = 2
PTY_VERASE = 3
PTY_VKILL = 4
PTY_VEOF = 5
PTY_VEOL = 6
PTY_VEOL2 = 7
PTY_VSTART = 8
PTY_VSTOP = 9
PTY_VSUSP = 10
PTY_VDSUSP = 11
PTY_VREPRINT = 12
PTY_WERASE = 13
PTY_VLNEXT = 14
PTY_VFLUSH = 15
PTY_VSWTCH = 16
PTY_VSTATUS = 17
PTY_VDISCARD = 18
PTY_IGNPAR = 30
PTY_PARMRK = 31
PTY_INPCK = 32
PTY_ISTRIP = 33
PTY_INLCR = 34
PTY_IGNCR = 35
PTY_ICRNL = 36
PTY_IUCLC = 37
PTY_IXON = 38
PTY_IXANY = 39
PTY_IXOFF = 40
PTY_IMAXBEL = 41
PTY_IUTF8 = 42
PTY_ISIG = 50
PTY_ICANON = 51
PTY_XCASE = 52
PTY_ECHO = 53
PTY_ECHOE = 54
PTY_ECHOK = 55
PTY_ECHONL = 56
PTY_NOFLSH = 57
PTY_TOSTOP = 58
PTY_IEXTEN = 59
PTY_ECHOCTL = 60
PTY_ECHOKE = 61
PTY_PENDIN = 62
PTY_OPOST = 70
PTY_OLCUC = 71
PTY_ONLCR = 72
PTY_OCRNL = 73
PTY_ONOCR = 74
PTY_ONLRET = 75
PTY_CS7 = 90
PTY_CS8 = 91
PTY_PARENB = 92
PTY_PARODD = 93
PTY_OP_ISPEED = 128
PTY_OP_OSPEED = 129
PTY_OP_RESERVED = 160

View File

@@ -0,0 +1,61 @@
# Copyright (c) 2014-2021 by Ron Frederick <ronf@timeheart.net> and others.
#
# This program and the accompanying materials are made available under
# the terms of the Eclipse Public License v2.0 which accompanies this
# distribution and is available at:
#
# http://www.eclipse.org/legal/epl-2.0/
#
# This program may also be made available under the following secondary
# licenses when the conditions for such availability set forth in the
# Eclipse Public License v2.0 are satisfied:
#
# GNU General Public License, Version 2.0, or any later versions of
# that license
#
# SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later
#
# Contributors:
# Ron Frederick - initial implementation, API, and documentation
"""A shim for accessing cryptographic primitives needed by asyncssh"""
from .cipher import BasicCipher, GCMCipher, register_cipher, get_cipher_params
from .dsa import DSAPrivateKey, DSAPublicKey
from .dh import DH
from .ec import ECDSAPrivateKey, ECDSAPublicKey, ECDH
from .ed import ed25519_available, ed448_available
from .ed import curve25519_available, curve448_available
from .ed import EdDSAPrivateKey, EdDSAPublicKey, Curve25519DH, Curve448DH
from .ec_params import lookup_ec_curve_by_params
from .kdf import pbkdf2_hmac
from .misc import CryptoKey, PyCAKey
from .rsa import RSAPrivateKey, RSAPublicKey
from .sntrup import sntrup761_available
from .sntrup import sntrup761_pubkey_bytes, sntrup761_ciphertext_bytes
from .sntrup import sntrup761_keypair, sntrup761_encaps, sntrup761_decaps
# Import chacha20-poly1305 cipher if available
from .chacha import ChachaCipher, chacha_available
# Import umac cryptographic hash if available
try:
from .umac import umac32, umac64, umac96, umac128
except (ImportError, AttributeError, OSError): # pragma: no cover
pass
# Import X.509 certificate support if available
try:
from .x509 import X509Certificate, X509Name, X509NamePattern
from .x509 import generate_x509_certificate, import_x509_certificate
except ImportError: # pragma: no cover
pass

View File

@@ -0,0 +1,162 @@
# Copyright (c) 2015-2021 by Ron Frederick <ronf@timeheart.net> and others.
#
# This program and the accompanying materials are made available under
# the terms of the Eclipse Public License v2.0 which accompanies this
# distribution and is available at:
#
# http://www.eclipse.org/legal/epl-2.0/
#
# This program may also be made available under the following secondary
# licenses when the conditions for such availability set forth in the
# Eclipse Public License v2.0 are satisfied:
#
# GNU General Public License, Version 2.0, or any later versions of
# that license
#
# SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later
#
# Contributors:
# Ron Frederick - initial implementation, API, and documentation
"""Chacha20-Poly1305 symmetric encryption handler"""
from ctypes import c_ulonglong, create_string_buffer
from typing import Optional, Tuple
from cryptography.exceptions import InvalidSignature
from cryptography.hazmat.backends.openssl import backend
from cryptography.hazmat.primitives.ciphers import Cipher
from cryptography.hazmat.primitives.ciphers.algorithms import ChaCha20
from cryptography.hazmat.primitives.poly1305 import Poly1305
from .cipher import register_cipher
if backend.poly1305_supported():
_CTR_0 = (0).to_bytes(8, 'little')
_CTR_1 = (1).to_bytes(8, 'little')
_POLY1305_KEYBYTES = 32
def chacha20(key: bytes, data: bytes, nonce: bytes, ctr: int) -> bytes:
"""Encrypt/decrypt a block of data with the ChaCha20 cipher"""
return Cipher(ChaCha20(key, (_CTR_1 if ctr else _CTR_0) + nonce),
mode=None).encryptor().update(data)
def poly1305_key(key: bytes, nonce: bytes) -> bytes:
"""Derive a Poly1305 key"""
return chacha20(key, _POLY1305_KEYBYTES * b'\0', nonce, 0)
def poly1305(key: bytes, data: bytes, nonce: bytes) -> bytes:
"""Compute a Poly1305 tag for a block of data"""
return Poly1305.generate_tag(poly1305_key(key, nonce), data)
def poly1305_verify(key: bytes, data: bytes,
nonce: bytes, tag: bytes) -> bool:
"""Verify a Poly1305 tag for a block of data"""
try:
Poly1305.verify_tag(poly1305_key(key, nonce), data, tag)
return True
except InvalidSignature:
return False
chacha_available = True
else: # pragma: no cover
try:
from libnacl import nacl
_chacha20 = nacl.crypto_stream_chacha20
_chacha20_xor_ic = nacl.crypto_stream_chacha20_xor_ic
_POLY1305_BYTES = nacl.crypto_onetimeauth_poly1305_bytes()
_POLY1305_KEYBYTES = nacl.crypto_onetimeauth_poly1305_keybytes()
_poly1305 = nacl.crypto_onetimeauth_poly1305
_poly1305_verify = nacl.crypto_onetimeauth_poly1305_verify
def chacha20(key: bytes, data: bytes, nonce: bytes, ctr: int) -> bytes:
"""Encrypt/decrypt a block of data with the ChaCha20 cipher"""
datalen = len(data)
result = create_string_buffer(datalen)
ull_datalen = c_ulonglong(datalen)
ull_ctr = c_ulonglong(ctr)
_chacha20_xor_ic(result, data, ull_datalen, nonce, ull_ctr, key)
return result.raw
def poly1305_key(key: bytes, nonce: bytes) -> bytes:
"""Derive a Poly1305 key"""
polykey = create_string_buffer(_POLY1305_KEYBYTES)
ull_polykeylen = c_ulonglong(_POLY1305_KEYBYTES)
_chacha20(polykey, ull_polykeylen, nonce, key)
return polykey.raw
def poly1305(key: bytes, data: bytes, nonce: bytes) -> bytes:
"""Compute a Poly1305 tag for a block of data"""
tag = create_string_buffer(_POLY1305_BYTES)
ull_datalen = c_ulonglong(len(data))
polykey = poly1305_key(key, nonce)
_poly1305(tag, data, ull_datalen, polykey)
return tag.raw
def poly1305_verify(key: bytes, data: bytes,
nonce: bytes, tag: bytes) -> bool:
"""Verify a Poly1305 tag for a block of data"""
ull_datalen = c_ulonglong(len(data))
polykey = poly1305_key(key, nonce)
return _poly1305_verify(tag, data, ull_datalen, polykey) == 0
chacha_available = True
except (ImportError, OSError, AttributeError):
chacha_available = False
class ChachaCipher:
"""Shim for Chacha20-Poly1305 symmetric encryption"""
def __init__(self, key: bytes):
keylen = len(key) // 2
self._key = key[:keylen]
self._adkey = key[keylen:]
def encrypt_and_sign(self, header: bytes, data: bytes,
nonce: bytes) -> Tuple[bytes, bytes]:
"""Encrypt and sign a block of data"""
header = chacha20(self._adkey, header, nonce, 0)
data = chacha20(self._key, data, nonce, 1)
tag = poly1305(self._key, header + data, nonce)
return header + data, tag
def decrypt_header(self, header: bytes, nonce: bytes) -> bytes:
"""Decrypt header data"""
return chacha20(self._adkey, header, nonce, 0)
def verify_and_decrypt(self, header: bytes, data: bytes,
nonce: bytes, tag: bytes) -> Optional[bytes]:
"""Verify the signature of and decrypt a block of data"""
if poly1305_verify(self._key, header + data, nonce, tag):
return chacha20(self._key, data, nonce, 1)
else:
return None
if chacha_available: # pragma: no branch
register_cipher('chacha20-poly1305', 64, 0, 1)

View File

@@ -0,0 +1,166 @@
# Copyright (c) 2014-2021 by Ron Frederick <ronf@timeheart.net> and others.
#
# This program and the accompanying materials are made available under
# the terms of the Eclipse Public License v2.0 which accompanies this
# distribution and is available at:
#
# http://www.eclipse.org/legal/epl-2.0/
#
# This program may also be made available under the following secondary
# licenses when the conditions for such availability set forth in the
# Eclipse Public License v2.0 are satisfied:
#
# GNU General Public License, Version 2.0, or any later versions of
# that license
#
# SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later
#
# Contributors:
# Ron Frederick - initial implementation, API, and documentation
"""A shim around PyCA for accessing symmetric ciphers needed by AsyncSSH"""
from typing import Any, MutableMapping, Optional, Tuple
import warnings
from cryptography.exceptions import InvalidTag
from cryptography.hazmat.primitives.ciphers import Cipher, CipherContext
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
from cryptography.hazmat.primitives.ciphers.algorithms import AES, ARC4
from cryptography.hazmat.primitives.ciphers.algorithms import TripleDES
from cryptography.hazmat.primitives.ciphers.modes import CBC, CTR
with warnings.catch_warnings():
warnings.simplefilter('ignore')
from cryptography.hazmat.primitives.ciphers.algorithms import Blowfish
from cryptography.hazmat.primitives.ciphers.algorithms import CAST5
from cryptography.hazmat.primitives.ciphers.algorithms import SEED
_CipherAlgs = Tuple[Any, Any, int]
_CipherParams = Tuple[int, int, int]
_GCM_MAC_SIZE = 16
_cipher_algs: MutableMapping[str, _CipherAlgs] = {}
_cipher_params: MutableMapping[str, _CipherParams] = {}
class BasicCipher:
"""Shim for basic ciphers"""
def __init__(self, cipher_name: str, key: bytes, iv: bytes):
cipher, mode, initial_bytes = _cipher_algs[cipher_name]
self._cipher = Cipher(cipher(key), mode(iv) if mode else None)
self._initial_bytes = initial_bytes
self._encryptor: Optional[CipherContext] = None
self._decryptor: Optional[CipherContext] = None
def encrypt(self, data: bytes) -> bytes:
"""Encrypt a block of data"""
if not self._encryptor:
self._encryptor = self._cipher.encryptor()
if self._initial_bytes:
assert self._encryptor is not None
self._encryptor.update(self._initial_bytes * b'\0')
assert self._encryptor is not None
return self._encryptor.update(data)
def decrypt(self, data: bytes) -> bytes:
"""Decrypt a block of data"""
if not self._decryptor:
self._decryptor = self._cipher.decryptor()
if self._initial_bytes:
assert self._decryptor is not None
self._decryptor.update(self._initial_bytes * b'\0')
assert self._decryptor is not None
return self._decryptor.update(data)
class GCMCipher:
"""Shim for GCM ciphers"""
def __init__(self, cipher_name: str, key: bytes, iv: bytes):
self._cipher = _cipher_algs[cipher_name][0]
self._key = key
self._iv = iv
def _update_iv(self) -> None:
"""Update the IV after each encrypt/decrypt operation"""
invocation = int.from_bytes(self._iv[4:], 'big')
invocation = (invocation + 1) & 0xffffffffffffffff
self._iv = self._iv[:4] + invocation.to_bytes(8, 'big')
def encrypt_and_sign(self, header: bytes,
data: bytes) -> Tuple[bytes, bytes]:
"""Encrypt and sign a block of data"""
data = AESGCM(self._key).encrypt(self._iv, data, header)
self._update_iv()
return header + data[:-_GCM_MAC_SIZE], data[-_GCM_MAC_SIZE:]
def verify_and_decrypt(self, header: bytes, data: bytes,
mac: bytes) -> Optional[bytes]:
"""Verify the signature of and decrypt a block of data"""
try:
decrypted_data: Optional[bytes] = \
AESGCM(self._key).decrypt(self._iv, data + mac, header)
except InvalidTag:
decrypted_data = None
self._update_iv()
return decrypted_data
def register_cipher(cipher_name: str, key_size: int,
iv_size: int, block_size: int) -> None:
"""Register a symmetric cipher"""
_cipher_params[cipher_name] = (key_size, iv_size, block_size)
def get_cipher_params(cipher_name: str) -> _CipherParams:
"""Get parameters of a symmetric cipher"""
return _cipher_params[cipher_name]
_cipher_alg_list = (
('aes128-cbc', AES, CBC, 0, 16, 16, 16),
('aes192-cbc', AES, CBC, 0, 24, 16, 16),
('aes256-cbc', AES, CBC, 0, 32, 16, 16),
('aes128-ctr', AES, CTR, 0, 16, 16, 16),
('aes192-ctr', AES, CTR, 0, 24, 16, 16),
('aes256-ctr', AES, CTR, 0, 32, 16, 16),
('aes128-gcm', None, None, 0, 16, 12, 16),
('aes256-gcm', None, None, 0, 32, 12, 16),
('arcfour', ARC4, None, 0, 16, 1, 1),
('arcfour40', ARC4, None, 0, 5, 1, 1),
('arcfour128', ARC4, None, 1536, 16, 1, 1),
('arcfour256', ARC4, None, 1536, 32, 1, 1),
('blowfish-cbc', Blowfish, CBC, 0, 16, 8, 8),
('cast128-cbc', CAST5, CBC, 0, 16, 8, 8),
('des-cbc', TripleDES, CBC, 0, 8, 8, 8),
('des2-cbc', TripleDES, CBC, 0, 16, 8, 8),
('des3-cbc', TripleDES, CBC, 0, 24, 8, 8),
('seed-cbc', SEED, CBC, 0, 16, 16, 16)
)
for _cipher_name, _cipher, _mode, _initial_bytes, \
_key_size, _iv_size, _block_size in _cipher_alg_list:
_cipher_algs[_cipher_name] = (_cipher, _mode, _initial_bytes)
register_cipher(_cipher_name, _key_size, _iv_size, _block_size)

View File

@@ -0,0 +1,46 @@
# Copyright (c) 2022 by Ron Frederick <ronf@timeheart.net> and others.
#
# This program and the accompanying materials are made available under
# the terms of the Eclipse Public License v2.0 which accompanies this
# distribution and is available at:
#
# http://www.eclipse.org/legal/epl-2.0/
#
# This program may also be made available under the following secondary
# licenses when the conditions for such availability set forth in the
# Eclipse Public License v2.0 are satisfied:
#
# GNU General Public License, Version 2.0, or any later versions of
# that license
#
# SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later
#
# Contributors:
# Ron Frederick - initial implementation, API, and documentation
"""A shim around PyCA for Diffie Hellman key exchange"""
from cryptography.hazmat.primitives.asymmetric import dh
class DH:
"""A shim around PyCA for Diffie Hellman key exchange"""
def __init__(self, g: int, p: int):
self._pn = dh.DHParameterNumbers(p, g)
self._priv_key = self._pn.parameters().generate_private_key()
def get_public(self) -> int:
"""Return the public key to send in the handshake"""
pub_key = self._priv_key.public_key()
return pub_key.public_numbers().y
def get_shared(self, peer_public: int) -> int:
"""Return the shared key from the peer's public key"""
peer_key = dh.DHPublicNumbers(peer_public, self._pn).public_key()
shared_key = self._priv_key.exchange(peer_key)
return int.from_bytes(shared_key, 'big')

View File

@@ -0,0 +1,132 @@
# Copyright (c) 2014-2023 by Ron Frederick <ronf@timeheart.net> and others.
#
# This program and the accompanying materials are made available under
# the terms of the Eclipse Public License v2.0 which accompanies this
# distribution and is available at:
#
# http://www.eclipse.org/legal/epl-2.0/
#
# This program may also be made available under the following secondary
# licenses when the conditions for such availability set forth in the
# Eclipse Public License v2.0 are satisfied:
#
# GNU General Public License, Version 2.0, or any later versions of
# that license
#
# SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later
#
# Contributors:
# Ron Frederick - initial implementation, API, and documentation
"""A shim around PyCA for DSA public and private keys"""
from typing import Optional, cast
from cryptography.exceptions import InvalidSignature
from cryptography.hazmat.primitives.asymmetric import dsa
from .misc import CryptoKey, PyCAKey, hashes
# Short variable names are used here, matching names in the spec
# pylint: disable=invalid-name
class _DSAKey(CryptoKey):
"""Base class for shim around PyCA for DSA keys"""
def __init__(self, pyca_key: PyCAKey, params: dsa.DSAParameterNumbers,
pub: dsa.DSAPublicNumbers,
priv: Optional[dsa.DSAPrivateNumbers] = None):
super().__init__(pyca_key)
self._params = params
self._pub = pub
self._priv = priv
@property
def p(self) -> int:
"""Return the DSA public modulus"""
return self._params.p
@property
def q(self) -> int:
"""Return the DSA sub-group order"""
return self._params.q
@property
def g(self) -> int:
"""Return the DSA generator"""
return self._params.g
@property
def y(self) -> int:
"""Return the DSA public value"""
return self._pub.y
@property
def x(self) -> Optional[int]:
"""Return the DSA private value"""
return self._priv.x if self._priv else None
class DSAPrivateKey(_DSAKey):
"""A shim around PyCA for DSA private keys"""
@classmethod
def construct(cls, p: int, q: int, g: int,
y: int, x: int) -> 'DSAPrivateKey':
"""Construct a DSA private key"""
params = dsa.DSAParameterNumbers(p, q, g)
pub = dsa.DSAPublicNumbers(y, params)
priv = dsa.DSAPrivateNumbers(x, pub)
priv_key = priv.private_key()
return cls(priv_key, params, pub, priv)
@classmethod
def generate(cls, key_size: int) -> 'DSAPrivateKey':
"""Generate a new DSA private key"""
priv_key = dsa.generate_private_key(key_size)
priv = priv_key.private_numbers()
pub = priv.public_numbers
params = pub.parameter_numbers
return cls(priv_key, params, pub, priv)
def sign(self, data: bytes, hash_name: str = '') -> bytes:
"""Sign a block of data"""
priv_key = cast('dsa.DSAPrivateKey', self.pyca_key)
return priv_key.sign(data, hashes[hash_name]())
class DSAPublicKey(_DSAKey):
"""A shim around PyCA for DSA public keys"""
@classmethod
def construct(cls, p: int, q: int, g: int, y: int) -> 'DSAPublicKey':
"""Construct a DSA public key"""
params = dsa.DSAParameterNumbers(p, q, g)
pub = dsa.DSAPublicNumbers(y, params)
pub_key = pub.public_key()
return cls(pub_key, params, pub)
def verify(self, data: bytes, sig: bytes, hash_name: str = '') -> bool:
"""Verify the signature on a block of data"""
try:
pub_key = cast('dsa.DSAPublicKey', self.pyca_key)
pub_key.verify(sig, data, hashes[hash_name]())
return True
except InvalidSignature:
return False

View File

@@ -0,0 +1,205 @@
# Copyright (c) 2015-2021 by Ron Frederick <ronf@timeheart.net> and others.
#
# This program and the accompanying materials are made available under
# the terms of the Eclipse Public License v2.0 which accompanies this
# distribution and is available at:
#
# http://www.eclipse.org/legal/epl-2.0/
#
# This program may also be made available under the following secondary
# licenses when the conditions for such availability set forth in the
# Eclipse Public License v2.0 are satisfied:
#
# GNU General Public License, Version 2.0, or any later versions of
# that license
#
# SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later
#
# Contributors:
# Ron Frederick - initial implementation, API, and documentation
"""A shim around PyCA for elliptic curve keys and key exchange"""
from typing import Mapping, Optional, Type, cast
from cryptography.exceptions import InvalidSignature
from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.primitives.serialization import Encoding
from cryptography.hazmat.primitives.serialization import PublicFormat
from .misc import CryptoKey, PyCAKey, hashes
# Short variable names are used here, matching names in the spec
# pylint: disable=invalid-name
_curves: Mapping[bytes, Type[ec.EllipticCurve]] = {
b'1.3.132.0.10': ec.SECP256K1,
b'nistp256': ec.SECP256R1,
b'nistp384': ec.SECP384R1,
b'nistp521': ec.SECP521R1
}
class _ECKey(CryptoKey):
"""Base class for shim around PyCA for EC keys"""
def __init__(self, pyca_key: PyCAKey, curve_id: bytes,
pub: ec.EllipticCurvePublicNumbers, point: bytes,
priv: Optional[ec.EllipticCurvePrivateNumbers] = None):
super().__init__(pyca_key)
self._curve_id = curve_id
self._pub = pub
self._point = point
self._priv = priv
@classmethod
def lookup_curve(cls, curve_id: bytes) -> Type[ec.EllipticCurve]:
"""Look up curve and hash algorithm"""
try:
return _curves[curve_id]
except KeyError: # pragma: no cover, other curves not registered
raise ValueError('Unknown EC curve %s' %
curve_id.decode()) from None
@property
def curve_id(self) -> bytes:
"""Return the EC curve name"""
return self._curve_id
@property
def x(self) -> int:
"""Return the EC public x coordinate"""
return self._pub.x
@property
def y(self) -> int:
"""Return the EC public y coordinate"""
return self._pub.y
@property
def d(self) -> Optional[int]:
"""Return the EC private value as an integer"""
return self._priv.private_value if self._priv else None
@property
def public_value(self) -> bytes:
"""Return the EC public point value encoded as a byte string"""
return self._point
@property
def private_value(self) -> Optional[bytes]:
"""Return the EC private value encoded as a byte string"""
if self._priv:
keylen = (self._pub.curve.key_size + 7) // 8
return self._priv.private_value.to_bytes(keylen, 'big')
else:
return None
class ECDSAPrivateKey(_ECKey):
"""A shim around PyCA for ECDSA private keys"""
@classmethod
def construct(cls, curve_id: bytes, public_value: bytes,
private_value: int) -> 'ECDSAPrivateKey':
"""Construct an ECDSA private key"""
curve = cls.lookup_curve(curve_id)
priv_key = ec.derive_private_key(private_value, curve())
priv = priv_key.private_numbers()
pub = priv.public_numbers
return cls(priv_key, curve_id, pub, public_value, priv)
@classmethod
def generate(cls, curve_id: bytes) -> 'ECDSAPrivateKey':
"""Generate a new ECDSA private key"""
curve = cls.lookup_curve(curve_id)
priv_key = ec.generate_private_key(curve())
priv = priv_key.private_numbers()
pub_key = priv_key.public_key()
pub = pub_key.public_numbers()
public_value = pub_key.public_bytes(Encoding.X962,
PublicFormat.UncompressedPoint)
return cls(priv_key, curve_id, pub, public_value, priv)
def sign(self, data: bytes, hash_name: str = '') -> bytes:
"""Sign a block of data"""
# pylint: disable=unused-argument
priv_key = cast('ec.EllipticCurvePrivateKey', self.pyca_key)
return priv_key.sign(data, ec.ECDSA(hashes[hash_name]()))
class ECDSAPublicKey(_ECKey):
"""A shim around PyCA for ECDSA public keys"""
@classmethod
def construct(cls, curve_id: bytes,
public_value: bytes) -> 'ECDSAPublicKey':
"""Construct an ECDSA public key"""
curve = cls.lookup_curve(curve_id)
pub_key = ec.EllipticCurvePublicKey.from_encoded_point(curve(),
public_value)
pub = pub_key.public_numbers()
return cls(pub_key, curve_id, pub, public_value)
def verify(self, data: bytes, sig: bytes, hash_name: str = '') -> bool:
"""Verify the signature on a block of data"""
try:
pub_key = cast('ec.EllipticCurvePublicKey', self.pyca_key)
pub_key.verify(sig, data, ec.ECDSA(hashes[hash_name]()))
return True
except InvalidSignature:
return False
class ECDH:
"""A shim around PyCA for ECDH key exchange"""
def __init__(self, curve_id: bytes):
try:
curve = _curves[curve_id]
except KeyError: # pragma: no cover, other curves not registered
raise ValueError('Unknown EC curve %s' %
curve_id.decode()) from None
self._priv_key = ec.generate_private_key(curve())
def get_public(self) -> bytes:
"""Return the public key to send in the handshake"""
pub_key = self._priv_key.public_key()
return pub_key.public_bytes(Encoding.X962,
PublicFormat.UncompressedPoint)
def get_shared(self, peer_public: bytes) -> int:
"""Return the shared key from the peer's public key"""
peer_key = ec.EllipticCurvePublicKey.from_encoded_point(
self._priv_key.curve, peer_public)
shared_key = self._priv_key.exchange(ec.ECDH(), peer_key)
return int.from_bytes(shared_key, 'big')

View File

@@ -0,0 +1,87 @@
# Copyright (c) 2013-2021 by Ron Frederick <ronf@timeheart.net> and others.
#
# This program and the accompanying materials are made available under
# the terms of the Eclipse Public License v2.0 which accompanies this
# distribution and is available at:
#
# http://www.eclipse.org/legal/epl-2.0/
#
# This program may also be made available under the following secondary
# licenses when the conditions for such availability set forth in the
# Eclipse Public License v2.0 are satisfied:
#
# GNU General Public License, Version 2.0, or any later versions of
# that license
#
# SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later
#
# Contributors:
# Ron Frederick - initial implementation, API, and documentation
"""Functions for looking up named elliptic curves by their parameters"""
_curve_param_map = {}
# Short variable names are used here, matching names in the spec
# pylint: disable=invalid-name
def register_prime_curve(curve_id: bytes, p: int, a: int, b: int,
point: bytes, n: int) -> None:
"""Register an elliptic curve prime domain
This function registers an elliptic curve prime domain by
specifying the SSH identifier for the curve and the set of
parameters describing the curve, generator point, and order.
This allows EC keys encoded with explicit parameters to be
mapped back into their SSH curve IDs.
"""
_curve_param_map[p, a % p, b % p, point, n] = curve_id
def lookup_ec_curve_by_params(p: int, a: int, b: int,
point: bytes, n: int) -> bytes:
"""Look up an elliptic curve by its parameters
This function looks up an elliptic curve by its parameters
and returns the curve's name.
"""
try:
return _curve_param_map[p, a % p, b % p, point, n]
except (KeyError, ValueError):
raise ValueError('Unknown elliptic curve parameters') from None
# pylint: disable=line-too-long
register_prime_curve(b'nistp521',
6864797660130609714981900799081393217269435300143305409394463459185543183397656052122559640661454554977296311391480858037121987999716643812574028291115057151,
-3,
1093849038073734274511112390766805569936207598951683748994586394495953116150735016013708737573759623248592132296706313309438452531591012912142327488478985984,
b'\x04\x00\xc6\x85\x8e\x06\xb7\x04\x04\xe9\xcd\x9e>\xcbf#\x95\xb4B\x9cd\x819\x05?\xb5!\xf8(\xaf`kM=\xba\xa1K^w\xef\xe7Y(\xfe\x1d\xc1\'\xa2\xff\xa8\xde3H\xb3\xc1\x85jB\x9b\xf9~~1\xc2\xe5\xbdf\x01\x189)jx\x9a;\xc0\x04\\\x8a_\xb4,}\x1b\xd9\x98\xf5DIW\x9bDh\x17\xaf\xbd\x17\'>f,\x97\xeer\x99^\xf4&@\xc5P\xb9\x01?\xad\x07a5<p\x86\xa2r\xc2@\x88\xbe\x94v\x9f\xd1fP',
6864797660130609714981900799081393217269435300143305409394463459185543183397655394245057746333217197532963996371363321113864768612440380340372808892707005449)
register_prime_curve(b'nistp384',
39402006196394479212279040100143613805079739270465446667948293404245721771496870329047266088258938001861606973112319,
-3,
27580193559959705877849011840389048093056905856361568521428707301988689241309860865136260764883745107765439761230575,
b'\x04\xaa\x87\xca"\xbe\x8b\x057\x8e\xb1\xc7\x1e\xf3 \xadtn\x1d;b\x8b\xa7\x9b\x98Y\xf7A\xe0\x82T*8U\x02\xf2]\xbfU)l:T^8rv\n\xb76\x17\xdeJ\x96&,o]\x9e\x98\xbf\x92\x92\xdc)\xf8\xf4\x1d\xbd(\x9a\x14|\xe9\xda1\x13\xb5\xf0\xb8\xc0\n`\xb1\xce\x1d~\x81\x9dzC\x1d|\x90\xea\x0e_',
39402006196394479212279040100143613805079739270465446667946905279627659399113263569398956308152294913554433653942643)
register_prime_curve(b'nistp256',
115792089210356248762697446949407573530086143415290314195533631308867097853951,
-3,
41058363725152142129326129780047268409114441015993725554835256314039467401291,
b'\x04k\x17\xd1\xf2\xe1,BG\xf8\xbc\xe6\xe5c\xa4@\xf2w\x03}\x81-\xeb3\xa0\xf4\xa19E\xd8\x98\xc2\x96O\xe3B\xe2\xfe\x1a\x7f\x9b\x8e\xe7\xebJ|\x0f\x9e\x16+\xce3Wk1^\xce\xcb\xb6@h7\xbfQ\xf5',
115792089210356248762697446949407573529996955224135760342422259061068512044369)
register_prime_curve(b'1.3.132.0.10',
115792089237316195423570985008687907853269984665640564039457584007908834671663,
0,
7,
b'\x04y\xbef~\xf9\xdc\xbb\xacU\xa0b\x95\xce\x87\x0b\x07\x02\x9b\xfc\xdb-\xce(\xd9Y\xf2\x81[\x16\xf8\x17\x98H:\xdaw&\xa3\xc4e]\xa4\xfb\xfc\x0e\x11\x08\xa8\xfd\x17\xb4H\xa6\x85T\x19\x9cG\xd0\x8f\xfb\x10\xd4\xb8',
115792089237316195423570985008687907852837564279074904382605163141518161494337)

View File

@@ -0,0 +1,325 @@
# Copyright (c) 2019-2023 by Ron Frederick <ronf@timeheart.net> and others.
#
# This program and the accompanying materials are made available under
# the terms of the Eclipse Public License v2.0 which accompanies this
# distribution and is available at:
#
# http://www.eclipse.org/legal/epl-2.0/
#
# This program may also be made available under the following secondary
# licenses when the conditions for such availability set forth in the
# Eclipse Public License v2.0 are satisfied:
#
# GNU General Public License, Version 2.0, or any later versions of
# that license
#
# SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later
#
# Contributors:
# Ron Frederick - initial implementation, API, and documentation
"""A shim around PyCA and libnacl for Edwards-curve keys and key exchange"""
import ctypes
import os
from typing import Dict, Optional, Union, cast
from cryptography.exceptions import InvalidSignature
from cryptography.hazmat.backends.openssl import backend
from cryptography.hazmat.primitives.asymmetric import ed25519, ed448
from cryptography.hazmat.primitives.asymmetric import x25519, x448
from cryptography.hazmat.primitives.serialization import Encoding
from cryptography.hazmat.primitives.serialization import PrivateFormat
from cryptography.hazmat.primitives.serialization import PublicFormat
from cryptography.hazmat.primitives.serialization import NoEncryption
from .misc import CryptoKey, PyCAKey
_EdPrivateKey = Union[ed25519.Ed25519PrivateKey, ed448.Ed448PrivateKey]
_EdPublicKey = Union[ed25519.Ed25519PublicKey, ed448.Ed448PublicKey]
ed25519_available = backend.ed25519_supported()
ed448_available = backend.ed448_supported()
curve25519_available = backend.x25519_supported()
curve448_available = backend.x448_supported()
if ed25519_available or ed448_available: # pragma: no branch
class _EdDSAKey(CryptoKey):
"""Base class for shim around PyCA for EdDSA keys"""
def __init__(self, pyca_key: PyCAKey, pub: bytes,
priv: Optional[bytes] = None):
super().__init__(pyca_key)
self._pub = pub
self._priv = priv
@property
def public_value(self) -> bytes:
"""Return the public value encoded as a byte string"""
return self._pub
@property
def private_value(self) -> Optional[bytes]:
"""Return the private value encoded as a byte string"""
return self._priv
class EdDSAPrivateKey(_EdDSAKey):
"""A shim around PyCA for EdDSA private keys"""
_priv_classes: Dict[bytes, object] = {}
if ed25519_available: # pragma: no branch
_priv_classes[b'ed25519'] = ed25519.Ed25519PrivateKey
if ed448_available: # pragma: no branch
_priv_classes[b'ed448'] = ed448.Ed448PrivateKey
@classmethod
def construct(cls, curve_id: bytes, priv: bytes) -> 'EdDSAPrivateKey':
"""Construct an EdDSA private key"""
priv_cls = cast('_EdPrivateKey', cls._priv_classes[curve_id])
priv_key = priv_cls.from_private_bytes(priv)
pub_key = priv_key.public_key()
pub = pub_key.public_bytes(Encoding.Raw, PublicFormat.Raw)
return cls(priv_key, pub, priv)
@classmethod
def generate(cls, curve_id: bytes) -> 'EdDSAPrivateKey':
"""Generate a new EdDSA private key"""
priv_cls = cast('_EdPrivateKey', cls._priv_classes[curve_id])
priv_key = priv_cls.generate()
priv = priv_key.private_bytes(Encoding.Raw, PrivateFormat.Raw,
NoEncryption())
pub_key = priv_key.public_key()
pub = pub_key.public_bytes(Encoding.Raw, PublicFormat.Raw)
return cls(priv_key, pub, priv)
def sign(self, data: bytes, hash_name: str = '') -> bytes:
"""Sign a block of data"""
# pylint: disable=unused-argument
priv_key = cast('_EdPrivateKey', self.pyca_key)
return priv_key.sign(data)
class EdDSAPublicKey(_EdDSAKey):
"""A shim around PyCA for EdDSA public keys"""
_pub_classes: Dict[bytes, object] = {
b'ed25519': ed25519.Ed25519PublicKey,
b'ed448': ed448.Ed448PublicKey
}
@classmethod
def construct(cls, curve_id: bytes, pub: bytes) -> 'EdDSAPublicKey':
"""Construct an EdDSA public key"""
pub_cls = cast('_EdPublicKey', cls._pub_classes[curve_id])
pub_key = pub_cls.from_public_bytes(pub)
return cls(pub_key, pub)
def verify(self, data: bytes, sig: bytes, hash_name: str = '') -> bool:
"""Verify the signature on a block of data"""
# pylint: disable=unused-argument
try:
pub_key = cast('_EdPublicKey', self.pyca_key)
pub_key.verify(sig, data)
return True
except InvalidSignature:
return False
else: # pragma: no cover
class _EdDSANaclKey:
"""Base class for shim around libnacl for EdDSA keys"""
def __init__(self, pub: bytes, priv: Optional[bytes] = None):
self._pub = pub
self._priv = priv
@property
def public_value(self) -> bytes:
"""Return the public value encoded as a byte string"""
return self._pub
@property
def private_value(self) -> Optional[bytes]:
"""Return the private value encoded as a byte string"""
return self._priv[:-len(self._pub)] if self._priv else None
class EdDSAPrivateKey(_EdDSANaclKey): # type: ignore
"""A shim around libnacl for EdDSA private keys"""
@classmethod
def construct(cls, curve_id: bytes, priv: bytes) -> 'EdDSAPrivateKey':
"""Construct an EdDSA private key"""
# pylint: disable=unused-argument
return cls(*_ed25519_construct_keypair(priv))
@classmethod
def generate(cls, curve_id: str) -> 'EdDSAPrivateKey':
"""Generate a new EdDSA private key"""
# pylint: disable=unused-argument
return cls(*_ed25519_generate_keypair())
def sign(self, data: bytes, hash_name: str = '') -> bytes:
"""Sign a block of data"""
# pylint: disable=unused-argument
assert self._priv is not None
return _ed25519_sign(data, self._priv)[:-len(data)]
class EdDSAPublicKey(_EdDSANaclKey): # type: ignore
"""A shim around libnacl for EdDSA public keys"""
@classmethod
def construct(cls, curve_id: bytes, pub: bytes) -> 'EdDSAPublicKey':
"""Construct an EdDSA public key"""
# pylint: disable=unused-argument
if len(pub) != _ED25519_PUBLIC_BYTES:
raise ValueError('Invalid EdDSA public key')
return cls(pub)
def verify(self, data: bytes, sig: bytes, hash_name: str = '') -> bool:
"""Verify the signature on a block of data"""
# pylint: disable=unused-argument
try:
return _ed25519_verify(sig + data, self._pub) == data
except ValueError:
return False
try:
import libnacl
_ED25519_PUBLIC_BYTES = libnacl.crypto_sign_ed25519_PUBLICKEYBYTES
_ed25519_construct_keypair = libnacl.crypto_sign_seed_keypair
_ed25519_generate_keypair = libnacl.crypto_sign_keypair
_ed25519_sign = libnacl.crypto_sign
_ed25519_verify = libnacl.crypto_sign_open
ed25519_available = True
except (ImportError, OSError, AttributeError):
pass
if curve25519_available: # pragma: no branch
class Curve25519DH:
"""Curve25519 Diffie Hellman implementation based on PyCA"""
def __init__(self) -> None:
self._priv_key = x25519.X25519PrivateKey.generate()
def get_public(self) -> bytes:
"""Return the public key to send in the handshake"""
return self._priv_key.public_key().public_bytes(Encoding.Raw,
PublicFormat.Raw)
def get_shared_bytes(self, peer_public: bytes) -> bytes:
"""Return the shared key from the peer's public key"""
peer_key = x25519.X25519PublicKey.from_public_bytes(peer_public)
return self._priv_key.exchange(peer_key)
def get_shared(self, peer_public: bytes) -> int:
"""Return the shared key from the peer's public key as bytes"""
return int.from_bytes(self.get_shared_bytes(peer_public), 'big')
else: # pragma: no cover
class Curve25519DH: # type: ignore
"""Curve25519 Diffie Hellman implementation based on libnacl"""
def __init__(self) -> None:
self._private = os.urandom(_CURVE25519_SCALARBYTES)
def get_public(self) -> bytes:
"""Return the public key to send in the handshake"""
public = ctypes.create_string_buffer(_CURVE25519_BYTES)
if _curve25519_base(public, self._private) != 0:
# This error is never returned by libsodium
raise ValueError('Curve25519 failed') # pragma: no cover
return public.raw
def get_shared_bytes(self, peer_public: bytes) -> bytes:
"""Return the shared key from the peer's public key as bytes"""
if len(peer_public) != _CURVE25519_BYTES:
raise ValueError('Invalid curve25519 public key size')
shared = ctypes.create_string_buffer(_CURVE25519_BYTES)
if _curve25519(shared, self._private, peer_public) != 0:
raise ValueError('Curve25519 failed')
return shared.raw
def get_shared(self, peer_public: bytes) -> int:
"""Return the shared key from the peer's public key"""
return int.from_bytes(self.get_shared_bytes(peer_public), 'big')
try:
from libnacl import nacl
_CURVE25519_BYTES = nacl.crypto_scalarmult_curve25519_bytes()
_CURVE25519_SCALARBYTES = \
nacl.crypto_scalarmult_curve25519_scalarbytes()
_curve25519 = nacl.crypto_scalarmult_curve25519
_curve25519_base = nacl.crypto_scalarmult_curve25519_base
curve25519_available = True
except (ImportError, OSError, AttributeError):
pass
class Curve448DH:
"""Curve448 Diffie Hellman implementation based on PyCA"""
def __init__(self) -> None:
self._priv_key = x448.X448PrivateKey.generate()
def get_public(self) -> bytes:
"""Return the public key to send in the handshake"""
return self._priv_key.public_key().public_bytes(Encoding.Raw,
PublicFormat.Raw)
def get_shared(self, peer_public: bytes) -> int:
"""Return the shared key from the peer's public key"""
peer_key = x448.X448PublicKey.from_public_bytes(peer_public)
shared = self._priv_key.exchange(peer_key)
return int.from_bytes(shared, 'big')

View File

@@ -0,0 +1,33 @@
# Copyright (c) 2017-2021 by Ron Frederick <ronf@timeheart.net> and others.
#
# This program and the accompanying materials are made available under
# the terms of the Eclipse Public License v2.0 which accompanies this
# distribution and is available at:
#
# http://www.eclipse.org/legal/epl-2.0/
#
# This program may also be made available under the following secondary
# licenses when the conditions for such availability set forth in the
# Eclipse Public License v2.0 are satisfied:
#
# GNU General Public License, Version 2.0, or any later versions of
# that license
#
# SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later
#
# Contributors:
# Ron Frederick - initial implementation, API, and documentation
"""A shim around PyCA for key derivation functions"""
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
from .misc import hashes
def pbkdf2_hmac(hash_name: str, passphrase: bytes, salt: bytes,
count: int, key_size: int) -> bytes:
"""A shim around PyCA for PBKDF2 HMAC key derivation"""
return PBKDF2HMAC(hashes[hash_name](), key_size, salt,
count).derive(passphrase)

View File

@@ -0,0 +1,70 @@
# Copyright (c) 2017-2023 by Ron Frederick <ronf@timeheart.net> and others.
#
# This program and the accompanying materials are made available under
# the terms of the Eclipse Public License v2.0 which accompanies this
# distribution and is available at:
#
# http://www.eclipse.org/legal/epl-2.0/
#
# This program may also be made available under the following secondary
# licenses when the conditions for such availability set forth in the
# Eclipse Public License v2.0 are satisfied:
#
# GNU General Public License, Version 2.0, or any later versions of
# that license
#
# SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later
#
# Contributors:
# Ron Frederick - initial implementation, API, and documentation
"""Miscellaneous PyCA utility classes and functions"""
from typing import Callable, Mapping, Union
from cryptography.hazmat.primitives.asymmetric import dsa, ec, rsa
from cryptography.hazmat.primitives.asymmetric import ed25519, ed448
from cryptography.hazmat.primitives.hashes import HashAlgorithm
from cryptography.hazmat.primitives.hashes import MD5, SHA1, SHA224
from cryptography.hazmat.primitives.hashes import SHA256, SHA384, SHA512
PyCAPrivateKey = Union[dsa.DSAPrivateKey, rsa.RSAPrivateKey,
ec.EllipticCurvePrivateKey,
ed25519.Ed25519PrivateKey, ed448.Ed448PrivateKey]
PyCAPublicKey = Union[dsa.DSAPublicKey, rsa.RSAPublicKey,
ec.EllipticCurvePublicKey,
ed25519.Ed25519PublicKey, ed448.Ed448PublicKey]
PyCAKey = Union[PyCAPrivateKey, PyCAPublicKey]
hashes: Mapping[str, Callable[[], HashAlgorithm]] = {
str(h.name): h for h in (MD5, SHA1, SHA224, SHA256, SHA384, SHA512)
}
class CryptoKey:
"""Base class for PyCA private/public keys"""
def __init__(self, pyca_key: PyCAKey):
self._pyca_key = pyca_key
@property
def pyca_key(self) -> PyCAKey:
"""Return the PyCA object associated with this key"""
return self._pyca_key
def sign(self, data: bytes, hash_name: str = '') -> bytes:
"""Sign a block of data"""
# pylint: disable=no-self-use
raise RuntimeError # pragma: no cover
def verify(self, data: bytes, sig: bytes, hash_name: str = '') -> bool:
"""Verify the signature on a block of data"""
# pylint: disable=no-self-use
raise RuntimeError # pragma: no cover

View File

@@ -0,0 +1,169 @@
# Copyright (c) 2014-2023 by Ron Frederick <ronf@timeheart.net> and others.
#
# This program and the accompanying materials are made available under
# the terms of the Eclipse Public License v2.0 which accompanies this
# distribution and is available at:
#
# http://www.eclipse.org/legal/epl-2.0/
#
# This program may also be made available under the following secondary
# licenses when the conditions for such availability set forth in the
# Eclipse Public License v2.0 are satisfied:
#
# GNU General Public License, Version 2.0, or any later versions of
# that license
#
# SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later
#
# Contributors:
# Ron Frederick - initial implementation, API, and documentation
"""A shim around PyCA for RSA public and private keys"""
from typing import Optional, cast
from cryptography.exceptions import InvalidSignature
from cryptography.hazmat.primitives.asymmetric.padding import MGF1, OAEP
from cryptography.hazmat.primitives.asymmetric.padding import PKCS1v15
from cryptography.hazmat.primitives.asymmetric import rsa
from .misc import CryptoKey, PyCAKey, hashes
# Short variable names are used here, matching names in the spec
# pylint: disable=invalid-name
class _RSAKey(CryptoKey):
"""Base class for shim around PyCA for RSA keys"""
def __init__(self, pyca_key: PyCAKey, pub: rsa.RSAPublicNumbers,
priv: Optional[rsa.RSAPrivateNumbers] = None):
super().__init__(pyca_key)
self._pub = pub
self._priv = priv
@property
def n(self) -> int:
"""Return the RSA public modulus"""
return self._pub.n
@property
def e(self) -> int:
"""Return the RSA public exponent"""
return self._pub.e
@property
def d(self) -> Optional[int]:
"""Return the RSA private exponent"""
return self._priv.d if self._priv else None
@property
def p(self) -> Optional[int]:
"""Return the RSA first private prime"""
return self._priv.p if self._priv else None
@property
def q(self) -> Optional[int]:
"""Return the RSA second private prime"""
return self._priv.q if self._priv else None
@property
def dmp1(self) -> Optional[int]:
"""Return d modulo p-1"""
return self._priv.dmp1 if self._priv else None
@property
def dmq1(self) -> Optional[int]:
"""Return q modulo p-1"""
return self._priv.dmq1 if self._priv else None
@property
def iqmp(self) -> Optional[int]:
"""Return the inverse of q modulo p"""
return self._priv.iqmp if self._priv else None
class RSAPrivateKey(_RSAKey):
"""A shim around PyCA for RSA private keys"""
@classmethod
def construct(cls, n: int, e: int, d: int, p: int, q: int,
dmp1: int, dmq1: int, iqmp: int,
skip_validation: bool) -> 'RSAPrivateKey':
"""Construct an RSA private key"""
pub = rsa.RSAPublicNumbers(e, n)
priv = rsa.RSAPrivateNumbers(p, q, d, dmp1, dmq1, iqmp, pub)
priv_key = priv.private_key(
unsafe_skip_rsa_key_validation=skip_validation)
return cls(priv_key, pub, priv)
@classmethod
def generate(cls, key_size: int, exponent: int) -> 'RSAPrivateKey':
"""Generate a new RSA private key"""
priv_key = rsa.generate_private_key(exponent, key_size)
priv = priv_key.private_numbers()
pub = priv.public_numbers
return cls(priv_key, pub, priv)
def decrypt(self, data: bytes, hash_name: str) -> Optional[bytes]:
"""Decrypt a block of data"""
try:
hash_alg = hashes[hash_name]()
priv_key = cast('rsa.RSAPrivateKey', self.pyca_key)
return priv_key.decrypt(data, OAEP(MGF1(hash_alg), hash_alg, None))
except ValueError:
return None
def sign(self, data: bytes, hash_name: str = '') -> bytes:
"""Sign a block of data"""
priv_key = cast('rsa.RSAPrivateKey', self.pyca_key)
return priv_key.sign(data, PKCS1v15(), hashes[hash_name]())
class RSAPublicKey(_RSAKey):
"""A shim around PyCA for RSA public keys"""
@classmethod
def construct(cls, n: int, e: int) -> 'RSAPublicKey':
"""Construct an RSA public key"""
pub = rsa.RSAPublicNumbers(e, n)
pub_key = pub.public_key()
return cls(pub_key, pub)
def encrypt(self, data: bytes, hash_name: str) -> Optional[bytes]:
"""Encrypt a block of data"""
try:
hash_alg = hashes[hash_name]()
pub_key = cast('rsa.RSAPublicKey', self.pyca_key)
return pub_key.encrypt(data, OAEP(MGF1(hash_alg), hash_alg, None))
except ValueError:
return None
def verify(self, data: bytes, sig: bytes, hash_name: str = '') -> bool:
"""Verify the signature on a block of data"""
try:
pub_key = cast('rsa.RSAPublicKey', self.pyca_key)
pub_key.verify(sig, data, PKCS1v15(), hashes[hash_name]())
return True
except InvalidSignature:
return False

View File

@@ -0,0 +1,88 @@
# Copyright (c) 2022 by Ron Frederick <ronf@timeheart.net> and others.
#
# This program and the accompanying materials are made available under
# the terms of the Eclipse Public License v2.0 which accompanies this
# distribution and is available at:
#
# http://www.eclipse.org/legal/epl-2.0/
#
# This program may also be made available under the following secondary
# licenses when the conditions for such availability set forth in the
# Eclipse Public License v2.0 are satisfied:
#
# GNU General Public License, Version 2.0, or any later versions of
# that license
#
# SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later
#
# Contributors:
# Ron Frederick - initial implementation, API, and documentation
"""A shim around liboqs for Streamlined NTRU Prime post-quantum encryption"""
import ctypes
import ctypes.util
from typing import Tuple
sntrup761_available = False
sntrup761_pubkey_bytes = 1158
sntrup761_privkey_bytes = 1763
sntrup761_ciphertext_bytes = 1039
sntrup761_secret_bytes = 32
for lib in ('oqs', 'liboqs'):
_oqs_lib = ctypes.util.find_library(lib)
if _oqs_lib: # pragma: no branch
break
else: # pragma: no cover
_oqs_lib = None
if _oqs_lib: # pragma: no branch
_oqs = ctypes.cdll.LoadLibrary(_oqs_lib)
_sntrup761_keypair = _oqs.OQS_KEM_ntruprime_sntrup761_keypair
_sntrup761_encaps = _oqs.OQS_KEM_ntruprime_sntrup761_encaps
_sntrup761_decaps = _oqs.OQS_KEM_ntruprime_sntrup761_decaps
sntrup761_available = True
def sntrup761_keypair() -> Tuple[bytes, bytes]:
"""Make a SNTRUP761 key pair"""
pubkey = ctypes.create_string_buffer(sntrup761_pubkey_bytes)
privkey = ctypes.create_string_buffer(sntrup761_privkey_bytes)
_sntrup761_keypair(pubkey, privkey)
return pubkey.raw, privkey.raw
def sntrup761_encaps(pubkey: bytes) -> Tuple[bytes, bytes]:
"""Generate a random secret and encrypt it with a public key"""
if len(pubkey) != sntrup761_pubkey_bytes:
raise ValueError('Invalid SNTRUP761 public key')
ciphertext = ctypes.create_string_buffer(sntrup761_ciphertext_bytes)
secret = ctypes.create_string_buffer(sntrup761_secret_bytes)
_sntrup761_encaps(ciphertext, secret, pubkey)
return secret.raw, ciphertext.raw
def sntrup761_decaps(ciphertext: bytes, privkey: bytes) -> bytes:
"""Decrypt an encrypted secret using a private key"""
if len(ciphertext) != sntrup761_ciphertext_bytes:
raise ValueError('Invalid SNTRUP761 ciphertext')
secret = ctypes.create_string_buffer(sntrup761_secret_bytes)
_sntrup761_decaps(secret, ciphertext, privkey)
return secret.raw

View File

@@ -0,0 +1,139 @@
# Copyright (c) 2016-2021 by Ron Frederick <ronf@timeheart.net> and others.
#
# This program and the accompanying materials are made available under
# the terms of the Eclipse Public License v2.0 which accompanies this
# distribution and is available at:
#
# http://www.eclipse.org/legal/epl-2.0/
#
# This program may also be made available under the following secondary
# licenses when the conditions for such availability set forth in the
# Eclipse Public License v2.0 are satisfied:
#
# GNU General Public License, Version 2.0, or any later versions of
# that license
#
# SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later
#
# Contributors:
# Ron Frederick - initial implementation, API, and documentation
"""UMAC cryptographic hash (RFC 4418) wrapper for Nettle library"""
import binascii
import ctypes
import ctypes.util
from typing import TYPE_CHECKING, Callable, Optional
if TYPE_CHECKING:
_ByteArray = ctypes.Array[ctypes.c_char]
_SetKey = Callable[[_ByteArray, bytes], None]
_SetNonce = Callable[[_ByteArray, ctypes.c_size_t, bytes], None]
_Update = Callable[[_ByteArray, ctypes.c_size_t, bytes], None]
_Digest = Callable[[_ByteArray, ctypes.c_size_t, _ByteArray], None]
_New = Callable[[bytes, Optional[bytes], Optional[bytes]], object]
_UMAC_BLOCK_SIZE = 1024
_UMAC_DEFAULT_CTX_SIZE = 4096
def _build_umac(size: int) -> '_New':
"""Function to build UMAC wrapper for a specific digest size"""
_name = 'umac%d' % size
_prefix = 'nettle_%s_' % _name
try:
_context_size: int = getattr(_nettle, _prefix + '_ctx_size')()
except AttributeError:
_context_size = _UMAC_DEFAULT_CTX_SIZE
_set_key: _SetKey = getattr(_nettle, _prefix + 'set_key')
_set_nonce: _SetNonce = getattr(_nettle, _prefix + 'set_nonce')
_update: _Update = getattr(_nettle, _prefix + 'update')
_digest: _Digest = getattr(_nettle, _prefix + 'digest')
class _UMAC:
"""Wrapper for UMAC cryptographic hash
This class supports the cryptographic hash API defined in PEP 452.
"""
name = _name
block_size = _UMAC_BLOCK_SIZE
digest_size = size // 8
def __init__(self, ctx: '_ByteArray', nonce: Optional[bytes] = None,
msg: Optional[bytes] = None):
self._ctx = ctx
if nonce:
self.set_nonce(nonce)
if msg:
self.update(msg)
@classmethod
def new(cls, key: bytes, msg: Optional[bytes] = None,
nonce: Optional[bytes] = None) -> '_UMAC':
"""Construct a new UMAC hash object"""
ctx = ctypes.create_string_buffer(_context_size)
_set_key(ctx, key)
return cls(ctx, nonce, msg)
def copy(self) -> '_UMAC':
"""Return a new hash object with this object's state"""
ctx = ctypes.create_string_buffer(self._ctx.raw)
return self.__class__(ctx)
def set_nonce(self, nonce: bytes) -> None:
"""Reset the nonce associated with this object"""
_set_nonce(self._ctx, ctypes.c_size_t(len(nonce)), nonce)
def update(self, msg: bytes) -> None:
"""Add the data in msg to the hash"""
_update(self._ctx, ctypes.c_size_t(len(msg)), msg)
def digest(self) -> bytes:
"""Return the hash and increment nonce to begin a new message
.. note:: The hash is reset and the nonce is incremented
when this function is called. This doesn't match
the behavior defined in PEP 452.
"""
result = ctypes.create_string_buffer(self.digest_size)
_digest(self._ctx, ctypes.c_size_t(self.digest_size), result)
return result.raw
def hexdigest(self) -> str:
"""Return the digest as a string of hexadecimal digits"""
return binascii.b2a_hex(self.digest()).decode('ascii')
return _UMAC.new
for lib in ('nettle', 'libnettle', 'libnettle-6'):
_nettle_lib = ctypes.util.find_library(lib)
if _nettle_lib: # pragma: no branch
break
else: # pragma: no cover
_nettle_lib = None
if _nettle_lib: # pragma: no branch
_nettle = ctypes.cdll.LoadLibrary(_nettle_lib)
umac32, umac64, umac96, umac128 = map(_build_umac, (32, 64, 96, 128))

View File

@@ -0,0 +1,417 @@
# Copyright (c) 2017-2021 by Ron Frederick <ronf@timeheart.net> and others.
#
# This program and the accompanying materials are made available under
# the terms of the Eclipse Public License v2.0 which accompanies this
# distribution and is available at:
#
# http://www.eclipse.org/legal/epl-2.0/
#
# This program may also be made available under the following secondary
# licenses when the conditions for such availability set forth in the
# Eclipse Public License v2.0 are satisfied:
#
# GNU General Public License, Version 2.0, or any later versions of
# that license
#
# SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later
#
# Contributors:
# Ron Frederick - initial implementation, API, and documentation
"""A shim around PyCA and PyOpenSSL for X.509 certificates"""
from datetime import datetime, timezone
import re
import sys
from typing import Iterable, List, Optional, Sequence, Set, Union, cast
from cryptography.hazmat.primitives.serialization import Encoding
from cryptography.hazmat.primitives.serialization import PublicFormat
from cryptography import x509
from OpenSSL import crypto
from ..asn1 import IA5String, der_decode, der_encode
from ..misc import ip_address
from .misc import PyCAKey, PyCAPrivateKey, PyCAPublicKey, hashes
_Comment = Union[None, bytes, str]
_Principals = Union[str, Sequence[str]]
_Purposes = Union[None, str, Sequence[str]]
_PurposeOIDs = Union[None, Set[x509.ObjectIdentifier]]
_GeneralNameList = List[x509.GeneralName]
_NameInit = Union[str, x509.Name, Iterable[x509.RelativeDistinguishedName]]
_purpose_to_oid = {
'serverAuth': x509.ExtendedKeyUsageOID.SERVER_AUTH,
'clientAuth': x509.ExtendedKeyUsageOID.CLIENT_AUTH,
'secureShellClient': x509.ObjectIdentifier('1.3.6.1.5.5.7.3.21'),
'secureShellServer': x509.ObjectIdentifier('1.3.6.1.5.5.7.3.22')}
_purpose_any = '2.5.29.37.0'
_nscomment_oid = x509.ObjectIdentifier('2.16.840.1.113730.1.13')
_datetime_min = datetime.fromtimestamp(0, timezone.utc).replace(microsecond=1)
_datetime_32bit_max = datetime.fromtimestamp(2**31 - 1, timezone.utc)
if sys.platform == 'win32': # pragma: no cover
# Windows' datetime.max is year 9999, but timestamps that large don't work
_datetime_max = datetime.max.replace(year=2999, tzinfo=timezone.utc)
else:
_datetime_max = datetime.max.replace(tzinfo=timezone.utc)
def _to_generalized_time(t: int) -> datetime:
"""Convert a timestamp value to a datetime"""
if t <= 0:
return _datetime_min
else:
try:
return datetime.fromtimestamp(t, timezone.utc)
except (OSError, OverflowError):
try:
# Work around a bug in cryptography which shows up on
# systems with a small time_t.
datetime.fromtimestamp(_datetime_max.timestamp() - 1,
timezone.utc)
return _datetime_max
except (OSError, OverflowError): # pragma: no cover
return _datetime_32bit_max
def _to_purpose_oids(purposes: _Purposes) -> _PurposeOIDs:
"""Convert a list of purposes to purpose OIDs"""
if isinstance(purposes, str):
purposes = [p.strip() for p in purposes.split(',')]
if not purposes or 'any' in purposes or _purpose_any in purposes:
purpose_oids = None
else:
purpose_oids = set(_purpose_to_oid.get(p) or x509.ObjectIdentifier(p)
for p in purposes)
return purpose_oids
def _encode_user_principals(principals: _Principals) -> _GeneralNameList:
"""Encode user principals as e-mail addresses"""
if isinstance(principals, str):
principals = [p.strip() for p in principals.split(',')]
return [x509.RFC822Name(name) for name in principals]
def _encode_host_principals(principals: _Principals) -> _GeneralNameList:
"""Encode host principals as DNS names or IP addresses"""
def _encode_host(name: str) -> x509.GeneralName:
"""Encode a host principal as a DNS name or IP address"""
try:
return x509.IPAddress(ip_address(name))
except ValueError:
return x509.DNSName(name)
if isinstance(principals, str):
principals = [p.strip() for p in principals.split(',')]
return [_encode_host(name) for name in principals]
class X509Name(x509.Name):
"""A shim around PyCA for X.509 distinguished names"""
_escape = re.compile(r'([,+\\])')
_unescape = re.compile(r'\\([,+\\])')
_split_rdn = re.compile(r'(?:[^+\\]+|\\.)+')
_split_name = re.compile(r'(?:[^,\\]+|\\.)+')
_attrs = (
('C', x509.NameOID.COUNTRY_NAME),
('ST', x509.NameOID.STATE_OR_PROVINCE_NAME),
('L', x509.NameOID.LOCALITY_NAME),
('O', x509.NameOID.ORGANIZATION_NAME),
('OU', x509.NameOID.ORGANIZATIONAL_UNIT_NAME),
('CN', x509.NameOID.COMMON_NAME),
('DC', x509.NameOID.DOMAIN_COMPONENT))
_to_oid = dict((k, v) for k, v in _attrs)
_from_oid = dict((v, k) for k, v in _attrs)
def __init__(self, name: _NameInit):
if isinstance(name, str):
rdns = self._parse_name(name)
elif isinstance(name, x509.Name):
rdns = name.rdns
else:
rdns = name
super().__init__(rdns)
def __str__(self) -> str:
return ','.join(self._format_rdn(rdn) for rdn in self.rdns)
def _format_rdn(self, rdn: x509.RelativeDistinguishedName) -> str:
"""Format an X.509 RelativeDistinguishedName as a string"""
return '+'.join(sorted(self._format_attr(nameattr) for nameattr in rdn))
def _format_attr(self, nameattr: x509.NameAttribute) -> str:
"""Format an X.509 NameAttribute as a string"""
attr = self._from_oid.get(nameattr.oid) or nameattr.oid.dotted_string
return attr + '=' + self._escape.sub(r'\\\1', cast(str, nameattr.value))
def _parse_name(self, name: str) -> \
Iterable[x509.RelativeDistinguishedName]:
"""Parse an X.509 distinguished name"""
return [self._parse_rdn(rdn) for rdn in self._split_name.findall(name)]
def _parse_rdn(self, rdn: str) -> x509.RelativeDistinguishedName:
"""Parse an X.509 relative distinguished name"""
return x509.RelativeDistinguishedName(
self._parse_nameattr(av) for av in self._split_rdn.findall(rdn))
def _parse_nameattr(self, av: str) -> x509.NameAttribute:
"""Parse an X.509 name attribute/value pair"""
try:
attr, value = av.split('=', 1)
except ValueError:
raise ValueError('Invalid X.509 name attribute: ' + av) from None
try:
attr = attr.strip()
oid = self._to_oid.get(attr) or x509.ObjectIdentifier(attr)
except ValueError:
raise ValueError('Unknown X.509 attribute: ' + attr) from None
return x509.NameAttribute(oid, self._unescape.sub(r'\1', value))
class X509NamePattern:
"""Match X.509 distinguished names"""
def __init__(self, pattern: str):
if pattern.endswith(',*'):
self._pattern = X509Name(pattern[:-2])
self._prefix_len: Optional[int] = len(self._pattern.rdns)
else:
self._pattern = X509Name(pattern)
self._prefix_len = None
def __eq__(self, other: object) -> bool:
# This isn't protected access - both objects are _RSAKey instances
# pylint: disable=protected-access
if not isinstance(other, X509NamePattern): # pragma: no cover
return NotImplemented
return (self._pattern == other._pattern and
self._prefix_len == other._prefix_len)
def __hash__(self) -> int:
return hash((self._pattern, self._prefix_len))
def matches(self, name: X509Name) -> bool:
"""Return whether an X.509 name matches this pattern"""
return self._pattern.rdns == name.rdns[:self._prefix_len]
class X509Certificate:
"""A shim around PyCA and PyOpenSSL for X.509 certificates"""
def __init__(self, cert: x509.Certificate, data: bytes):
self.data = data
self.subject = X509Name(cert.subject)
self.issuer = X509Name(cert.issuer)
self.key_data = cert.public_key().public_bytes(
Encoding.DER, PublicFormat.SubjectPublicKeyInfo)
self.openssl_cert = crypto.X509.from_cryptography(cert)
self.subject_hash = hex(self.openssl_cert.get_subject().hash())[2:]
self.issuer_hash = hex(self.openssl_cert.get_issuer().hash())[2:]
try:
self.purposes: Optional[Set[bytes]] = \
set(cert.extensions.get_extension_for_class(
x509.ExtendedKeyUsage).value)
except x509.ExtensionNotFound:
self.purposes = None
try:
sans = cert.extensions.get_extension_for_class(
x509.SubjectAlternativeName).value
self.user_principals = sans.get_values_for_type(x509.RFC822Name)
self.host_principals = sans.get_values_for_type(x509.DNSName) + \
[str(ip) for ip in sans.get_values_for_type(x509.IPAddress)]
except x509.ExtensionNotFound:
cn = cert.subject.get_attributes_for_oid(x509.NameOID.COMMON_NAME)
principals = [cast(str, attr.value) for attr in cn]
self.user_principals = principals
self.host_principals = principals
try:
comment = cert.extensions.get_extension_for_oid(_nscomment_oid)
comment_der = cast(x509.UnrecognizedExtension, comment.value).value
self.comment: Optional[bytes] = \
cast(IA5String, der_decode(comment_der)).value
except x509.ExtensionNotFound:
self.comment = None
def __eq__(self, other: object) -> bool:
if not isinstance(other, X509Certificate): # pragma: no cover
return NotImplemented
return self.data == other.data
def __hash__(self) -> int:
return hash(self.data)
def validate(self, trust_store: Sequence['X509Certificate'],
purposes: _Purposes, user_principal: str,
host_principal: str) -> None:
"""Validate an X.509 certificate"""
purpose_oids = _to_purpose_oids(purposes)
if purpose_oids and self.purposes and not purpose_oids & self.purposes:
raise ValueError('Certificate purpose mismatch')
if user_principal and user_principal not in self.user_principals:
raise ValueError('Certificate user principal mismatch')
if host_principal and host_principal not in self.host_principals:
raise ValueError('Certificate host principal mismatch')
x509_store = crypto.X509Store()
for c in trust_store:
x509_store.add_cert(c.openssl_cert)
try:
x509_ctx = crypto.X509StoreContext(x509_store, self.openssl_cert,
None)
x509_ctx.verify_certificate()
except crypto.X509StoreContextError as exc:
raise ValueError(f'X.509 chain validation error: {exc}') from None
def generate_x509_certificate(signing_key: PyCAKey, key: PyCAKey,
subject: _NameInit, issuer: Optional[_NameInit],
serial: Optional[int], valid_after: int,
valid_before: int, ca: bool,
ca_path_len: Optional[int], purposes: _Purposes,
user_principals: _Principals,
host_principals: _Principals,
hash_name: str,
comment: _Comment) -> X509Certificate:
"""Generate a new X.509 certificate"""
builder = x509.CertificateBuilder()
subject = X509Name(subject)
issuer = X509Name(issuer) if issuer else subject
self_signed = subject == issuer
builder = builder.subject_name(subject)
builder = builder.issuer_name(issuer)
if serial is None:
serial = x509.random_serial_number()
builder = builder.serial_number(serial)
builder = builder.not_valid_before(_to_generalized_time(valid_after))
builder = builder.not_valid_after(_to_generalized_time(valid_before))
builder = builder.public_key(cast(PyCAPublicKey, key))
if ca:
basic_constraints = x509.BasicConstraints(ca=True,
path_length=ca_path_len)
key_usage = x509.KeyUsage(digital_signature=False,
content_commitment=False,
key_encipherment=False,
data_encipherment=False,
key_agreement=False, key_cert_sign=True,
crl_sign=True, encipher_only=False,
decipher_only=False)
else:
basic_constraints = x509.BasicConstraints(ca=False, path_length=None)
key_usage = x509.KeyUsage(digital_signature=True,
content_commitment=False,
key_encipherment=True,
data_encipherment=False,
key_agreement=True, key_cert_sign=False,
crl_sign=False, encipher_only=False,
decipher_only=False)
builder = builder.add_extension(basic_constraints, critical=True)
if ca or not self_signed:
builder = builder.add_extension(key_usage, critical=True)
purpose_oids = _to_purpose_oids(purposes)
if purpose_oids:
builder = builder.add_extension(x509.ExtendedKeyUsage(purpose_oids),
critical=False)
skid = x509.SubjectKeyIdentifier.from_public_key(cast(PyCAPublicKey, key))
builder = builder.add_extension(skid, critical=False)
if not self_signed:
issuer_pk = cast(PyCAPrivateKey, signing_key).public_key()
akid = x509.AuthorityKeyIdentifier.from_issuer_public_key(issuer_pk)
builder = builder.add_extension(akid, critical=False)
sans = _encode_user_principals(user_principals) + \
_encode_host_principals(host_principals)
if sans:
builder = builder.add_extension(x509.SubjectAlternativeName(sans),
critical=False)
if comment:
if isinstance(comment, str):
comment_bytes = comment.encode('utf-8')
else:
comment_bytes = comment
comment_bytes = der_encode(IA5String(comment_bytes))
builder = builder.add_extension(
x509.UnrecognizedExtension(_nscomment_oid, comment_bytes),
critical=False)
try:
hash_alg = hashes[hash_name]() if hash_name else None
except KeyError:
raise ValueError('Unknown hash algorithm') from None
cert = builder.sign(cast(PyCAPrivateKey, signing_key), hash_alg)
data = cert.public_bytes(Encoding.DER)
return X509Certificate(cert, data)
def import_x509_certificate(data: bytes) -> X509Certificate:
"""Construct an X.509 certificate from DER data"""
cert = x509.load_der_x509_certificate(data)
return X509Certificate(cert, data)

View File

@@ -0,0 +1,258 @@
# Copyright (c) 2013-2021 by Ron Frederick <ronf@timeheart.net> and others.
#
# This program and the accompanying materials are made available under
# the terms of the Eclipse Public License v2.0 which accompanies this
# distribution and is available at:
#
# http://www.eclipse.org/legal/epl-2.0/
#
# This program may also be made available under the following secondary
# licenses when the conditions for such availability set forth in the
# Eclipse Public License v2.0 are satisfied:
#
# GNU General Public License, Version 2.0, or any later versions of
# that license
#
# SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later
#
# Contributors:
# Ron Frederick - initial implementation, API, and documentation
"""DSA public key encryption handler"""
from typing import Optional, Tuple, Union, cast
from .asn1 import ASN1DecodeError, ObjectIdentifier, der_encode, der_decode
from .crypto import DSAPrivateKey, DSAPublicKey
from .misc import all_ints
from .packet import MPInt, String, SSHPacket
from .public_key import SSHKey, SSHOpenSSHCertificateV01, KeyExportError
from .public_key import register_public_key_alg, register_certificate_alg
from .public_key import register_x509_certificate_alg
_PrivateKeyArgs = Tuple[int, int, int, int, int]
_PublicKeyArgs = Tuple[int, int, int, int]
class _DSAKey(SSHKey):
"""Handler for DSA public key encryption"""
_key: Union[DSAPrivateKey, DSAPublicKey]
algorithm = b'ssh-dss'
default_x509_hash = 'sha256'
pem_name = b'DSA'
pkcs8_oid = ObjectIdentifier('1.2.840.10040.4.1')
sig_algorithms = (algorithm,)
x509_algorithms = (b'x509v3-' + algorithm,)
all_sig_algorithms = set(sig_algorithms)
def __eq__(self, other: object) -> bool:
# This isn't protected access - both objects are _DSAKey instances
# pylint: disable=protected-access
return (isinstance(other, type(self)) and
self._key.p == other._key.p and
self._key.q == other._key.q and
self._key.g == other._key.g and
self._key.y == other._key.y and
self._key.x == other._key.x)
def __hash__(self) -> int:
return hash((self._key.p, self._key.q, self._key.g,
self._key.y, self._key.x))
@classmethod
def generate(cls, algorithm: bytes) -> '_DSAKey': # type: ignore
"""Generate a new DSA private key"""
# pylint: disable=arguments-differ,unused-argument
return cls(DSAPrivateKey.generate(key_size=1024))
@classmethod
def make_private(cls, key_params: object) -> SSHKey:
"""Construct a DSA private key"""
p, q, g, y, x = cast(_PrivateKeyArgs, key_params)
return cls(DSAPrivateKey.construct(p, q, g, y, x))
@classmethod
def make_public(cls, key_params: object) -> SSHKey:
"""Construct a DSA public key"""
p, q, g, y = cast(_PublicKeyArgs, key_params)
return cls(DSAPublicKey.construct(p, q, g, y))
@classmethod
def decode_pkcs1_private(cls, key_data: object) -> \
Optional[_PrivateKeyArgs]:
"""Decode a PKCS#1 format DSA private key"""
if (isinstance(key_data, tuple) and len(key_data) == 6 and
all_ints(key_data) and key_data[0] == 0):
return cast(_PrivateKeyArgs, key_data[1:])
else:
return None
@classmethod
def decode_pkcs1_public(cls, key_data: object) -> \
Optional[_PublicKeyArgs]:
"""Decode a PKCS#1 format DSA public key"""
if (isinstance(key_data, tuple) and len(key_data) == 4 and
all_ints(key_data)):
y, p, q, g = key_data
return p, q, g, y
else:
return None
@classmethod
def decode_pkcs8_private(cls, alg_params: object,
data: bytes) -> Optional[_PrivateKeyArgs]:
"""Decode a PKCS#8 format DSA private key"""
try:
x = der_decode(data)
except ASN1DecodeError:
return None
if (isinstance(alg_params, tuple) and len(alg_params) == 3 and
all_ints(alg_params) and isinstance(x, int)):
p, q, g = alg_params
y: int = pow(g, x, p)
return p, q, g, y, x
else:
return None
@classmethod
def decode_pkcs8_public(cls, alg_params: object,
data: bytes) -> Optional[_PublicKeyArgs]:
"""Decode a PKCS#8 format DSA public key"""
try:
y = der_decode(data)
except ASN1DecodeError:
return None
if (isinstance(alg_params, tuple) and len(alg_params) == 3 and
all_ints(alg_params) and isinstance(y, int)):
p, q, g = alg_params
return p, q, g, y
else:
return None
@classmethod
def decode_ssh_private(cls, packet: SSHPacket) -> _PrivateKeyArgs:
"""Decode an SSH format DSA private key"""
p = packet.get_mpint()
q = packet.get_mpint()
g = packet.get_mpint()
y = packet.get_mpint()
x = packet.get_mpint()
return p, q, g, y, x
@classmethod
def decode_ssh_public(cls, packet: SSHPacket) -> _PublicKeyArgs:
"""Decode an SSH format DSA public key"""
p = packet.get_mpint()
q = packet.get_mpint()
g = packet.get_mpint()
y = packet.get_mpint()
return p, q, g, y
def encode_pkcs1_private(self) -> object:
"""Encode a PKCS#1 format DSA private key"""
if not self._key.x:
raise KeyExportError('Key is not private')
return (0, self._key.p, self._key.q, self._key.g,
self._key.y, self._key.x)
def encode_pkcs1_public(self) -> object:
"""Encode a PKCS#1 format DSA public key"""
return (self._key.y, self._key.p, self._key.q, self._key.g)
def encode_pkcs8_private(self) -> Tuple[object, object]:
"""Encode a PKCS#8 format DSA private key"""
if not self._key.x:
raise KeyExportError('Key is not private')
return (self._key.p, self._key.q, self._key.g), der_encode(self._key.x)
def encode_pkcs8_public(self) -> Tuple[object, object]:
"""Encode a PKCS#8 format DSA public key"""
return (self._key.p, self._key.q, self._key.g), der_encode(self._key.y)
def encode_ssh_private(self) -> bytes:
"""Encode an SSH format DSA private key"""
if not self._key.x:
raise KeyExportError('Key is not private')
return b''.join((MPInt(self._key.p), MPInt(self._key.q),
MPInt(self._key.g), MPInt(self._key.y),
MPInt(self._key.x)))
def encode_ssh_public(self) -> bytes:
"""Encode an SSH format DSA public key"""
return b''.join((MPInt(self._key.p), MPInt(self._key.q),
MPInt(self._key.g), MPInt(self._key.y)))
def encode_agent_cert_private(self) -> bytes:
"""Encode DSA certificate private key data for agent"""
if not self._key.x:
raise KeyExportError('Key is not private')
return MPInt(self._key.x)
def sign_ssh(self, data: bytes, sig_algorithm: bytes) -> bytes:
"""Compute an SSH-encoded signature of the specified data"""
# pylint: disable=unused-argument
if not self._key.x:
raise ValueError('Private key needed for signing')
sig = der_decode(self._key.sign(data, 'sha1'))
r, s = cast(Tuple[int, int], sig)
return String(r.to_bytes(20, 'big') + s.to_bytes(20, 'big'))
def verify_ssh(self, data: bytes, sig_algorithm: bytes,
packet: SSHPacket) -> bool:
"""Verify an SSH-encoded signature of the specified data"""
# pylint: disable=unused-argument
sig = packet.get_string()
packet.check_end()
if len(sig) != 40:
return False
r = int.from_bytes(sig[:20], 'big')
s = int.from_bytes(sig[20:], 'big')
return self._key.verify(data, der_encode((r, s)), 'sha1')
register_public_key_alg(b'ssh-dss', _DSAKey, False)
register_certificate_alg(1, b'ssh-dss', b'ssh-dss-cert-v01@openssh.com',
_DSAKey, SSHOpenSSHCertificateV01, False)
for alg in _DSAKey.x509_algorithms:
register_x509_certificate_alg(alg, False)

View File

@@ -0,0 +1,341 @@
# Copyright (c) 2013-2021 by Ron Frederick <ronf@timeheart.net> and others.
#
# This program and the accompanying materials are made available under
# the terms of the Eclipse Public License v2.0 which accompanies this
# distribution and is available at:
#
# http://www.eclipse.org/legal/epl-2.0/
#
# This program may also be made available under the following secondary
# licenses when the conditions for such availability set forth in the
# Eclipse Public License v2.0 are satisfied:
#
# GNU General Public License, Version 2.0, or any later versions of
# that license
#
# SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later
#
# Contributors:
# Ron Frederick - initial implementation, API, and documentation
"""ECDSA public key encryption handler"""
from typing import Dict, Optional, Tuple, Union, cast
from .asn1 import ASN1DecodeError, BitString, ObjectIdentifier, TaggedDERObject
from .asn1 import der_encode, der_decode
from .crypto import CryptoKey, ECDSAPrivateKey, ECDSAPublicKey
from .crypto import lookup_ec_curve_by_params
from .packet import MPInt, String, SSHPacket
from .public_key import SSHKey, SSHOpenSSHCertificateV01
from .public_key import KeyImportError, KeyExportError
from .public_key import register_public_key_alg, register_certificate_alg
from .public_key import register_x509_certificate_alg
_PrivateKeyArgs = Tuple[bytes, Union[bytes, int], bytes]
_PublicKeyArgs = Tuple[bytes, bytes]
# OID for EC prime fields
PRIME_FIELD = ObjectIdentifier('1.2.840.10045.1.1')
_hash_algs = {b'1.3.132.0.10': 'sha256',
b'nistp256': 'sha256',
b'nistp384': 'sha384',
b'nistp521': 'sha512'}
_alg_oids: Dict[bytes, ObjectIdentifier] = {}
_alg_oid_map: Dict[ObjectIdentifier, bytes] = {}
class _ECKey(SSHKey):
"""Handler for elliptic curve public key encryption"""
_key: Union[ECDSAPrivateKey, ECDSAPublicKey]
default_x509_hash = 'sha256'
pem_name = b'EC'
pkcs8_oid = ObjectIdentifier('1.2.840.10045.2.1')
def __init__(self, key: CryptoKey):
super().__init__(key)
self.algorithm = b'ecdsa-sha2-' + self._key.curve_id
self.sig_algorithms = (self.algorithm,)
self.x509_algorithms = (b'x509v3-' + self.algorithm,)
self.all_sig_algorithms = set(self.sig_algorithms)
self._alg_oid = _alg_oids[self._key.curve_id]
self._hash_alg = _hash_algs[self._key.curve_id]
def __eq__(self, other: object) -> bool:
# This isn't protected access - both objects are _ECKey instances
# pylint: disable=protected-access
return (isinstance(other, type(self)) and
self._key.curve_id == other._key.curve_id and
self._key.x == other._key.x and
self._key.y == other._key.y and
self._key.d == other._key.d)
def __hash__(self) -> int:
return hash((self._key.curve_id, self._key.x,
self._key.y, self._key.d))
@classmethod
def _lookup_curve(cls, alg_params: object) -> bytes:
"""Look up an EC curve matching the specified parameters"""
if isinstance(alg_params, ObjectIdentifier):
try:
curve_id = _alg_oid_map[alg_params]
except KeyError:
raise KeyImportError('Unknown elliptic curve OID %s' %
alg_params) from None
elif (isinstance(alg_params, tuple) and len(alg_params) >= 5 and
alg_params[0] == 1 and isinstance(alg_params[1], tuple) and
len(alg_params[1]) == 2 and alg_params[1][0] == PRIME_FIELD and
isinstance(alg_params[2], tuple) and len(alg_params[2]) >= 2 and
isinstance(alg_params[3], bytes) and
isinstance(alg_params[2][0], bytes) and
isinstance(alg_params[2][1], bytes) and
isinstance(alg_params[4], int)):
p = alg_params[1][1]
a = int.from_bytes(alg_params[2][0], 'big')
b = int.from_bytes(alg_params[2][1], 'big')
point = alg_params[3]
n = alg_params[4]
try:
curve_id = lookup_ec_curve_by_params(p, a, b, point, n)
except ValueError as exc:
raise KeyImportError(str(exc)) from None
else:
raise KeyImportError('Invalid EC curve parameters')
return curve_id
@classmethod
def generate(cls, algorithm: bytes) -> '_ECKey': # type: ignore
"""Generate a new EC private key"""
# pylint: disable=arguments-differ
# Strip 'ecdsa-sha2-' prefix of algorithm to get curve_id
return cls(ECDSAPrivateKey.generate(algorithm[11:]))
@classmethod
def make_private(cls, key_params: object) -> SSHKey:
"""Construct an EC private key"""
curve_id, private_value, public_value = \
cast(_PrivateKeyArgs, key_params)
if isinstance(private_value, bytes):
private_value = int.from_bytes(private_value, 'big')
return cls(ECDSAPrivateKey.construct(curve_id, public_value,
private_value))
@classmethod
def make_public(cls, key_params: object) -> SSHKey:
"""Construct an EC public key"""
curve_id, public_value = cast(_PublicKeyArgs, key_params)
return cls(ECDSAPublicKey.construct(curve_id, public_value))
@classmethod
def decode_pkcs1_private(cls, key_data: object) -> \
Optional[_PrivateKeyArgs]:
"""Decode a PKCS#1 format EC private key"""
if (isinstance(key_data, tuple) and len(key_data) > 2 and
key_data[0] == 1 and isinstance(key_data[1], bytes) and
isinstance(key_data[2], TaggedDERObject) and
key_data[2].tag == 0):
alg_params = key_data[2].value
private_key = key_data[1]
if (len(key_data) > 3 and
isinstance(key_data[3], TaggedDERObject) and
key_data[3].tag == 1 and
isinstance(key_data[3].value, BitString) and
key_data[3].value.unused == 0):
public_key: bytes = key_data[3].value.value
else:
public_key = b''
return cls._lookup_curve(alg_params), private_key, public_key
else:
return None
@classmethod
def decode_pkcs1_public(cls, key_data: object) -> \
Optional[_PublicKeyArgs]:
"""Decode a PKCS#1 format EC public key"""
# pylint: disable=unused-argument
raise KeyImportError('PKCS#1 not supported for EC public keys')
@classmethod
def decode_pkcs8_private(cls, alg_params: object,
data: bytes) -> Optional[_PrivateKeyArgs]:
"""Decode a PKCS#8 format EC private key"""
try:
key_data = der_decode(data)
except ASN1DecodeError:
key_data = None
if (isinstance(key_data, tuple) and len(key_data) > 1 and
key_data[0] == 1 and isinstance(key_data[1], bytes)):
private_key = key_data[1]
if (len(key_data) > 2 and
isinstance(key_data[2], TaggedDERObject) and
key_data[2].tag == 1 and
isinstance(key_data[2].value, BitString) and
key_data[2].value.unused == 0):
public_key = key_data[2].value.value
else:
public_key = b''
return cls._lookup_curve(alg_params), private_key, public_key
else:
return None
@classmethod
def decode_pkcs8_public(cls, alg_params: object,
data: bytes) -> Optional[_PublicKeyArgs]:
"""Decode a PKCS#8 format EC public key"""
if isinstance(alg_params, ObjectIdentifier):
return cls._lookup_curve(alg_params), data
else:
return None
@classmethod
def decode_ssh_private(cls, packet: SSHPacket) -> _PrivateKeyArgs:
"""Decode an SSH format EC private key"""
curve_id = packet.get_string()
public_key = packet.get_string()
private_key = packet.get_mpint()
return curve_id, private_key, public_key
@classmethod
def decode_ssh_public(cls, packet: SSHPacket) -> _PublicKeyArgs:
"""Decode an SSH format EC public key"""
curve_id = packet.get_string()
public_key = packet.get_string()
return curve_id, public_key
def encode_public_tagged(self) -> object:
"""Encode an EC public key blob as a tagged bitstring"""
return TaggedDERObject(1, BitString(self._key.public_value))
def encode_pkcs1_private(self) -> object:
"""Encode a PKCS#1 format EC private key"""
if not self._key.private_value:
raise KeyExportError('Key is not private')
return (1, self._key.private_value,
TaggedDERObject(0, self._alg_oid),
self.encode_public_tagged())
def encode_pkcs1_public(self) -> object:
"""Encode a PKCS#1 format EC public key"""
raise KeyExportError('PKCS#1 is not supported for EC public keys')
def encode_pkcs8_private(self) -> Tuple[object, object]:
"""Encode a PKCS#8 format EC private key"""
if not self._key.private_value:
raise KeyExportError('Key is not private')
return self._alg_oid, der_encode((1, self._key.private_value,
self.encode_public_tagged()))
def encode_pkcs8_public(self) -> Tuple[object, object]:
"""Encode a PKCS#8 format EC public key"""
return self._alg_oid, self._key.public_value
def encode_ssh_private(self) -> bytes:
"""Encode an SSH format EC private key"""
if not self._key.d:
raise KeyExportError('Key is not private')
return b''.join((String(self._key.curve_id),
String(self._key.public_value),
MPInt(self._key.d)))
def encode_ssh_public(self) -> bytes:
"""Encode an SSH format EC public key"""
return b''.join((String(self._key.curve_id),
String(self._key.public_value)))
def encode_agent_cert_private(self) -> bytes:
"""Encode ECDSA certificate private key data for agent"""
if not self._key.d:
raise KeyExportError('Key is not private')
return MPInt(self._key.d)
def sign_ssh(self, data: bytes, sig_algorithm: bytes) -> bytes:
"""Compute an SSH-encoded signature of the specified data"""
# pylint: disable=unused-argument
if not self._key.private_value:
raise ValueError('Private key needed for signing')
sig = der_decode(self._key.sign(data, self._hash_alg))
r, s = cast(Tuple[int, int], sig)
return String(MPInt(r) + MPInt(s))
def verify_ssh(self, data: bytes, sig_algorithm: bytes,
packet: SSHPacket) -> bool:
"""Verify an SSH-encoded signature of the specified data"""
# pylint: disable=unused-argument
sig = packet.get_string()
packet.check_end()
packet = SSHPacket(sig)
r = packet.get_mpint()
s = packet.get_mpint()
packet.check_end()
return self._key.verify(data, der_encode((r, s)), self._hash_alg)
for _curve_id, _oid_str in ((b'nistp521', '1.3.132.0.35'),
(b'nistp384', '1.3.132.0.34'),
(b'nistp256', '1.2.840.10045.3.1.7'),
(b'1.3.132.0.10', '1.3.132.0.10')):
_algorithm = b'ecdsa-sha2-' + _curve_id
_cert_algorithm = _algorithm + b'-cert-v01@openssh.com'
_x509_algorithm = b'x509v3-' + _algorithm
_oid = ObjectIdentifier(_oid_str)
_alg_oids[_curve_id] = _oid
_alg_oid_map[_oid] = _curve_id
register_public_key_alg(_algorithm, _ECKey, True, (_algorithm,))
register_certificate_alg(1, _algorithm, _cert_algorithm,
_ECKey, SSHOpenSSHCertificateV01, True)
register_x509_certificate_alg(_x509_algorithm, True)

View File

@@ -0,0 +1,220 @@
# Copyright (c) 2019-2021 by Ron Frederick <ronf@timeheart.net> and others.
#
# This program and the accompanying materials are made available under
# the terms of the Eclipse Public License v2.0 which accompanies this
# distribution and is available at:
#
# http://www.eclipse.org/legal/epl-2.0/
#
# This program may also be made available under the following secondary
# licenses when the conditions for such availability set forth in the
# Eclipse Public License v2.0 are satisfied:
#
# GNU General Public License, Version 2.0, or any later versions of
# that license
#
# SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later
#
# Contributors:
# Ron Frederick - initial implementation, API, and documentation
"""EdDSA public key encryption handler"""
from typing import Optional, Tuple, Union, cast
from .asn1 import ASN1DecodeError, ObjectIdentifier, der_encode, der_decode
from .crypto import EdDSAPrivateKey, EdDSAPublicKey
from .crypto import ed25519_available, ed448_available
from .packet import String, SSHPacket
from .public_key import OMIT, SSHKey, SSHOpenSSHCertificateV01
from .public_key import KeyImportError, KeyExportError
from .public_key import register_public_key_alg, register_certificate_alg
from .public_key import register_x509_certificate_alg
_PrivateKeyArgs = Tuple[bytes]
_PublicKeyArgs = Tuple[bytes]
class _EdKey(SSHKey):
"""Handler for EdDSA public key encryption"""
_key: Union[EdDSAPrivateKey, EdDSAPublicKey]
algorithm = b''
def __eq__(self, other: object) -> bool:
# This isn't protected access - both objects are _EdKey instances
# pylint: disable=protected-access
return (isinstance(other, type(self)) and
self._key.public_value == other._key.public_value and
self._key.private_value == other._key.private_value)
def __hash__(self) -> int:
return hash((self._key.public_value, self._key.private_value))
@classmethod
def generate(cls, algorithm: bytes) -> '_EdKey': # type: ignore
"""Generate a new EdDSA private key"""
# pylint: disable=arguments-differ
# Strip 'ssh-' prefix of algorithm to get curve_id
return cls(EdDSAPrivateKey.generate(algorithm[4:]))
@classmethod
def make_private(cls, key_params: object) -> SSHKey:
"""Construct an EdDSA private key"""
try:
private_value, = cast(_PrivateKeyArgs, key_params)
return cls(EdDSAPrivateKey.construct(cls.algorithm[4:],
private_value))
except (TypeError, ValueError):
raise KeyImportError('Invalid EdDSA private key') from None
@classmethod
def make_public(cls, key_params: object) -> SSHKey:
"""Construct an EdDSA public key"""
try:
public_value, = cast(_PublicKeyArgs, key_params)
return cls(EdDSAPublicKey.construct(cls.algorithm[4:],
public_value))
except (TypeError, ValueError):
raise KeyImportError('Invalid EdDSA public key') from None
@classmethod
def decode_pkcs8_private(cls, alg_params: object,
data: bytes) -> Optional[_PrivateKeyArgs]:
"""Decode a PKCS#8 format EdDSA private key"""
# pylint: disable=unused-argument
try:
return (cast(bytes, der_decode(data)),)
except ASN1DecodeError:
return None
@classmethod
def decode_pkcs8_public(cls, alg_params: object,
data: bytes) -> Optional[_PublicKeyArgs]:
"""Decode a PKCS#8 format EdDSA public key"""
# pylint: disable=unused-argument
return (data,)
@classmethod
def decode_ssh_private(cls, packet: SSHPacket) -> _PrivateKeyArgs:
"""Decode an SSH format EdDSA private key"""
public_value = packet.get_string()
private_value = packet.get_string()
return (private_value[:-len(public_value)],)
@classmethod
def decode_ssh_public(cls, packet: SSHPacket) -> _PublicKeyArgs:
"""Decode an SSH format EdDSA public key"""
public_value = packet.get_string()
return (public_value,)
def encode_pkcs8_private(self) -> Tuple[object, object]:
"""Encode a PKCS#8 format EdDSA private key"""
if not self._key.private_value:
raise KeyExportError('Key is not private')
return OMIT, der_encode(self._key.private_value)
def encode_pkcs8_public(self) -> Tuple[object, object]:
"""Encode a PKCS#8 format EdDSA public key"""
return OMIT, self._key.public_value
def encode_ssh_private(self) -> bytes:
"""Encode an SSH format EdDSA private key"""
if self._key.private_value is None:
raise KeyExportError('Key is not private')
return b''.join((String(self._key.public_value),
String(self._key.private_value +
self._key.public_value)))
def encode_ssh_public(self) -> bytes:
"""Encode an SSH format EdDSA public key"""
return String(self._key.public_value)
def encode_agent_cert_private(self) -> bytes:
"""Encode EdDSA certificate private key data for agent"""
return self.encode_ssh_private()
def sign_ssh(self, data: bytes, sig_algorithm: bytes) -> bytes:
"""Compute an SSH-encoded signature of the specified data"""
# pylint: disable=unused-argument
if not self._key.private_value:
raise ValueError('Private key needed for signing')
return String(self._key.sign(data))
def verify_ssh(self, data: bytes, sig_algorithm: bytes,
packet: SSHPacket) -> bool:
"""Verify an SSH-encoded signature of the specified data"""
# pylint: disable=unused-argument
sig = packet.get_string()
packet.check_end()
return self._key.verify(data, sig)
class _Ed25519Key(_EdKey):
"""Handler for Curve25519 public key encryption"""
algorithm = b'ssh-ed25519'
pkcs8_oid = ObjectIdentifier('1.3.101.112')
sig_algorithms = (algorithm,)
x509_algorithms = (b'x509v3-' + algorithm,)
all_sig_algorithms = set(sig_algorithms)
class _Ed448Key(_EdKey):
"""Handler for Curve448 public key encryption"""
algorithm = b'ssh-ed448'
pkcs8_oid = ObjectIdentifier('1.3.101.113')
sig_algorithms = (algorithm,)
x509_algorithms = (b'x509v3-' + algorithm,)
all_sig_algorithms = set(sig_algorithms)
if ed25519_available: # pragma: no branch
register_public_key_alg(b'ssh-ed25519', _Ed25519Key, True)
register_certificate_alg(1, b'ssh-ed25519',
b'ssh-ed25519-cert-v01@openssh.com',
_Ed25519Key, SSHOpenSSHCertificateV01, True)
for alg in _Ed25519Key.x509_algorithms:
register_x509_certificate_alg(alg, True)
if ed448_available: # pragma: no branch
register_public_key_alg(b'ssh-ed448', _Ed448Key, True)
register_certificate_alg(1, b'ssh-ed448', b'ssh-ed448-cert-v01@openssh.com',
_Ed448Key, SSHOpenSSHCertificateV01, True)
for alg in _Ed448Key.x509_algorithms:
register_x509_certificate_alg(alg, True)

View File

@@ -0,0 +1,959 @@
# Copyright (c) 2016-2022 by Ron Frederick <ronf@timeheart.net> and others.
#
# This program and the accompanying materials are made available under
# the terms of the Eclipse Public License v2.0 which accompanies this
# distribution and is available at:
#
# http://www.eclipse.org/legal/epl-2.0/
#
# This program may also be made available under the following secondary
# licenses when the conditions for such availability set forth in the
# Eclipse Public License v2.0 are satisfied:
#
# GNU General Public License, Version 2.0, or any later versions of
# that license
#
# SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later
#
# Contributors:
# Ron Frederick - initial implementation, API, and documentation
"""Input line editor"""
import re
from functools import partial
from typing import TYPE_CHECKING, Callable, Dict, List
from typing import Optional, Set, Tuple, Union, cast
from unicodedata import east_asian_width
from .session import DataType
if TYPE_CHECKING:
# pylint: disable=cyclic-import
from .channel import SSHServerChannel
from .session import SSHServerSession
_CharDict = Dict[str, object]
_CharHandler = Callable[['SSHLineEditor'], None]
_KeyHandler = Callable[[str, int], Union[bool, Tuple[str, int]]]
_DEFAULT_WIDTH = 80
_ansi_terminals = ('ansi', 'cygwin', 'linux', 'putty', 'screen', 'teraterm',
'cit80', 'vt100', 'vt102', 'vt220', 'vt320', 'xterm',
'xterm-color', 'xterm-16color', 'xterm-256color', 'rxvt',
'rxvt-color')
def _is_wide(ch: str) -> bool:
"""Return display width of character"""
return east_asian_width(ch) in 'WF'
class SSHLineEditor:
"""Input line editor"""
def __init__(self, chan: 'SSHServerChannel[str]',
session: 'SSHServerSession[str]', line_echo: bool,
history_size: int, max_line_length: int, term_type: str,
width: int):
self._chan = chan
self._session = session
self._line_echo = line_echo
self._line_pending = False
self._history_size = history_size if history_size > 0 else 0
self._max_line_length = max_line_length
self._wrap = term_type in _ansi_terminals
self._width = width or _DEFAULT_WIDTH
self._line_mode = True
self._echo = True
self._start_column = 0
self._end_column = 0
self._cursor = 0
self._left_pos = 0
self._right_pos = 0
self._pos = 0
self._line = ''
self._bell_rung = False
self._early_wrap: Set[int] = set()
self._outbuf: List[str] = []
self._keymap: _CharDict = {}
self._key_state = self._keymap
self._erased = ''
self._history: List[str] = []
self._history_index = 0
for func, keys in self._keylist:
for key in keys:
self._add_key(key, func)
self._build_printable()
def _add_key(self, key: str, func: _CharHandler) -> None:
"""Add a key to the keymap"""
keymap = self._keymap
for ch in key[:-1]:
if ch not in keymap:
keymap[ch] = {}
keymap = cast(_CharDict, keymap[ch])
keymap[key[-1]] = func
def _del_key(self, key: str) -> None:
"""Delete a key from the keymap"""
keymap = self._keymap
for ch in key[:-1]:
if ch not in keymap:
return
keymap = cast(_CharDict, keymap[ch])
keymap.pop(key[-1], None)
def _build_printable(self) -> None:
"""Build a regex of printable ASCII non-registered keys"""
def _escape(c: int) -> str:
"""Backslash escape special characters in regex character range"""
ch = chr(c)
return ('\\' if (ch in '-&|[]\\^~') else '') + ch
def _is_printable(ch: str) -> bool:
"""Return if character is printable and has no handler"""
return ch.isprintable() and ch not in keys
pat: List[str] = []
keys = self._keymap.keys()
start = ord(' ')
limit = 0x10000
while start < limit:
while start < limit and not _is_printable(chr(start)):
start += 1
end = start
while _is_printable(chr(end)):
end += 1
pat.append(_escape(start))
if start != end - 1:
pat.append('-' + _escape(end - 1))
start = end + 1
self._printable = re.compile('[' + ''.join(pat) + ']*')
def _char_width(self, pos: int) -> int:
"""Return width of character at specified position"""
return 1 + _is_wide(self._line[pos]) + ((pos + 1) in self._early_wrap)
def _determine_column(self, data: str, column: int,
pos: Optional[int] = None) -> Tuple[str, int]:
"""Determine new output column after output occurs"""
escaped = False
offset = pos
last_wrap_pos = pos
wrapped_data = []
for ch in data:
if ch == '\b':
column -= 1
elif ch == '\x1b':
escaped = True
elif escaped:
if ch == 'm':
escaped = False
else:
if _is_wide(ch) and (column % self._width) == self._width - 1:
column += 1
if pos is not None:
assert last_wrap_pos is not None
assert offset is not None
wrapped_data.append(data[last_wrap_pos - offset:
pos - offset])
last_wrap_pos = pos
self._early_wrap.add(pos)
else:
if pos is not None:
self._early_wrap.discard(pos)
column += 1 + _is_wide(ch)
if pos is not None:
pos += 1
if pos is not None:
assert last_wrap_pos is not None
assert offset is not None
wrapped_data.append(data[last_wrap_pos - offset:])
return ' '.join(wrapped_data), column
else:
return data, column
def _output(self, data: str, pos: Optional[int] = None) -> None:
"""Generate output and calculate new output column"""
idx = data.rfind('\n')
if idx >= 0:
self._outbuf.append(data[:idx+1])
tail = data[idx+1:]
self._cursor = 0
else:
tail = data
data, self._cursor = self._determine_column(tail, self._cursor, pos)
self._outbuf.append(data)
if self._cursor and self._cursor % self._width == 0:
self._outbuf.append(' \b')
def _ring_bell(self) -> None:
"""Ring the terminal bell"""
if not self._bell_rung:
self._outbuf.append('\a')
self._bell_rung = True
def _update_input_window(self, new_pos: int) -> int:
"""Update visible input window when not wrapping onto multiple lines"""
line_len = len(self._line)
if new_pos < self._left_pos:
self._left_pos = new_pos
else:
if new_pos < line_len:
new_pos += 1
pos = self._pos
column = self._cursor
while pos < new_pos:
column += self._char_width(pos)
pos += 1
if column >= self._width:
while column >= self._width:
column -= self._char_width(self._left_pos)
self._left_pos += 1
else:
while self._left_pos > 0:
column += self._char_width(self._left_pos)
if column < self._width:
self._left_pos -= 1
else:
break
column = self._start_column
self._right_pos = self._left_pos
while self._right_pos < line_len:
ch_width = self._char_width(self._right_pos)
if column + ch_width < self._width:
self._right_pos += 1
column += ch_width
else:
break
return column
def _move_cursor(self, column: int) -> None:
"""Move the cursor to selected position in input line"""
start_row = self._cursor // self._width
start_col = self._cursor % self._width
end_row = column // self._width
end_col = column % self._width
if end_row < start_row:
self._outbuf.append('\x1b[' + str(start_row-end_row) + 'A')
elif end_row > start_row:
self._outbuf.append('\x1b[' + str(end_row-start_row) + 'B')
if end_col > start_col:
self._outbuf.append('\x1b[' + str(end_col-start_col) + 'C')
elif end_col < start_col:
self._outbuf.append('\x1b[' + str(start_col-end_col) + 'D')
self._cursor = column
def _move_back(self, column: int) -> None:
"""Move the cursor backward to selected position in input line"""
if self._wrap:
self._move_cursor(column)
else:
self._outbuf.append('\b' * (self._cursor - column))
self._cursor = column
def _clear_to_end(self) -> None:
"""Clear any remaining characters from previous input line"""
column = self._cursor
remaining = self._end_column - column
if remaining > 0:
self._outbuf.append(' ' * remaining)
self._cursor = self._end_column
if self._cursor % self._width == 0:
self._outbuf.append(' \b')
self._move_back(column)
self._end_column = column
def _erase_input(self) -> None:
"""Erase current input line"""
self._move_cursor(self._start_column)
self._clear_to_end()
self._early_wrap.clear()
def _draw_input(self) -> None:
"""Draw current input line"""
if self._line and self._echo:
if self._wrap:
self._output(self._line[:self._pos], 0)
column = self._cursor
self._output(self._line[self._pos:], self._pos)
else:
self._update_input_window(self._pos)
self._output(self._line[self._left_pos:self._pos])
column = self._cursor
self._output(self._line[self._pos:self._right_pos])
self._end_column = self._cursor
self._move_back(column)
def _reposition(self, new_pos: int, new_column: int) -> None:
"""Reposition the cursor to selected position in input"""
if self._echo:
if self._wrap:
self._move_cursor(new_column)
else:
self._update_input(self._pos, self._cursor, new_pos)
self._pos = new_pos
def _update_input(self, pos: int, column: int, new_pos: int) -> None:
"""Update selected portion of current input line"""
if self._echo:
if self._wrap:
if pos in self._early_wrap:
column -= 1
self._move_cursor(column)
prev_wrap = new_pos in self._early_wrap
self._output(self._line[pos:new_pos], pos)
column = self._cursor
self._output(self._line[new_pos:], new_pos)
column += (new_pos in self._early_wrap) - prev_wrap
else:
self._update_input_window(new_pos)
self._move_back(self._start_column)
self._output(self._line[self._left_pos:new_pos])
column = self._cursor
self._output(self._line[new_pos:self._right_pos])
self._clear_to_end()
self._move_back(column)
self._pos = new_pos
def _reset_line(self) -> None:
"""Reset input line to empty"""
self._line = ''
self._left_pos = 0
self._right_pos = 0
self._pos = 0
self._start_column = self._cursor
self._end_column = self._cursor
def _reset_pending(self) -> None:
"""Reset a pending echoed line if any"""
if self._line_pending:
self._erase_input()
self._reset_line()
self._line_pending = False
def _insert_printable(self, data: str) -> None:
"""Insert data into the input line"""
line_len = len(self._line)
data_len = len(data)
if self._max_line_length:
if line_len + data_len > self._max_line_length:
self._ring_bell()
data_len = self._max_line_length - line_len
data = data[:data_len]
if data:
pos = self._pos
new_pos = pos + data_len
self._line = self._line[:pos] + data + self._line[pos:]
self._update_input(pos, self._cursor, new_pos)
def _end_line(self) -> None:
"""End the current input line and send it to the session"""
line = self._line
need_wrap = (self._echo and not self._wrap and
(self._left_pos > 0 or self._right_pos < len(line)))
if self._line_echo or need_wrap:
if need_wrap:
self._output('\b' * (self._cursor - self._start_column) + line)
else:
self._move_to_end()
self._output('\r\n')
self._reset_line()
else:
self._move_to_end()
self._line_pending = True
if self._echo and self._history_size and line:
self._history.append(line)
self._history = self._history[-self._history_size:]
self._history_index = len(self._history)
self._session.data_received(line + '\n', None)
def _eof_or_delete(self) -> None:
"""Erase character to the right, or send EOF if input line is empty"""
if not self._line:
self._session.soft_eof_received()
else:
self._erase_right()
def _erase_left(self) -> None:
"""Erase character to the left"""
if self._pos > 0:
pos = self._pos - 1
column = self._cursor - self._char_width(pos)
self._line = self._line[:pos] + self._line[pos+1:]
self._update_input(pos, column, pos)
else:
self._ring_bell()
def _erase_right(self) -> None:
"""Erase character to the right"""
if self._pos < len(self._line):
pos = self._pos
self._line = self._line[:pos] + self._line[pos+1:]
self._update_input(pos, self._cursor, pos)
else:
self._ring_bell()
def _erase_line(self) -> None:
"""Erase entire input line"""
self._erased = self._line
self._line = ''
self._update_input(0, self._start_column, 0)
def _erase_to_end(self) -> None:
"""Erase to end of input line"""
pos = self._pos
self._erased = self._line[pos:]
self._line = self._line[:pos]
self._update_input(pos, self._cursor, pos)
def _handle_key(self, key: str, handler: _KeyHandler) -> None:
"""Call an external key handler"""
result = handler(self._line, self._pos)
if result is True:
if key.isprintable():
self._insert_printable(key)
else:
self._ring_bell()
elif result is False:
self._ring_bell()
else:
line, new_pos = cast(Tuple[str, int], result)
if new_pos < 0:
self._session.signal_received(line)
else:
self._line = line
self._update_input(0, self._start_column, new_pos)
def _history_prev(self) -> None:
"""Replace input with previous line in history"""
if self._history_index > 0:
self._history_index -= 1
self._line = self._history[self._history_index]
self._update_input(0, self._start_column, len(self._line))
else:
self._ring_bell()
def _history_next(self) -> None:
"""Replace input with next line in history"""
if self._history_index < len(self._history):
self._history_index += 1
if self._history_index < len(self._history):
self._line = self._history[self._history_index]
else:
self._line = ''
self._update_input(0, self._start_column, len(self._line))
else:
self._ring_bell()
def _move_left(self) -> None:
"""Move left in input line"""
if self._pos > 0:
pos = self._pos - 1
column = self._cursor - self._char_width(pos)
self._reposition(pos, column)
else:
self._ring_bell()
def _move_right(self) -> None:
"""Move right in input line"""
if self._pos < len(self._line):
pos = self._pos
column = self._cursor + self._char_width(pos)
self._reposition(pos + 1, column)
else:
self._ring_bell()
def _move_to_start(self) -> None:
"""Move to start of input line"""
self._reposition(0, self._start_column)
def _move_to_end(self) -> None:
"""Move to end of input line"""
self._reposition(len(self._line), self._end_column)
def _redraw(self) -> None:
"""Redraw input line"""
self._erase_input()
self._draw_input()
def _insert_erased(self) -> None:
"""Insert previously erased input"""
self._insert_printable(self._erased)
def _send_break(self) -> None:
"""Send break to session"""
self._session.break_received(0)
_keylist = ((_end_line, ('\n', '\r', '\x1bOM')),
(_eof_or_delete, ('\x04',)),
(_erase_left, ('\x08', '\x7f')),
(_erase_right, ('\x1b[3~',)),
(_erase_line, ('\x15',)),
(_erase_to_end, ('\x0b',)),
(_history_prev, ('\x10', '\x1b[A', '\x1bOA')),
(_history_next, ('\x0e', '\x1b[B', '\x1bOB')),
(_move_left, ('\x02', '\x1b[D', '\x1bOD')),
(_move_right, ('\x06', '\x1b[C', '\x1bOC')),
(_move_to_start, ('\x01', '\x1b[H', '\x1b[1~')),
(_move_to_end, ('\x05', '\x1b[F', '\x1b[4~')),
(_redraw, ('\x12',)),
(_insert_erased, ('\x19',)),
(_send_break, ('\x03', '\x1b[33~')))
def register_key(self, key: str, handler: _KeyHandler) -> None:
"""Register a handler to be called when a key is pressed"""
self._add_key(key, partial(SSHLineEditor._handle_key,
key=key, handler=handler))
self._build_printable()
def unregister_key(self, key: str) -> None:
"""Remove the handler associated with a key"""
self._del_key(key)
self._build_printable()
def set_input(self, line: str, pos: int) -> None:
"""Set input line and cursor position"""
self._reset_pending()
self._line = line
self._update_input(0, self._start_column, pos)
def set_line_mode(self, line_mode: bool) -> None:
"""Enable/disable input line editing"""
self._reset_pending()
if self._line and not line_mode:
data = self._line
self._erase_input()
self._line = ''
self._session.data_received(data, None)
self._line_mode = line_mode
def set_echo(self, echo: bool) -> None:
"""Enable/disable echoing of input in line mode"""
self._reset_pending()
if self._echo and not echo:
self._erase_input()
self._echo = False
elif echo and not self._echo:
self._echo = True
self._draw_input()
def set_width(self, width: int) -> None:
"""Set terminal line width"""
self._reset_pending()
self._width = width or _DEFAULT_WIDTH
if self._wrap:
_, self._cursor = self._determine_column(self._line,
self._start_column, 0)
self._redraw()
def process_input(self, data: str, datatype: DataType) -> None:
"""Process input from channel"""
if self._line_mode:
data_len = len(data)
idx = 0
while idx < data_len:
self._reset_pending()
ch = data[idx]
idx += 1
if ch in self._key_state:
key_state = self._key_state[ch]
if callable(key_state):
try:
cast(_CharHandler, key_state)(self)
finally:
self._key_state = self._keymap
else:
self._key_state = cast(_CharDict, key_state)
elif self._key_state == self._keymap and ch.isprintable():
match = self._printable.match(data, idx - 1)
assert match is not None
match = match[0]
if match:
self._insert_printable(match)
idx += len(match) - 1
else:
self._insert_printable(ch)
else:
self._key_state = self._keymap
self._ring_bell()
self._bell_rung = False
if self._outbuf:
self._chan.write(''.join(self._outbuf))
self._outbuf.clear()
else:
self._session.data_received(data, datatype)
def process_output(self, data: str) -> None:
"""Process output to channel"""
if self._line_pending:
if data.startswith(self._line):
self._start_column = self._cursor
data = data[len(self._line):]
else:
self._erase_input()
self._reset_line()
self._line_pending = False
data = data.replace('\n', '\r\n')
self._erase_input()
self._output(data)
if not self._wrap:
self._cursor %= self._width
self._start_column = self._cursor
self._end_column = self._cursor
self._draw_input()
self._chan.write(''.join(self._outbuf))
self._outbuf.clear()
class SSHLineEditorChannel:
"""Input line editor channel wrapper
When creating server channels with `line_editor` set to `True`,
this class is wrapped around the channel, providing the caller with
the ability to enable and disable input line editing and echoing.
.. note:: Line editing is only available when a pseudo-terminal
is requested on the server channel and the character
encoding on the channel is not set to `None`.
"""
def __init__(self, orig_chan: 'SSHServerChannel[str]',
orig_session: 'SSHServerSession[str]', line_echo: bool,
history_size: int, max_line_length: int):
self._orig_chan = orig_chan
self._orig_session = orig_session
self._line_echo = line_echo
self._history_size = history_size
self._max_line_length = max_line_length
self._editor: Optional[SSHLineEditor] = None
def __getattr__(self, attr: str):
"""Delegate most channel functions to original channel"""
return getattr(self._orig_chan, attr)
def create_editor(self) -> Optional[SSHLineEditor]:
"""Create input line editor if encoding and terminal type are set"""
encoding, _ = self._orig_chan.get_encoding()
term_type = self._orig_chan.get_terminal_type()
width = self._orig_chan.get_terminal_size()[0]
if encoding and term_type:
self._editor = SSHLineEditor(
self._orig_chan, self._orig_session, self._line_echo,
self._history_size, self._max_line_length, term_type, width)
return self._editor
def register_key(self, key: str, handler: _KeyHandler) -> None:
"""Register a handler to be called when a key is pressed
This method registers a handler function which will be called
when a user presses the specified key while inputting a line.
The handler will be called with arguments of the current
input line and cursor position, and updated versions of these
two values should be returned as a tuple.
The handler can also return a tuple of a signal name and
negative cursor position to cause a signal to be delivered
on the channel. In this case, the current input line is left
unchanged but the signal is delivered before processing any
additional input. This can be used to define "hot keys" that
trigger actions unrelated to editing the input.
If the registered key is printable text, returning `True` will
insert that text at the current cursor position, acting as if
no handler was registered for that key. This is useful if you
want to perform a special action in some cases but not others,
such as based on the current cursor position.
Returning `False` will ring the bell and leave the input
unchanged, indicating the requested action could not be
performed.
:param key:
The key sequence to look for
:param handler:
The handler function to call when the key is pressed
:type key: `str`
:type handler: `callable`
"""
assert self._editor is not None
self._editor.register_key(key, handler)
def unregister_key(self, key: str) -> None:
"""Remove the handler associated with a key
This method removes a handler function associated with
the specified key. If the key sequence is printable,
this will cause it to return to being inserted at the
current position when pressed. Otherwise, it will cause
the bell to ring to signal the key is not understood.
:param key:
The key sequence to look for
:type key: `str`
"""
assert self._editor is not None
self._editor.unregister_key(key)
def clear_input(self) -> None:
"""Clear input line
This method clears the current input line.
"""
assert self._editor is not None
self._editor.set_input('', 0)
def set_input(self, line: str, pos: int) -> None:
"""Clear input line
This method sets the current input line and cursor position.
:param line:
The new input line
:param pos:
The new cursor position within the input line
:type line: `str`
:type pos: `int`
"""
assert self._editor is not None
self._editor.set_input(line, pos)
def set_line_mode(self, line_mode: bool) -> None:
"""Enable/disable input line editing
This method enabled or disables input line editing. When set,
only full lines of input are sent to the session, and each
line of input can be edited before it is sent.
:param line_mode:
Whether or not to process input a line at a time
:type line_mode: `bool`
"""
self._orig_chan.logger.info('%s line editor',
'Enabling' if line_mode else 'Disabling')
assert self._editor is not None
self._editor.set_line_mode(line_mode)
def set_echo(self, echo: bool) -> None:
"""Enable/disable echoing of input in line mode
This method enables or disables echoing of input data when
input line editing is enabled.
:param echo:
Whether or not input to echo input as it is entered
:type echo: `bool`
"""
self._orig_chan.logger.info('%s echo',
'Enabling' if echo else 'Disabling')
assert self._editor is not None
self._editor.set_echo(echo)
def write(self, data: str, datatype: DataType = None) -> None:
"""Process data written to the channel"""
if self._editor and datatype is None:
self._editor.process_output(data)
else:
self._orig_chan.write(data, datatype)
class SSHLineEditorSession:
"""Input line editor session wrapper"""
def __init__(self, chan: SSHLineEditorChannel,
orig_session: 'SSHServerSession[str]'):
self._chan = chan
self._orig_session = orig_session
self._editor: Optional[SSHLineEditor] = None
def __getattr__(self, attr: str):
"""Delegate most channel functions to original session"""
return getattr(self._orig_session, attr)
def session_started(self) -> None:
"""Start a session for this newly opened server channel"""
self._editor = self._chan.create_editor()
self._orig_session.session_started()
def terminal_size_changed(self, width: int, height: int,
pixwidth: int, pixheight: int) -> None:
"""The terminal size has changed"""
if self._editor:
self._editor.set_width(width)
self._orig_session.terminal_size_changed(width, height,
pixwidth, pixheight)
def data_received(self, data: str, datatype: DataType) -> None:
"""Process data received from the channel"""
if self._editor:
self._editor.process_input(data, datatype)
else:
self._orig_session.data_received(data, datatype)
def eof_received(self) -> Optional[bool]:
"""Process EOF received from the channel"""
if self._editor:
self._editor.set_line_mode(False)
return self._orig_session.eof_received()

View File

@@ -0,0 +1,318 @@
# Copyright (c) 2013-2021 by Ron Frederick <ronf@timeheart.net> and others.
#
# This program and the accompanying materials are made available under
# the terms of the Eclipse Public License v2.0 which accompanies this
# distribution and is available at:
#
# http://www.eclipse.org/legal/epl-2.0/
#
# This program may also be made available under the following secondary
# licenses when the conditions for such availability set forth in the
# Eclipse Public License v2.0 are satisfied:
#
# GNU General Public License, Version 2.0, or any later versions of
# that license
#
# SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later
#
# Contributors:
# Ron Frederick - initial implementation, API, and documentation
"""Symmetric key encryption handlers"""
from typing import Dict, List, Optional, Tuple, Type
from .crypto import BasicCipher, GCMCipher, ChachaCipher, get_cipher_params
from .mac import MAC, get_mac_params, get_mac
from .packet import UInt64
_EncParams = Tuple[int, int, int, int, int, bool]
_EncParamsMap = Dict[bytes, Tuple[Type['Encryption'], str]]
_enc_algs: List[bytes] = []
_default_enc_algs: List[bytes] = []
_enc_params: _EncParamsMap = {}
class Encryption:
"""Parent class for SSH packet encryption objects"""
@classmethod
def new(cls, cipher_name: str, key: bytes, iv: bytes, mac_alg: bytes = b'',
mac_key: bytes = b'', etm: bool = False) -> 'Encryption':
"""Construct a new SSH packet encryption object"""
raise NotImplementedError
@classmethod
def get_mac_params(cls, mac_alg: bytes) -> Tuple[int, int, bool]:
"""Get parameters of the MAC algorithm used with this encryption"""
return get_mac_params(mac_alg)
def encrypt_packet(self, seq: int, header: bytes,
packet: bytes) -> Tuple[bytes, bytes]:
"""Encrypt and sign an SSH packet"""
raise NotImplementedError
def decrypt_header(self, seq: int, first_block: bytes,
header_len: int) -> Tuple[bytes, bytes]:
"""Decrypt an SSH packet header"""
raise NotImplementedError
def decrypt_packet(self, seq: int, first: bytes, rest: bytes,
header_len: int, mac: bytes) -> Optional[bytes]:
"""Verify the signature of and decrypt an SSH packet"""
raise NotImplementedError
class BasicEncryption(Encryption):
"""Shim for basic encryption"""
def __init__(self, cipher: BasicCipher, mac: MAC):
self._cipher = cipher
self._mac = mac
@classmethod
def new(cls, cipher_name: str, key: bytes, iv: bytes, mac_alg: bytes = b'',
mac_key: bytes = b'', etm: bool = False) -> 'BasicEncryption':
"""Construct a new SSH packet encryption object for basic ciphers"""
cipher = BasicCipher(cipher_name, key, iv)
mac = get_mac(mac_alg, mac_key)
if etm:
return ETMEncryption(cipher, mac)
else:
return cls(cipher, mac)
def encrypt_packet(self, seq: int, header: bytes,
packet: bytes) -> Tuple[bytes, bytes]:
"""Encrypt and sign an SSH packet"""
packet = header + packet
mac = self._mac.sign(seq, packet) if self._mac else b''
return self._cipher.encrypt(packet), mac
def decrypt_header(self, seq: int, first_block: bytes,
header_len: int) -> Tuple[bytes, bytes]:
"""Decrypt an SSH packet header"""
first_block = self._cipher.decrypt(first_block)
return first_block, first_block[:header_len]
def decrypt_packet(self, seq: int, first: bytes, rest: bytes,
header_len: int, mac: bytes) -> Optional[bytes]:
"""Verify the signature of and decrypt an SSH packet"""
packet = first + self._cipher.decrypt(rest)
if self._mac.verify(seq, packet, mac):
return packet[header_len:]
else:
return None
class ETMEncryption(BasicEncryption):
"""Shim for encrypt-then-mac encryption"""
def encrypt_packet(self, seq: int, header: bytes,
packet: bytes) -> Tuple[bytes, bytes]:
"""Encrypt and sign an SSH packet"""
packet = header + self._cipher.encrypt(packet)
return packet, self._mac.sign(seq, packet)
def decrypt_header(self, seq: int, first_block: bytes,
header_len: int) -> Tuple[bytes, bytes]:
"""Decrypt an SSH packet header"""
return first_block, first_block[:header_len]
def decrypt_packet(self, seq: int, first: bytes, rest: bytes,
header_len: int, mac: bytes) -> Optional[bytes]:
"""Verify the signature of and decrypt an SSH packet"""
packet = first + rest
if self._mac.verify(seq, packet, mac):
return self._cipher.decrypt(packet[header_len:])
else:
return None
class GCMEncryption(Encryption):
"""Shim for GCM encryption"""
def __init__(self, cipher: GCMCipher):
self._cipher = cipher
@classmethod
def new(cls, cipher_name: str, key: bytes, iv: bytes, mac_alg: bytes = b'',
mac_key: bytes = b'', etm: bool = False) -> 'GCMEncryption':
"""Construct a new SSH packet encryption object for GCM ciphers"""
return cls(GCMCipher(cipher_name, key, iv))
@classmethod
def get_mac_params(cls, mac_alg: bytes) -> Tuple[int, int, bool]:
"""Get parameters of the MAC algorithm used with this encryption"""
return 0, 16, True
def encrypt_packet(self, seq: int, header: bytes,
packet: bytes) -> Tuple[bytes, bytes]:
"""Encrypt and sign an SSH packet"""
return self._cipher.encrypt_and_sign(header, packet)
def decrypt_header(self, seq: int, first_block: bytes,
header_len: int) -> Tuple[bytes, bytes]:
"""Decrypt an SSH packet header"""
return first_block, first_block[:header_len]
def decrypt_packet(self, seq: int, first: bytes, rest: bytes,
header_len: int, mac: bytes) -> Optional[bytes]:
"""Verify the signature of and decrypt an SSH packet"""
return self._cipher.verify_and_decrypt(first[:header_len],
first[header_len:] + rest, mac)
class ChachaEncryption(Encryption):
"""Shim for chacha20-poly1305 encryption"""
def __init__(self, cipher: ChachaCipher):
self._cipher = cipher
@classmethod
def new(cls, cipher_name: str, key: bytes, iv: bytes, mac_alg: bytes = b'',
mac_key: bytes = b'', etm: bool = False) -> 'ChachaEncryption':
"""Construct a new SSH packet encryption object for Chacha ciphers"""
return cls(ChachaCipher(key))
@classmethod
def get_mac_params(cls, mac_alg: bytes) -> Tuple[int, int, bool]:
"""Get parameters of the MAC algorithm used with this encryption"""
return 0, 16, True
def encrypt_packet(self, seq: int, header: bytes,
packet: bytes) -> Tuple[bytes, bytes]:
"""Encrypt and sign an SSH packet"""
return self._cipher.encrypt_and_sign(header, packet, UInt64(seq))
def decrypt_header(self, seq: int, first_block: bytes,
header_len: int) -> Tuple[bytes, bytes]:
"""Decrypt an SSH packet header"""
return (first_block,
self._cipher.decrypt_header(first_block[:header_len],
UInt64(seq)))
def decrypt_packet(self, seq: int, first: bytes, rest: bytes,
header_len: int, mac: bytes) -> Optional[bytes]:
"""Verify the signature of and decrypt an SSH packet"""
return self._cipher.verify_and_decrypt(first[:header_len],
first[header_len:] + rest,
UInt64(seq), mac)
def register_encryption_alg(enc_alg: bytes, encryption: Type[Encryption],
cipher_name: str, default: bool) -> None:
"""Register an encryption algorithm"""
try:
get_cipher_params(cipher_name)
except KeyError:
pass
else:
_enc_algs.append(enc_alg)
if default:
_default_enc_algs.append(enc_alg)
_enc_params[enc_alg] = (encryption, cipher_name)
def get_encryption_algs() -> List[bytes]:
"""Return supported encryption algorithms"""
return _enc_algs
def get_default_encryption_algs() -> List[bytes]:
"""Return default encryption algorithms"""
return _default_enc_algs
def get_encryption_params(enc_alg: bytes,
mac_alg: bytes = b'') -> _EncParams:
"""Get parameters of an encryption and MAC algorithm"""
encryption, cipher_name = _enc_params[enc_alg]
enc_keysize, enc_ivsize, enc_blocksize = get_cipher_params(cipher_name)
mac_keysize, mac_hashsize, etm = encryption.get_mac_params(mac_alg)
return (enc_keysize, enc_ivsize, enc_blocksize,
mac_keysize, mac_hashsize, etm)
def get_encryption(enc_alg: bytes, key: bytes, iv: bytes, mac_alg: bytes = b'',
mac_key: bytes = b'', etm: bool = False) -> Encryption:
"""Return an object which can encrypt and decrypt SSH packets"""
encryption, cipher_name = _enc_params[enc_alg]
return encryption.new(cipher_name, key, iv, mac_alg, mac_key, etm)
_enc_alg_list = (
(b'chacha20-poly1305@openssh.com', ChachaEncryption,
'chacha20-poly1305', True),
(b'aes256-gcm@openssh.com', GCMEncryption,
'aes256-gcm', True),
(b'aes128-gcm@openssh.com', GCMEncryption,
'aes128-gcm', True),
(b'aes256-ctr', BasicEncryption,
'aes256-ctr', True),
(b'aes192-ctr', BasicEncryption,
'aes192-ctr', True),
(b'aes128-ctr', BasicEncryption,
'aes128-ctr', True),
(b'aes256-cbc', BasicEncryption,
'aes256-cbc', False),
(b'aes192-cbc', BasicEncryption,
'aes192-cbc', False),
(b'aes128-cbc', BasicEncryption,
'aes128-cbc', False),
(b'3des-cbc', BasicEncryption,
'des3-cbc', False),
(b'blowfish-cbc', BasicEncryption,
'blowfish-cbc', False),
(b'cast128-cbc', BasicEncryption,
'cast128-cbc', False),
(b'seed-cbc@ssh.com', BasicEncryption,
'seed-cbc', False),
(b'arcfour256', BasicEncryption,
'arcfour256', False),
(b'arcfour128', BasicEncryption,
'arcfour128', False),
(b'arcfour', BasicEncryption,
'arcfour', False)
)
for _enc_alg_args in _enc_alg_list:
register_encryption_alg(*_enc_alg_args)

View File

@@ -0,0 +1,216 @@
# Copyright (c) 2013-2023 by Ron Frederick <ronf@timeheart.net> and others.
#
# This program and the accompanying materials are made available under
# the terms of the Eclipse Public License v2.0 which accompanies this
# distribution and is available at:
#
# http://www.eclipse.org/legal/epl-2.0/
#
# This program may also be made available under the following secondary
# licenses when the conditions for such availability set forth in the
# Eclipse Public License v2.0 are satisfied:
#
# GNU General Public License, Version 2.0, or any later versions of
# that license
#
# SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later
#
# Contributors:
# Ron Frederick - initial implementation, API, and documentation
"""SSH port forwarding handlers"""
import asyncio
import socket
from typing import TYPE_CHECKING, Awaitable, Callable, Optional, cast
from .misc import ChannelOpenError, SockAddr
if TYPE_CHECKING:
# pylint: disable=cyclic-import
from .connection import SSHConnection
SSHForwarderCoro = Callable[..., Awaitable]
class SSHForwarder(asyncio.BaseProtocol):
"""SSH port forwarding connection handler"""
def __init__(self, peer: Optional['SSHForwarder'] = None):
self._peer = peer
self._transport: Optional[asyncio.Transport] = None
self._inpbuf = b''
self._eof_received = False
if peer:
peer.set_peer(self)
def set_peer(self, peer: 'SSHForwarder') -> None:
"""Set the peer forwarder to exchange data with"""
self._peer = peer
def write(self, data: bytes) -> None:
"""Write data to the transport"""
assert self._transport is not None
self._transport.write(data)
def write_eof(self) -> None:
"""Write end of file to the transport"""
assert self._transport is not None
try:
self._transport.write_eof()
except OSError: # pragma: no cover
pass
def was_eof_received(self) -> bool:
"""Return whether end of file has been received or not"""
return self._eof_received
def pause_reading(self) -> None:
"""Pause reading from the transport"""
assert self._transport is not None
self._transport.pause_reading()
def resume_reading(self) -> None:
"""Resume reading on the transport"""
assert self._transport is not None
self._transport.resume_reading()
def connection_made(self, transport: asyncio.BaseTransport) -> None:
"""Handle a newly opened connection"""
self._transport = cast(Optional['asyncio.Transport'], transport)
sock = cast(socket.socket, transport.get_extra_info('socket'))
if sock and sock.family in {socket.AF_INET, socket.AF_INET6}:
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
def connection_lost(self, exc: Optional[Exception]) -> None:
"""Handle an incoming connection close"""
# pylint: disable=unused-argument
self.close()
def session_started(self) -> None:
"""Handle session start"""
def data_received(self, data: bytes,
datatype: Optional[int] = None) -> None:
"""Handle incoming data from the transport"""
# pylint: disable=unused-argument
if self._peer:
self._peer.write(data)
else:
self._inpbuf += data
def eof_received(self) -> bool:
"""Handle an incoming end of file from the transport"""
self._eof_received = True
if self._peer:
self._peer.write_eof()
return not self._peer.was_eof_received()
else:
return True
def pause_writing(self) -> None:
"""Pause writing by asking peer to pause reading"""
if self._peer: # pragma: no branch
self._peer.pause_reading()
def resume_writing(self) -> None:
"""Resume writing by asking peer to resume reading"""
if self._peer: # pragma: no branch
self._peer.resume_reading()
def close(self) -> None:
"""Close this port forwarder"""
if self._transport:
self._transport.close()
self._transport = None
if self._peer:
peer = self._peer
self._peer = None
peer.close()
class SSHLocalForwarder(SSHForwarder):
"""Local forwarding connection handler"""
def __init__(self, conn: 'SSHConnection', coro: SSHForwarderCoro):
super().__init__()
self._conn = conn
self._coro = coro
async def _forward(self, *args: object) -> None:
"""Begin local forwarding"""
def session_factory() -> SSHForwarder:
"""Return an SSH forwarder"""
return SSHForwarder(self)
try:
await self._coro(session_factory, *args)
except ChannelOpenError as exc:
self.connection_lost(exc)
return
assert self._peer is not None
if self._inpbuf:
self._peer.write(self._inpbuf)
self._inpbuf = b''
if self._eof_received:
self._peer.write_eof()
def forward(self, *args: object) -> None:
"""Start a task to begin local forwarding"""
self._conn.create_task(self._forward(*args))
class SSHLocalPortForwarder(SSHLocalForwarder):
"""Local TCP port forwarding connection handler"""
def connection_made(self, transport: asyncio.BaseTransport) -> None:
"""Handle a newly opened connection"""
super().connection_made(transport)
peername = cast(SockAddr, transport.get_extra_info('peername'))
if peername: # pragma: no branch
orig_host, orig_port = peername[:2]
self.forward(orig_host, orig_port)
class SSHLocalPathForwarder(SSHLocalForwarder):
"""Local UNIX domain socket forwarding connection handler"""
def connection_made(self, transport: asyncio.BaseTransport) -> None:
"""Handle a newly opened connection"""
super().connection_made(transport)
self.forward()

View File

@@ -0,0 +1,63 @@
# Copyright (c) 2017-2023 by Ron Frederick <ronf@timeheart.net> and others.
#
# This program and the accompanying materials are made available under
# the terms of the Eclipse Public License v2.0 which accompanies this
# distribution and is available at:
#
# http://www.eclipse.org/legal/epl-2.0/
#
# This program may also be made available under the following secondary
# licenses when the conditions for such availability set forth in the
# Eclipse Public License v2.0 are satisfied:
#
# GNU General Public License, Version 2.0, or any later versions of
# that license
#
# SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later
#
# Contributors:
# Ron Frederick - initial implementation, API, and documentation
"""GSSAPI wrapper"""
import sys
from typing import Optional
try:
# pylint: disable=unused-import
if sys.platform == 'win32': # pragma: no cover
from .gss_win32 import GSSBase, GSSClient, GSSServer, GSSError
else:
from .gss_unix import GSSBase, GSSClient, GSSServer, GSSError
gss_available = True
except ImportError: # pragma: no cover
gss_available = False
class GSSError(ValueError): # type: ignore
"""Stub class for reporting that GSS is not available"""
def __init__(self, maj_code: int, min_code: int,
token: Optional[bytes] = None):
super().__init__('GSS not available')
self.maj_code = maj_code
self.min_code = min_code
self.token = token
class GSSBase: # type: ignore
"""Base class for reporting that GSS is not available"""
class GSSClient(GSSBase): # type: ignore
"""Stub client class for reporting that GSS is not available"""
def __init__(self, _host: str, _delegate_creds: bool):
raise GSSError(0, 0)
class GSSServer(GSSBase): # type: ignore
"""Stub client class for reporting that GSS is not available"""
def __init__(self, _host: str):
raise GSSError(0, 0)

View File

@@ -0,0 +1,180 @@
# Copyright (c) 2017-2021 by Ron Frederick <ronf@timeheart.net> and others.
#
# This program and the accompanying materials are made available under
# the terms of the Eclipse Public License v2.0 which accompanies this
# distribution and is available at:
#
# http://www.eclipse.org/legal/epl-2.0/
#
# This program may also be made available under the following secondary
# licenses when the conditions for such availability set forth in the
# Eclipse Public License v2.0 are satisfied:
#
# GNU General Public License, Version 2.0, or any later versions of
# that license
#
# SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later
#
# Contributors:
# Ron Frederick - initial implementation, API, and documentation
"""GSSAPI wrapper for UNIX"""
from typing import Optional, Sequence, SupportsBytes, cast
from gssapi import Credentials, Name, NameType, OID
from gssapi import RequirementFlag, SecurityContext
from gssapi.exceptions import GSSError
from .asn1 import OBJECT_IDENTIFIER
def _mech_to_oid(mech: OID) -> bytes:
"""Return a DER-encoded OID corresponding to the requested GSS mechanism"""
mech_bytes = bytes(cast(SupportsBytes, mech))
return bytes((OBJECT_IDENTIFIER, len(mech_bytes))) + mech_bytes
class GSSBase:
"""GSS base class"""
def __init__(self, host: str):
if '@' in host:
self._host = Name(host)
else:
self._host = Name('host@' + host, NameType.hostbased_service)
self._mechs = [_mech_to_oid(mech) for mech in self._creds.mechs]
self._ctx: Optional[SecurityContext] = None
@property
def _creds(self) -> Credentials:
"""Abstract method to construct GSS credentials"""
raise NotImplementedError
def _init_context(self) -> None:
"""Abstract method to construct GSS security context"""
raise NotImplementedError
@property
def mechs(self) -> Sequence[bytes]:
"""Return GSS mechanisms available for this host"""
return self._mechs
@property
def complete(self) -> bool:
"""Return whether or not GSS negotiation is complete"""
return self._ctx.complete if self._ctx else False
@property
def provides_mutual_auth(self) -> bool:
"""Return whether or not this context provides mutual authentication"""
assert self._ctx is not None
return bool(self._ctx.actual_flags &
RequirementFlag.mutual_authentication)
@property
def provides_integrity(self) -> bool:
"""Return whether or not this context provides integrity protection"""
assert self._ctx is not None
return bool(self._ctx.actual_flags & RequirementFlag.integrity)
@property
def user(self) -> str:
"""Return user principal associated with this context"""
assert self._ctx is not None
return str(self._ctx.initiator_name)
@property
def host(self) -> str:
"""Return host principal associated with this context"""
assert self._ctx is not None
return str(self._ctx.target_name)
def reset(self) -> None:
"""Reset GSS security context"""
self._ctx = None
def step(self, token: Optional[bytes] = None) -> Optional[bytes]:
"""Perform next step in GSS security exchange"""
if not self._ctx:
self._init_context()
assert self._ctx is not None
return self._ctx.step(token)
def sign(self, data: bytes) -> bytes:
"""Sign a block of data"""
assert self._ctx is not None
return self._ctx.get_signature(data)
def verify(self, data: bytes, sig: bytes) -> bool:
"""Verify a signature for a block of data"""
assert self._ctx is not None
try:
self._ctx.verify_signature(data, sig)
return True
except GSSError:
return False
class GSSClient(GSSBase):
"""GSS client"""
def __init__(self, host: str, delegate_creds: bool):
super().__init__(host)
flags = RequirementFlag.mutual_authentication | \
RequirementFlag.integrity
if delegate_creds:
flags |= RequirementFlag.delegate_to_peer
self._flags = flags
@property
def _creds(self) -> Credentials:
"""Abstract method to construct GSS credentials"""
return Credentials(usage='initiate')
def _init_context(self) -> None:
"""Construct GSS client security context"""
self._ctx = SecurityContext(name=self._host, creds=self._creds,
flags=self._flags)
class GSSServer(GSSBase):
"""GSS server"""
@property
def _creds(self) -> Credentials:
"""Abstract method to construct GSS credentials"""
return Credentials(name=self._host, usage='accept')
def _init_context(self) -> None:
"""Construct GSS server security context"""
self._ctx = SecurityContext(creds=self._creds)

Some files were not shown because too many files have changed in this diff Show More