Browse Source

lnaddr: add get_features method

master
SomberNight 5 years ago
parent
commit
4069b455d0
No known key found for this signature in database
GPG Key ID: B33B5F232C6271E9
  1. 9
      electrum/lnaddr.py
  2. 12
      electrum/lnworker.py
  3. 3
      electrum/tests/test_bolt11.py
  4. 2
      electrum/tests/test_lnpeer.py

9
electrum/lnaddr.py

@ -6,7 +6,7 @@ import time
from hashlib import sha256 from hashlib import sha256
from binascii import hexlify from binascii import hexlify
from decimal import Decimal from decimal import Decimal
from typing import Optional from typing import Optional, TYPE_CHECKING
import random import random
import bitstring import bitstring
@ -17,6 +17,9 @@ from . import constants
from . import ecc from . import ecc
from .bitcoin import COIN from .bitcoin import COIN
if TYPE_CHECKING:
from .lnutil import LnFeatures
# BOLT #11: # BOLT #11:
# #
@ -315,6 +318,10 @@ class LnAddr(object):
return None return None
return int(self.amount * COIN * 1000) return int(self.amount * COIN * 1000)
def get_features(self) -> 'LnFeatures':
from .lnutil import LnFeatures
return LnFeatures(self.get_tag('9') or 0)
def __str__(self): def __str__(self):
return "LnAddr[{}, amount={}{} tags=[{}]]".format( return "LnAddr[{}, amount={}{} tags=[{}]]".format(
hexlify(self.pubkey.serialize()).decode('utf-8') if self.pubkey else None, hexlify(self.pubkey.serialize()).decode('utf-8') if self.pubkey else None,

12
electrum/lnworker.py

@ -961,7 +961,7 @@ class LNWallet(LNWorker):
invoice_pubkey=decoded_invoice.pubkey.serialize(), invoice_pubkey=decoded_invoice.pubkey.serialize(),
min_cltv_expiry=decoded_invoice.get_min_final_cltv_expiry(), min_cltv_expiry=decoded_invoice.get_min_final_cltv_expiry(),
r_tags=decoded_invoice.get_routing_info('r'), r_tags=decoded_invoice.get_routing_info('r'),
invoice_features=decoded_invoice.get_tag('9') or 0, invoice_features=decoded_invoice.get_features(),
trampoline_fee_level=0, trampoline_fee_level=0,
use_two_trampolines=False, use_two_trampolines=False,
payment_hash=decoded_invoice.paymenthash, payment_hash=decoded_invoice.paymenthash,
@ -981,7 +981,7 @@ class LNWallet(LNWorker):
key = payment_hash.hex() key = payment_hash.hex()
payment_secret = lnaddr.payment_secret payment_secret = lnaddr.payment_secret
invoice_pubkey = lnaddr.pubkey.serialize() invoice_pubkey = lnaddr.pubkey.serialize()
invoice_features = LnFeatures(lnaddr.get_tag('9') or 0) invoice_features = lnaddr.get_features()
r_tags = lnaddr.get_routing_info('r') r_tags = lnaddr.get_routing_info('r')
amount_to_pay = lnaddr.get_amount_msat() amount_to_pay = lnaddr.get_amount_msat()
status = self.get_payment_status(payment_hash) status = self.get_payment_status(payment_hash)
@ -1625,11 +1625,11 @@ class LNWallet(LNWorker):
self.received_htlcs.pop(payment_secret) self.received_htlcs.pop(payment_secret)
return True if is_accepted else (False if is_expired else None) return True if is_accepted else (False if is_expired else None)
def get_payment_status(self, payment_hash): def get_payment_status(self, payment_hash: bytes) -> int:
info = self.get_payment_info(payment_hash) info = self.get_payment_info(payment_hash)
return info.status if info else PR_UNPAID return info.status if info else PR_UNPAID
def get_invoice_status(self, invoice): def get_invoice_status(self, invoice: LNInvoice) -> int:
key = invoice.rhash key = invoice.rhash
log = self.logs[key] log = self.logs[key]
if key in self.inflight_payments: if key in self.inflight_payments:
@ -1640,7 +1640,7 @@ class LNWallet(LNWorker):
status = PR_FAILED status = PR_FAILED
return status return status
def set_invoice_status(self, key, status): def set_invoice_status(self, key: str, status: int) -> None:
if status == PR_INFLIGHT: if status == PR_INFLIGHT:
self.inflight_payments.add(key) self.inflight_payments.add(key)
elif key in self.inflight_payments: elif key in self.inflight_payments:
@ -1649,7 +1649,7 @@ class LNWallet(LNWorker):
self.set_payment_status(bfh(key), status) self.set_payment_status(bfh(key), status)
util.trigger_callback('invoice_status', self.wallet, key) util.trigger_callback('invoice_status', self.wallet, key)
def set_payment_status(self, payment_hash: bytes, status): def set_payment_status(self, payment_hash: bytes, status: int) -> None:
info = self.get_payment_info(payment_hash) info = self.get_payment_info(payment_hash)
if info is None: if info is None:
# if we are forwarding # if we are forwarding

3
electrum/tests/test_bolt11.py

@ -6,7 +6,7 @@ import unittest
from electrum.lnaddr import shorten_amount, unshorten_amount, LnAddr, lnencode, lndecode, u5_to_bitarray, bitarray_to_u5 from electrum.lnaddr import shorten_amount, unshorten_amount, LnAddr, lnencode, lndecode, u5_to_bitarray, bitarray_to_u5
from electrum.segwit_addr import bech32_encode, bech32_decode from electrum.segwit_addr import bech32_encode, bech32_decode
from electrum.lnutil import UnknownEvenFeatureBits, derive_payment_secret_from_payment_preimage from electrum.lnutil import UnknownEvenFeatureBits, derive_payment_secret_from_payment_preimage, LnFeatures
from . import ElectrumTestCase from . import ElectrumTestCase
@ -128,6 +128,7 @@ class TestBolt11(ElectrumTestCase):
def test_features(self): def test_features(self):
lnaddr = lndecode("lnbc25m1pvjluezpp5qqqsyqcyq5rqwzqfqqqsyqcyq5rqwzqfqqqsyqcyq5rqwzqfqypqdq5vdhkven9v5sxyetpdees9qzsze992adudgku8p05pstl6zh7av6rx2f297pv89gu5q93a0hf3g7lynl3xq56t23dpvah6u7y9qey9lccrdml3gaqwc6nxsl5ktzm464sq73t7cl") lnaddr = lndecode("lnbc25m1pvjluezpp5qqqsyqcyq5rqwzqfqqqsyqcyq5rqwzqfqqqsyqcyq5rqwzqfqypqdq5vdhkven9v5sxyetpdees9qzsze992adudgku8p05pstl6zh7av6rx2f297pv89gu5q93a0hf3g7lynl3xq56t23dpvah6u7y9qey9lccrdml3gaqwc6nxsl5ktzm464sq73t7cl")
self.assertEqual(514, lnaddr.get_tag('9')) self.assertEqual(514, lnaddr.get_tag('9'))
self.assertEqual(LnFeatures(514), lnaddr.get_features())
with self.assertRaises(UnknownEvenFeatureBits): with self.assertRaises(UnknownEvenFeatureBits):
lndecode("lnbc25m1pvjluezpp5qqqsyqcyq5rqwzqfqqqsyqcyq5rqwzqfqqqsyqcyq5rqwzqfqypqdq5vdhkven9v5sxyetpdees9q4pqqqqqqqqqqqqqqqqqqszk3ed62snp73037h4py4gry05eltlp0uezm2w9ajnerhmxzhzhsu40g9mgyx5v3ad4aqwkmvyftzk4k9zenz90mhjcy9hcevc7r3lx2sphzfxz7") lndecode("lnbc25m1pvjluezpp5qqqsyqcyq5rqwzqfqqqsyqcyq5rqwzqfqqqsyqcyq5rqwzqfqypqdq5vdhkven9v5sxyetpdees9q4pqqqqqqqqqqqqqqqqqqszk3ed62snp73037h4py4gry05eltlp0uezm2w9ajnerhmxzhzhsu40g9mgyx5v3ad4aqwkmvyftzk4k9zenz90mhjcy9hcevc7r3lx2sphzfxz7")

2
electrum/tests/test_lnpeer.py

@ -724,7 +724,7 @@ class TestPeer(ElectrumTestCase):
await group.spawn(peer.htlc_switch()) await group.spawn(peer.htlc_switch())
await asyncio.sleep(0.2) await asyncio.sleep(0.2)
pay_req = await self.prepare_invoice(graph.w_d, include_routing_hints=True) pay_req = await self.prepare_invoice(graph.w_d, include_routing_hints=True)
invoice_features = LnFeatures(lndecode(pay_req).get_tag('9') or 0) invoice_features = lndecode(pay_req).get_features()
self.assertFalse(invoice_features.supports(LnFeatures.BASIC_MPP_OPT)) self.assertFalse(invoice_features.supports(LnFeatures.BASIC_MPP_OPT))
await group.spawn(pay(pay_req)) await group.spawn(pay(pay_req))
with self.assertRaises(PaymentDone): with self.assertRaises(PaymentDone):

Loading…
Cancel
Save