From 60215e1100876694f528ff1097fe8bdb53104f42 Mon Sep 17 00:00:00 2001 From: Adam Gibson Date: Mon, 29 Mar 2021 13:44:26 +0100 Subject: [PATCH] BIP78 input ordering correct for > 2 inputs Prior to this commit, the receiver code was assuming only 2 inputs always, when it decided how to change the input ordering (randomly doing a reversal), but this is not correct according to the BIP78 spec, which requires that the receiver's inputs are *inserted* randomly, without changing the ordering of the existing (sender) inputs. After this commit, the BIP78 protocol is adhered to for any number of inputs. Added test for random_insert and for payjoin with 3 inputs --- jmbase/jmbase/__init__.py | 2 +- jmbase/jmbase/support.py | 9 +++++++++ jmbase/test/test_base_support.py | 29 ++++++++++++++++++++++++++--- jmclient/jmclient/payjoin.py | 13 ++++--------- jmclient/test/test_payjoin.py | 21 +++++++++++++++------ 5 files changed, 55 insertions(+), 19 deletions(-) diff --git a/jmbase/jmbase/__init__.py b/jmbase/jmbase/__init__.py index 8bf80a9..d0b0e61 100644 --- a/jmbase/jmbase/__init__.py +++ b/jmbase/jmbase/__init__.py @@ -8,7 +8,7 @@ from .support import (get_log, chunks, debug_silence, jmprint, EXIT_SUCCESS, hexbin, dictchanger, listchanger, JM_WALLET_NAME_PREFIX, JM_APP_NAME, IndentedHelpFormatterWithNL, wrapped_urlparse, - bdict_sdict_convert) + bdict_sdict_convert, random_insert) from .proof_of_work import get_pow, verify_pow from .twisted_utils import (stop_reactor, is_hs_uri, get_tor_agent, get_nontor_agent, JMHiddenService, diff --git a/jmbase/jmbase/support.py b/jmbase/jmbase/support.py index 90ad7b4..373863d 100644 --- a/jmbase/jmbase/support.py +++ b/jmbase/jmbase/support.py @@ -1,6 +1,7 @@ import logging, sys import binascii +import random from getpass import getpass from os import path, environ from functools import wraps @@ -324,3 +325,11 @@ def bdict_sdict_convert(d, output_binary=False): newv = [a.decode("utf-8") for a in v] newd[k.decode("utf-8")] = newv return newd + +def random_insert(old, new): + """ Insert elements of new at random indices in + the old list, without changing the ordering of the old list. + """ + for n in new: + insertion_index = random.randint(0, len(old)) + old[:] = old[:insertion_index] + [n] + old[insertion_index:] diff --git a/jmbase/test/test_base_support.py b/jmbase/test/test_base_support.py index 7730dde..68734e5 100644 --- a/jmbase/test/test_base_support.py +++ b/jmbase/test/test_base_support.py @@ -1,9 +1,32 @@ #! /usr/bin/env python +import pytest +import copy +from jmbase import random_insert def test_color_coded_logging(): # TODO pass - - - +@pytest.mark.parametrize('list1, list2', [ + [[1,2,3],[4,5,6]], + [["a", "b", "c", "d", "e", "f", "g"], [1,2]], +]) +def test_random_insert(list1, list2): + l1 = len(list1) + l2 = len(list2) + # make a copy of the old version so we can + # check ordering: + old_list1 = copy.deepcopy(list1) + random_insert(list1, list2) + assert len(list1) == l1+l2 + assert all([x in list1 for x in list2]) + assert all([x in list1 for x in old_list1]) + # check the order of every element in the original + # list is preserved: + for x, y in [(old_list1[i], old_list1[i+1]) for i in range( + len(old_list1)-1)]: + # no need to catch ValueError, it should never throw + # so that's a fail anyway. + i_x = list1.index(x) + i_y = list1.index(y) + assert i_y > i_x diff --git a/jmclient/jmclient/payjoin.py b/jmclient/jmclient/payjoin.py index 59c8324..d326c7c 100644 --- a/jmclient/jmclient/payjoin.py +++ b/jmclient/jmclient/payjoin.py @@ -4,8 +4,7 @@ try: except ImportError: pass import json -import random -from jmbase import bintohex, jmprint +from jmbase import bintohex, jmprint, random_insert from .configure import get_log, jm_single import jmbitcoin as btc from .wallet import PSBTWalletMixin, SegwitLegacyWallet, SegwitWallet @@ -767,7 +766,9 @@ class PayjoinConverter(object): # construct unsigned tx for payjoin-psbt: payjoin_tx_inputs = [(x.prevout.hash[::-1], x.prevout.n) for x in payment_psbt.unsigned_tx.vin] - payjoin_tx_inputs.extend(receiver_utxos.keys()) + # See https://github.com/bitcoin/bips/blob/master/bip-0078.mediawiki#Protocol + random_insert(payjoin_tx_inputs, receiver_utxos.keys()) + pay_out = {"value": self.manager.pay_out.nValue, "address": str(btc.CCoinAddress.from_scriptPubKey( self.manager.pay_out.scriptPubKey))} @@ -846,11 +847,6 @@ class PayjoinConverter(object): # intended: outs[self.manager.change_out_index]["value"] -= our_fee_bump - # TODO this only works for 2 input transactions, otherwise - # reversal [::-1] will not be valid as per BIP78 ordering requirement. - # (For outputs, we do nothing since we aren't batching in other payments). - if random.random() < 0.5: - payjoin_tx_inputs = payjoin_tx_inputs[::-1] unsigned_payjoin_tx = btc.mktx(payjoin_tx_inputs, outs, version=payment_psbt.unsigned_tx.nVersion, locktime=payment_psbt.unsigned_tx.nLockTime) @@ -886,7 +882,6 @@ class PayjoinConverter(object): # respect the sender's fixed sequence number, if it was used (we checked # in the initial sanity check) - # TODO consider RBF if we implement it in Joinmarket payments. if self.manager.fixed_sequence_number: for inp in unsigned_payjoin_tx.vin: inp.nSequence = self.manager.fixed_sequence_number diff --git a/jmclient/test/test_payjoin.py b/jmclient/test/test_payjoin.py index 2738bdb..7ae8526 100644 --- a/jmclient/test/test_payjoin.py +++ b/jmclient/test/test_payjoin.py @@ -57,22 +57,25 @@ class PayjoinTestBase(object): implicitly testing all the BIP78 rules (failures are caught by the JMPayjoinManager and PayjoinConverter rules). """ + # the indices in our wallets to populate + wallet_structure = [1, 3, 0, 0, 0] + # the mean amount of each deposit in the above indices, in btc + mean_amt = 2.0 def setUp(self): load_test_config() jm_single().bc_interface.tick_forward_chain_interval = 5 jm_single().bc_interface.simulate_blocks() - def do_test_payment(self, wc1, wc2): - wallet_structures = [[1, 3, 0, 0, 0]] * 2 - mean_amt = 2.0 + def do_test_payment(self, wc1, wc2, amt=1.1): + wallet_structures = [self.wallet_structure] * 2 wallet_cls = (wc1, wc2) self.wallet_services = [] self.wallet_services.append(make_wallets_to_list(make_wallets( 1, wallet_structures=[wallet_structures[0]], - mean_amt=mean_amt, wallet_cls=wallet_cls[0]))[0]) + mean_amt=self.mean_amt, wallet_cls=wallet_cls[0]))[0]) self.wallet_services.append(make_wallets_to_list(make_wallets( 1, wallet_structures=[wallet_structures[1]], - mean_amt=mean_amt, wallet_cls=wallet_cls[1]))[0]) + mean_amt=self.mean_amt, wallet_cls=wallet_cls[1]))[0]) jm_single().bc_interface.tickchain() sync_wallets(self.wallet_services) @@ -81,7 +84,7 @@ class PayjoinTestBase(object): self.rsb = getbals(self.wallet_services[0], 0) self.ssb = getbals(self.wallet_services[1], 0) - self.cj_amount = int(1.1 * 10**8) + self.cj_amount = int(amt * 10**8) def cbStopListening(): return self.port.stopListening() b78rm = JMBIP78ReceiverManager(self.wallet_services[0], 0, @@ -135,6 +138,12 @@ class TrialTestPayjoin2(PayjoinTestBase, unittest.TestCase): def test_bech32_payment(self): return self.do_test_payment(SegwitWallet, SegwitWallet) +class TrialTestPayjoin3(PayjoinTestBase, unittest.TestCase): + def test_multi_input(self): + # wallet structure and amt are chosen so that the sender + # will need 3 utxos rather than 1 (to pay 4.5 from 2,2,2). + self.wallet_structure = [3, 1, 0, 0, 0] + return self.do_test_payment(SegwitWallet, SegwitWallet, amt=4.5) def bip78_receiver_response(response, manager): d = readBody(response)