342 lines
12 KiB
Python
342 lines
12 KiB
Python
# Copyright (c) 2013-2021 by Ron Frederick <ronf@timeheart.net> and others.
|
|
#
|
|
# This program and the accompanying materials are made available under
|
|
# the terms of the Eclipse Public License v2.0 which accompanies this
|
|
# distribution and is available at:
|
|
#
|
|
# http://www.eclipse.org/legal/epl-2.0/
|
|
#
|
|
# This program may also be made available under the following secondary
|
|
# licenses when the conditions for such availability set forth in the
|
|
# Eclipse Public License v2.0 are satisfied:
|
|
#
|
|
# GNU General Public License, Version 2.0, or any later versions of
|
|
# that license
|
|
#
|
|
# SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later
|
|
#
|
|
# Contributors:
|
|
# Ron Frederick - initial implementation, API, and documentation
|
|
|
|
"""ECDSA public key encryption handler"""
|
|
|
|
from typing import Dict, Optional, Tuple, Union, cast
|
|
|
|
from .asn1 import ASN1DecodeError, BitString, ObjectIdentifier, TaggedDERObject
|
|
from .asn1 import der_encode, der_decode
|
|
from .crypto import CryptoKey, ECDSAPrivateKey, ECDSAPublicKey
|
|
from .crypto import lookup_ec_curve_by_params
|
|
from .packet import MPInt, String, SSHPacket
|
|
from .public_key import SSHKey, SSHOpenSSHCertificateV01
|
|
from .public_key import KeyImportError, KeyExportError
|
|
from .public_key import register_public_key_alg, register_certificate_alg
|
|
from .public_key import register_x509_certificate_alg
|
|
|
|
|
|
_PrivateKeyArgs = Tuple[bytes, Union[bytes, int], bytes]
|
|
_PublicKeyArgs = Tuple[bytes, bytes]
|
|
|
|
|
|
# OID for EC prime fields
|
|
PRIME_FIELD = ObjectIdentifier('1.2.840.10045.1.1')
|
|
|
|
_hash_algs = {b'1.3.132.0.10': 'sha256',
|
|
b'nistp256': 'sha256',
|
|
b'nistp384': 'sha384',
|
|
b'nistp521': 'sha512'}
|
|
|
|
_alg_oids: Dict[bytes, ObjectIdentifier] = {}
|
|
_alg_oid_map: Dict[ObjectIdentifier, bytes] = {}
|
|
|
|
|
|
class _ECKey(SSHKey):
|
|
"""Handler for elliptic curve public key encryption"""
|
|
|
|
_key: Union[ECDSAPrivateKey, ECDSAPublicKey]
|
|
|
|
default_x509_hash = 'sha256'
|
|
pem_name = b'EC'
|
|
pkcs8_oid = ObjectIdentifier('1.2.840.10045.2.1')
|
|
|
|
def __init__(self, key: CryptoKey):
|
|
super().__init__(key)
|
|
|
|
self.algorithm = b'ecdsa-sha2-' + self._key.curve_id
|
|
self.sig_algorithms = (self.algorithm,)
|
|
self.x509_algorithms = (b'x509v3-' + self.algorithm,)
|
|
self.all_sig_algorithms = set(self.sig_algorithms)
|
|
|
|
self._alg_oid = _alg_oids[self._key.curve_id]
|
|
self._hash_alg = _hash_algs[self._key.curve_id]
|
|
|
|
def __eq__(self, other: object) -> bool:
|
|
# This isn't protected access - both objects are _ECKey instances
|
|
# pylint: disable=protected-access
|
|
|
|
return (isinstance(other, type(self)) and
|
|
self._key.curve_id == other._key.curve_id and
|
|
self._key.x == other._key.x and
|
|
self._key.y == other._key.y and
|
|
self._key.d == other._key.d)
|
|
|
|
def __hash__(self) -> int:
|
|
return hash((self._key.curve_id, self._key.x,
|
|
self._key.y, self._key.d))
|
|
|
|
@classmethod
|
|
def _lookup_curve(cls, alg_params: object) -> bytes:
|
|
"""Look up an EC curve matching the specified parameters"""
|
|
|
|
if isinstance(alg_params, ObjectIdentifier):
|
|
try:
|
|
curve_id = _alg_oid_map[alg_params]
|
|
except KeyError:
|
|
raise KeyImportError('Unknown elliptic curve OID %s' %
|
|
alg_params) from None
|
|
elif (isinstance(alg_params, tuple) and len(alg_params) >= 5 and
|
|
alg_params[0] == 1 and isinstance(alg_params[1], tuple) and
|
|
len(alg_params[1]) == 2 and alg_params[1][0] == PRIME_FIELD and
|
|
isinstance(alg_params[2], tuple) and len(alg_params[2]) >= 2 and
|
|
isinstance(alg_params[3], bytes) and
|
|
isinstance(alg_params[2][0], bytes) and
|
|
isinstance(alg_params[2][1], bytes) and
|
|
isinstance(alg_params[4], int)):
|
|
p = alg_params[1][1]
|
|
a = int.from_bytes(alg_params[2][0], 'big')
|
|
b = int.from_bytes(alg_params[2][1], 'big')
|
|
point = alg_params[3]
|
|
n = alg_params[4]
|
|
|
|
try:
|
|
curve_id = lookup_ec_curve_by_params(p, a, b, point, n)
|
|
except ValueError as exc:
|
|
raise KeyImportError(str(exc)) from None
|
|
else:
|
|
raise KeyImportError('Invalid EC curve parameters')
|
|
|
|
return curve_id
|
|
|
|
@classmethod
|
|
def generate(cls, algorithm: bytes) -> '_ECKey': # type: ignore
|
|
"""Generate a new EC private key"""
|
|
|
|
# pylint: disable=arguments-differ
|
|
|
|
# Strip 'ecdsa-sha2-' prefix of algorithm to get curve_id
|
|
return cls(ECDSAPrivateKey.generate(algorithm[11:]))
|
|
|
|
@classmethod
|
|
def make_private(cls, key_params: object) -> SSHKey:
|
|
"""Construct an EC private key"""
|
|
|
|
curve_id, private_value, public_value = \
|
|
cast(_PrivateKeyArgs, key_params)
|
|
|
|
if isinstance(private_value, bytes):
|
|
private_value = int.from_bytes(private_value, 'big')
|
|
|
|
return cls(ECDSAPrivateKey.construct(curve_id, public_value,
|
|
private_value))
|
|
|
|
@classmethod
|
|
def make_public(cls, key_params: object) -> SSHKey:
|
|
"""Construct an EC public key"""
|
|
|
|
curve_id, public_value = cast(_PublicKeyArgs, key_params)
|
|
|
|
return cls(ECDSAPublicKey.construct(curve_id, public_value))
|
|
|
|
@classmethod
|
|
def decode_pkcs1_private(cls, key_data: object) -> \
|
|
Optional[_PrivateKeyArgs]:
|
|
"""Decode a PKCS#1 format EC private key"""
|
|
|
|
if (isinstance(key_data, tuple) and len(key_data) > 2 and
|
|
key_data[0] == 1 and isinstance(key_data[1], bytes) and
|
|
isinstance(key_data[2], TaggedDERObject) and
|
|
key_data[2].tag == 0):
|
|
alg_params = key_data[2].value
|
|
private_key = key_data[1]
|
|
|
|
if (len(key_data) > 3 and
|
|
isinstance(key_data[3], TaggedDERObject) and
|
|
key_data[3].tag == 1 and
|
|
isinstance(key_data[3].value, BitString) and
|
|
key_data[3].value.unused == 0):
|
|
public_key: bytes = key_data[3].value.value
|
|
else:
|
|
public_key = b''
|
|
|
|
return cls._lookup_curve(alg_params), private_key, public_key
|
|
else:
|
|
return None
|
|
|
|
@classmethod
|
|
def decode_pkcs1_public(cls, key_data: object) -> \
|
|
Optional[_PublicKeyArgs]:
|
|
"""Decode a PKCS#1 format EC public key"""
|
|
|
|
# pylint: disable=unused-argument
|
|
|
|
raise KeyImportError('PKCS#1 not supported for EC public keys')
|
|
|
|
@classmethod
|
|
def decode_pkcs8_private(cls, alg_params: object,
|
|
data: bytes) -> Optional[_PrivateKeyArgs]:
|
|
"""Decode a PKCS#8 format EC private key"""
|
|
|
|
try:
|
|
key_data = der_decode(data)
|
|
except ASN1DecodeError:
|
|
key_data = None
|
|
|
|
if (isinstance(key_data, tuple) and len(key_data) > 1 and
|
|
key_data[0] == 1 and isinstance(key_data[1], bytes)):
|
|
private_key = key_data[1]
|
|
|
|
if (len(key_data) > 2 and
|
|
isinstance(key_data[2], TaggedDERObject) and
|
|
key_data[2].tag == 1 and
|
|
isinstance(key_data[2].value, BitString) and
|
|
key_data[2].value.unused == 0):
|
|
public_key = key_data[2].value.value
|
|
else:
|
|
public_key = b''
|
|
|
|
return cls._lookup_curve(alg_params), private_key, public_key
|
|
else:
|
|
return None
|
|
|
|
@classmethod
|
|
def decode_pkcs8_public(cls, alg_params: object,
|
|
data: bytes) -> Optional[_PublicKeyArgs]:
|
|
"""Decode a PKCS#8 format EC public key"""
|
|
|
|
if isinstance(alg_params, ObjectIdentifier):
|
|
return cls._lookup_curve(alg_params), data
|
|
else:
|
|
return None
|
|
|
|
@classmethod
|
|
def decode_ssh_private(cls, packet: SSHPacket) -> _PrivateKeyArgs:
|
|
"""Decode an SSH format EC private key"""
|
|
|
|
curve_id = packet.get_string()
|
|
public_key = packet.get_string()
|
|
private_key = packet.get_mpint()
|
|
|
|
return curve_id, private_key, public_key
|
|
|
|
@classmethod
|
|
def decode_ssh_public(cls, packet: SSHPacket) -> _PublicKeyArgs:
|
|
"""Decode an SSH format EC public key"""
|
|
|
|
curve_id = packet.get_string()
|
|
public_key = packet.get_string()
|
|
|
|
return curve_id, public_key
|
|
|
|
def encode_public_tagged(self) -> object:
|
|
"""Encode an EC public key blob as a tagged bitstring"""
|
|
|
|
return TaggedDERObject(1, BitString(self._key.public_value))
|
|
|
|
def encode_pkcs1_private(self) -> object:
|
|
"""Encode a PKCS#1 format EC private key"""
|
|
|
|
if not self._key.private_value:
|
|
raise KeyExportError('Key is not private')
|
|
|
|
return (1, self._key.private_value,
|
|
TaggedDERObject(0, self._alg_oid),
|
|
self.encode_public_tagged())
|
|
|
|
def encode_pkcs1_public(self) -> object:
|
|
"""Encode a PKCS#1 format EC public key"""
|
|
|
|
raise KeyExportError('PKCS#1 is not supported for EC public keys')
|
|
|
|
def encode_pkcs8_private(self) -> Tuple[object, object]:
|
|
"""Encode a PKCS#8 format EC private key"""
|
|
|
|
if not self._key.private_value:
|
|
raise KeyExportError('Key is not private')
|
|
|
|
return self._alg_oid, der_encode((1, self._key.private_value,
|
|
self.encode_public_tagged()))
|
|
|
|
def encode_pkcs8_public(self) -> Tuple[object, object]:
|
|
"""Encode a PKCS#8 format EC public key"""
|
|
|
|
return self._alg_oid, self._key.public_value
|
|
|
|
def encode_ssh_private(self) -> bytes:
|
|
"""Encode an SSH format EC private key"""
|
|
|
|
if not self._key.d:
|
|
raise KeyExportError('Key is not private')
|
|
|
|
return b''.join((String(self._key.curve_id),
|
|
String(self._key.public_value),
|
|
MPInt(self._key.d)))
|
|
|
|
def encode_ssh_public(self) -> bytes:
|
|
"""Encode an SSH format EC public key"""
|
|
|
|
return b''.join((String(self._key.curve_id),
|
|
String(self._key.public_value)))
|
|
|
|
def encode_agent_cert_private(self) -> bytes:
|
|
"""Encode ECDSA certificate private key data for agent"""
|
|
|
|
if not self._key.d:
|
|
raise KeyExportError('Key is not private')
|
|
|
|
return MPInt(self._key.d)
|
|
|
|
def sign_ssh(self, data: bytes, sig_algorithm: bytes) -> bytes:
|
|
"""Compute an SSH-encoded signature of the specified data"""
|
|
|
|
# pylint: disable=unused-argument
|
|
|
|
if not self._key.private_value:
|
|
raise ValueError('Private key needed for signing')
|
|
|
|
sig = der_decode(self._key.sign(data, self._hash_alg))
|
|
r, s = cast(Tuple[int, int], sig)
|
|
return String(MPInt(r) + MPInt(s))
|
|
|
|
def verify_ssh(self, data: bytes, sig_algorithm: bytes,
|
|
packet: SSHPacket) -> bool:
|
|
"""Verify an SSH-encoded signature of the specified data"""
|
|
|
|
# pylint: disable=unused-argument
|
|
|
|
sig = packet.get_string()
|
|
packet.check_end()
|
|
|
|
packet = SSHPacket(sig)
|
|
r = packet.get_mpint()
|
|
s = packet.get_mpint()
|
|
packet.check_end()
|
|
|
|
return self._key.verify(data, der_encode((r, s)), self._hash_alg)
|
|
|
|
|
|
for _curve_id, _oid_str in ((b'nistp521', '1.3.132.0.35'),
|
|
(b'nistp384', '1.3.132.0.34'),
|
|
(b'nistp256', '1.2.840.10045.3.1.7'),
|
|
(b'1.3.132.0.10', '1.3.132.0.10')):
|
|
_algorithm = b'ecdsa-sha2-' + _curve_id
|
|
_cert_algorithm = _algorithm + b'-cert-v01@openssh.com'
|
|
_x509_algorithm = b'x509v3-' + _algorithm
|
|
|
|
_oid = ObjectIdentifier(_oid_str)
|
|
_alg_oids[_curve_id] = _oid
|
|
_alg_oid_map[_oid] = _curve_id
|
|
|
|
register_public_key_alg(_algorithm, _ECKey, True, (_algorithm,))
|
|
register_certificate_alg(1, _algorithm, _cert_algorithm,
|
|
_ECKey, SSHOpenSSHCertificateV01, True)
|
|
register_x509_certificate_alg(_x509_algorithm, True)
|