Skip to content

Commit

Permalink
pyln: Implement sphinx onion packet generation in python
Browse files Browse the repository at this point in the history
Suggested-by: Rusty Russell <@rustyrussell>
Signed-off-by: Christian Decker <@cdecker>
  • Loading branch information
cdecker committed Jul 31, 2020
1 parent fc97268 commit d5c9e85
Show file tree
Hide file tree
Showing 2 changed files with 518 additions and 6 deletions.
255 changes: 251 additions & 4 deletions contrib/pyln-proto/pyln/proto/onion.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
from .primitives import varint_decode, varint_encode
from io import BytesIO, SEEK_CUR
from .primitives import varint_decode, varint_encode, Secret
from .wire import PrivateKey, PublicKey, ecdh
from binascii import hexlify, unhexlify
from collections import namedtuple
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes, hmac
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms
from hashlib import sha256
from io import BytesIO, SEEK_CUR
from typing import List, Optional, Union
import coincurve
import os
import struct


Expand Down Expand Up @@ -50,14 +59,19 @@ def __init__(self, amt_to_forward, outgoing_cltv_value,
self.outgoing_cltv_value = outgoing_cltv_value

if isinstance(short_channel_id, str) and 'x' in short_channel_id:
# Convert the short_channel_id from its string representation to its numeric representation
# Convert the short_channel_id from its string representation to
# its numeric representation
block, tx, out = short_channel_id.split('x')
num_scid = int(block) << 40 | int(tx) << 16 | int(out)
self.short_channel_id = num_scid
elif isinstance(short_channel_id, int):
self.short_channel_id = short_channel_id
else:
raise ValueError("short_channel_id format cannot be recognized: {}".format(short_channel_id))
raise ValueError(
"short_channel_id format cannot be recognized: {}".format(
short_channel_id
)
)

@classmethod
def from_bytes(cls, b):
Expand Down Expand Up @@ -242,6 +256,239 @@ class SignatureField(TlvField):
pass


VERSION_SIZE = 1
REALM_SIZE = 1
HMAC_SIZE = 32
PUBKEY_SIZE = 33
ROUTING_INFO_SIZE = 1300
TOTAL_PACKET_SIZE = VERSION_SIZE + PUBKEY_SIZE + HMAC_SIZE + ROUTING_INFO_SIZE


class RoutingOnion(object):
def __init__(
self, version: int,
ephemeralkey: PublicKey,
payloads: bytes,
hmac: bytes
):
assert(len(payloads) == ROUTING_INFO_SIZE)
self.version = version
self.payloads = payloads
self.ephemeralkey = ephemeralkey
self.hmac = hmac

@classmethod
def from_bin(cls, b: bytes):
if len(b) != TOTAL_PACKET_SIZE:
raise ValueError(
"Encoded binary RoutingOnion size mismatch: {} != {}".format(
len(b), TOTAL_PACKET_SIZE
)
)

version = int(b[0])
ephemeralkey = PublicKey(b[1:34])
payloads = b[34:1334]
hmac = b[1334:]

assert(len(payloads) == ROUTING_INFO_SIZE and
len(hmac) == HMAC_SIZE)
return cls(version=version, ephemeralkey=ephemeralkey,
payloads=payloads, hmac=hmac)

@classmethod
def from_hex(cls, s: str):
return cls.from_bin(unhexlify(s))

def to_bin(self) -> bytes:
ephkey = self.ephemeralkey.to_bytes()

return struct.pack("b", self.version) + \
ephkey + \
self.payloads + \
self.hmac

def to_hex(self):
return hexlify(self.to_bin())


KeySet = namedtuple('KeySet', ['rho', 'mu', 'um', 'pad', 'gamma', 'pi'])


def xor_inplace(d: Union[bytearray, memoryview],
a: Union[bytearray, memoryview],
b: Union[bytearray, memoryview]):
"""Compute a xor b and store the result in d
"""
assert(len(a) == len(b) and len(d) == len(b))
for i in range(len(a)):
d[i] = a[i] ^ b[i]


def xor(a: Union[bytearray, memoryview],
b: Union[bytearray, memoryview]) -> bytearray:
assert(len(a) == len(b))
d = bytearray(len(a))
xor_inplace(d, a, b)
return d


def generate_key(secret: bytes, prefix: bytes):
h = hmac.HMAC(prefix, hashes.SHA256(), backend=default_backend())
h.update(secret)
return h.finalize()


def generate_keyset(secret: Secret) -> KeySet:
types = [bytes(f, 'ascii') for f in KeySet._fields]
keys = [generate_key(secret.data, t) for t in types]
return KeySet(*keys)


class SphinxHopParam(object):
def __init__(self, secret: Secret, ephemeralkey: PublicKey):
self.secret = secret
self.ephemeralkey = ephemeralkey
self.blind = blind(self.ephemeralkey, self.secret)
self.keys = generate_keyset(self.secret)


class SphinxHop(object):
def __init__(self, pubkey: PublicKey, payload: bytes):
self.pubkey = pubkey
self.payload = payload
self.hmac: Optional[bytes] = None

def __len__(self):
return len(self.payload) + HMAC_SIZE


def blind(pubkey, sharedsecret) -> Secret:
m = sha256()
m.update(pubkey.to_bytes())
m.update(sharedsecret.to_bytes())
return Secret(m.digest())


def blind_group_element(pubkey, blind: Secret) -> PublicKey:
pubkey = coincurve.PublicKey(data=pubkey.to_bytes())
blinded = pubkey.multiply(blind.to_bytes(), update=False)
return PublicKey(blinded.format(compressed=True))


def chacha20_stream(key: bytes, dest: Union[bytearray, memoryview]):
algorithm = algorithms.ChaCha20(key, b'\x00'*16)
cipher = Cipher(algorithm, None, backend=default_backend())
encryptor = cipher.encryptor()
encryptor.update_into(dest, dest)


class SphinxPath(object):
def __init__(self, hops: List[SphinxHop], assocdata: bytes = None,
session_key: Optional[Secret] = None):
self.hops = hops
self.assocdata: Optional[bytes] = assocdata
if session_key is not None:
self.session_key = session_key
else:
self.session_key = Secret(os.urandom(32))

def get_filler(self) -> memoryview:
filler_size = sum(len(h) for h in self.hops[1:])
filler = memoryview(bytearray(filler_size))
params = self.get_hop_params()

for i in range(len(self.hops[:-1])):
h = self.hops[i]
p = params[i]
filler_offset = sum(len(sph) for sph in self.hops[:i])

filler_start = ROUTING_INFO_SIZE - filler_offset
filler_end = ROUTING_INFO_SIZE + len(h)
filler_len = filler_end-filler_start
stream = bytearray(filler_end)
chacha20_stream(p.keys.rho, stream)
xor_inplace(filler[:filler_len], filler[:filler_len],
stream[filler_start:filler_end])

return filler

def compile(self) -> RoutingOnion:
buf = bytearray(ROUTING_INFO_SIZE)

# Prefill the buffer with the pseudorandom stream to avoid telling the
# last hop the real payload size through zero ranges.
padkey = generate_key(self.session_key.data, b'pad')
params = self.get_hop_params()
chacha20_stream(padkey, buf)

filler = self.get_filler()
nexthmac = bytes(32)
for i, h, p in zip(
range(len(self.hops)),
reversed(self.hops),
reversed(params)):
h.hmac = nexthmac
shift_size = len(h)
assert(shift_size == len(h.payload) + HMAC_SIZE)
buf[shift_size:] = buf[:ROUTING_INFO_SIZE-shift_size]
buf[:shift_size] = h.payload + h.hmac

# Encrypt
chacha20_stream(p.keys.rho, buf)

if i == 0:
# Place the filler at the correct position
buf[ROUTING_INFO_SIZE-len(filler):] = filler

# Finally compute the hmac that the next hop will use to verify
# the onion's integrity.
hh = hmac.HMAC(p.keys.mu, hashes.SHA256(),
backend=default_backend())
hh.update(buf)
if self.assocdata is not None:
hh.update(self.assocdata)
nexthmac = hh.finalize()

return RoutingOnion(
version=0,
ephemeralkey=params[0].ephemeralkey,
hmac=nexthmac,
payloads=buf,
)

def get_hop_params(self) -> List[SphinxHopParam]:
assert(self.session_key is not None)
secret = ecdh(PrivateKey(self.session_key.data),
self.hops[0].pubkey)
sph = SphinxHopParam(
ephemeralkey=PrivateKey(self.session_key.data).public_key(),
secret=ecdh(PrivateKey(self.session_key.data),
self.hops[0].pubkey)
)

params = [sph]
for i, h in enumerate(self.hops[1:]):
prev = params[-1]
ek = blind_group_element(prev.ephemeralkey,
prev.blind)

# Start by blinding the current hop's pubkey with the session_key
temp = blind_group_element(h.pubkey, self.session_key)

# Then apply blind for all previous hops
for p in params:
temp = blind_group_element(temp, p.blind)

# Finally hash the compressed resulting pubkey to get the secret
secret = Secret(sha256(temp.to_bytes()).digest())

sph = SphinxHopParam(secret=secret, ephemeralkey=ek)
params.append(sph)

return params


# A mapping of known TLV types
tlv_types = {
2: (Tu64Field, 'amt_to_forward'),
Expand Down
Loading

0 comments on commit d5c9e85

Please sign in to comment.