|
|
|
|
@ -677,6 +677,8 @@ class LNWallet(LNWorker):
|
|
|
|
|
# map forwarded htlcs (fw_info=(scid_hex, htlc_id)) to originating peer pubkeys |
|
|
|
|
self.downstream_htlc_to_upstream_peer_map = {} # type: Dict[Tuple[str, int], bytes] |
|
|
|
|
self.hold_invoice_callbacks = {} # payment_hash -> callback, timeout |
|
|
|
|
self.payment_bundles = [] # lists of hashes. todo:persist |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def has_deterministic_node_id(self) -> bool: |
|
|
|
|
return bool(self.db.get('lightning_xprv')) |
|
|
|
|
@ -1862,6 +1864,14 @@ class LNWallet(LNWorker):
|
|
|
|
|
self.wallet.save_db() |
|
|
|
|
return payment_hash |
|
|
|
|
|
|
|
|
|
def bundle_payments(self, hash_list): |
|
|
|
|
self.payment_bundles.append(hash_list) |
|
|
|
|
|
|
|
|
|
def get_payment_bundle(self, payment_hash): |
|
|
|
|
for hash_list in self.payment_bundles: |
|
|
|
|
if payment_hash in hash_list: |
|
|
|
|
return hash_list |
|
|
|
|
|
|
|
|
|
def save_preimage(self, payment_hash: bytes, preimage: bytes, *, write_to_disk: bool = True): |
|
|
|
|
assert sha256(preimage) == payment_hash |
|
|
|
|
self.preimages[payment_hash.hex()] = preimage.hex() |
|
|
|
|
@ -1901,45 +1911,87 @@ class LNWallet(LNWorker):
|
|
|
|
|
""" return MPP status: True (accepted), False (expired) or None (waiting) |
|
|
|
|
""" |
|
|
|
|
payment_hash = htlc.payment_hash |
|
|
|
|
preimage = self.get_preimage(payment_hash) |
|
|
|
|
callback = self.hold_invoice_callbacks.get(payment_hash) |
|
|
|
|
if not preimage and callback: |
|
|
|
|
cb, timeout = callback |
|
|
|
|
if int(time.time()) < timeout: |
|
|
|
|
cb(payment_hash) |
|
|
|
|
return None |
|
|
|
|
else: |
|
|
|
|
return False |
|
|
|
|
|
|
|
|
|
amt_to_forward = htlc.amount_msat # check this |
|
|
|
|
if amt_to_forward >= expected_msat: |
|
|
|
|
# not multi-part |
|
|
|
|
return True |
|
|
|
|
self.update_mpp_with_received_htlc(payment_secret, short_channel_id, htlc, expected_msat) |
|
|
|
|
is_expired, is_accepted = self.get_mpp_status(payment_secret) |
|
|
|
|
if not is_accepted and not is_expired: |
|
|
|
|
bundle = self.get_payment_bundle(payment_hash) |
|
|
|
|
payment_hashes = bundle or [payment_hash] |
|
|
|
|
payment_secrets = [self.get_payment_secret(h) for h in bundle] if bundle else [payment_secret] |
|
|
|
|
first_timestamp = min([self.get_first_timestamp_of_mpp(x) for x in payment_secrets]) |
|
|
|
|
if self.get_payment_status(payment_hash) == PR_PAID: |
|
|
|
|
is_accepted = True |
|
|
|
|
elif self.stopping_soon: |
|
|
|
|
is_expired = True # try to time out pending HTLCs before shutting down |
|
|
|
|
elif time.time() - first_timestamp > self.MPP_EXPIRY: |
|
|
|
|
is_expired = True |
|
|
|
|
elif all([self.is_mpp_amount_reached(x) for x in payment_secrets]): |
|
|
|
|
preimage = self.get_preimage(payment_hash) |
|
|
|
|
hold_invoice_callback = self.hold_invoice_callbacks.get(payment_hash) |
|
|
|
|
if not preimage and hold_invoice_callback: |
|
|
|
|
# for hold invoices, trigger callback |
|
|
|
|
cb, timeout = hold_invoice_callback |
|
|
|
|
if int(time.time()) < timeout: |
|
|
|
|
cb(payment_hash) |
|
|
|
|
else: |
|
|
|
|
is_expired = True |
|
|
|
|
elif bundle is not None: |
|
|
|
|
is_accepted = all([bool(self.get_preimage(x)) for x in bundle]) |
|
|
|
|
else: |
|
|
|
|
# trampoline forwarding needs this to return True |
|
|
|
|
is_accepted = True |
|
|
|
|
|
|
|
|
|
# set status for the bundle |
|
|
|
|
if is_expired or is_accepted: |
|
|
|
|
for x in payment_secrets: |
|
|
|
|
if x in self.received_mpp_htlcs: |
|
|
|
|
self.set_mpp_status(x, is_expired, is_accepted) |
|
|
|
|
|
|
|
|
|
is_expired, is_accepted, htlc_set = self.received_mpp_htlcs.get(payment_secret, (False, False, set())) |
|
|
|
|
if self.get_payment_status(payment_hash) == PR_PAID: |
|
|
|
|
# payment_status is persisted |
|
|
|
|
is_accepted = True |
|
|
|
|
is_expired = False |
|
|
|
|
self.maybe_cleanup_mpp_status(payment_secret, short_channel_id, htlc) |
|
|
|
|
return True if is_accepted else (False if is_expired else None) |
|
|
|
|
|
|
|
|
|
def update_mpp_with_received_htlc(self, payment_secret, short_channel_id, htlc, expected_msat): |
|
|
|
|
# add new htlc to set |
|
|
|
|
is_expired, is_accepted, _expected_msat, htlc_set = self.received_mpp_htlcs.get(payment_secret, (False, False, expected_msat, set())) |
|
|
|
|
assert expected_msat == _expected_msat |
|
|
|
|
key = (short_channel_id, htlc) |
|
|
|
|
if key not in htlc_set: |
|
|
|
|
htlc_set.add(key) |
|
|
|
|
self.received_mpp_htlcs[payment_secret] = is_expired, is_accepted, _expected_msat, htlc_set |
|
|
|
|
|
|
|
|
|
def get_mpp_status(self, payment_secret): |
|
|
|
|
is_expired, is_accepted, _expected_msat, htlc_set = self.received_mpp_htlcs[payment_secret] |
|
|
|
|
return is_expired, is_accepted |
|
|
|
|
|
|
|
|
|
def set_mpp_status(self, payment_secret, is_expired, is_accepted): |
|
|
|
|
_is_expired, _is_accepted, _expected_msat, htlc_set = self.received_mpp_htlcs[payment_secret] |
|
|
|
|
self.received_mpp_htlcs[payment_secret] = is_expired, is_accepted, _expected_msat, htlc_set |
|
|
|
|
|
|
|
|
|
def is_mpp_amount_reached(self, payment_secret): |
|
|
|
|
mpp = self.received_mpp_htlcs.get(payment_secret) |
|
|
|
|
if not mpp: |
|
|
|
|
return False |
|
|
|
|
is_expired, is_accepted, _expected_msat, htlc_set = mpp |
|
|
|
|
total = sum([_htlc.amount_msat for scid, _htlc in htlc_set]) |
|
|
|
|
return total >= _expected_msat |
|
|
|
|
|
|
|
|
|
def get_first_timestamp_of_mpp(self, payment_secret): |
|
|
|
|
mpp = self.received_mpp_htlcs.get(payment_secret) |
|
|
|
|
if not mpp: |
|
|
|
|
return int(time.time()) |
|
|
|
|
is_expired, is_accepted, _expected_msat, htlc_set = mpp |
|
|
|
|
return min([_htlc.timestamp for scid, _htlc in htlc_set]) |
|
|
|
|
|
|
|
|
|
def maybe_cleanup_mpp_status(self, payment_secret, short_channel_id, htlc): |
|
|
|
|
is_expired, is_accepted, _expected_msat, htlc_set = self.received_mpp_htlcs[payment_secret] |
|
|
|
|
if not is_accepted and not is_expired: |
|
|
|
|
total = sum([_htlc.amount_msat for scid, _htlc in htlc_set]) |
|
|
|
|
first_timestamp = min([_htlc.timestamp for scid, _htlc in htlc_set]) |
|
|
|
|
if self.stopping_soon: |
|
|
|
|
is_expired = True # try to time out pending HTLCs before shutting down |
|
|
|
|
elif time.time() - first_timestamp > self.MPP_EXPIRY: |
|
|
|
|
is_expired = True |
|
|
|
|
elif total == expected_msat: |
|
|
|
|
is_accepted = True |
|
|
|
|
if is_accepted or is_expired: |
|
|
|
|
htlc_set.remove(key) |
|
|
|
|
return |
|
|
|
|
key = (short_channel_id, htlc) |
|
|
|
|
htlc_set.remove(key) |
|
|
|
|
if len(htlc_set) > 0: |
|
|
|
|
self.received_mpp_htlcs[payment_secret] = is_expired, is_accepted, htlc_set |
|
|
|
|
self.received_mpp_htlcs[payment_secret] = is_expired, is_accepted, _expected_msat, htlc_set |
|
|
|
|
elif payment_secret in self.received_mpp_htlcs: |
|
|
|
|
self.received_mpp_htlcs.pop(payment_secret) |
|
|
|
|
return True if is_accepted else (False if is_expired else None) |
|
|
|
|
|
|
|
|
|
def get_payment_status(self, payment_hash: bytes) -> int: |
|
|
|
|
info = self.get_payment_info(payment_hash) |
|
|
|
|
|