diff --git a/pybtc/functions/shamir.py b/pybtc/functions/shamir.py index f283ac0..aae75e6 100644 --- a/pybtc/functions/shamir.py +++ b/pybtc/functions/shamir.py @@ -115,13 +115,17 @@ def split_secret(threshold, total, secret, index_bits=8): q = [b] for i in range(threshold - 1): - if e_i < len(e): - a = e[e_i] - e_i += 1 - else: - e = generate_entropy(hex=False) - a = e[0] - e_i = 1 + is_leading_coefficient = i == threshold - 2 + while True: + if e_i < len(e): + a = e[e_i] + e_i += 1 + else: + e = generate_entropy(hex=False) + a = e[0] + e_i = 1 + if not is_leading_coefficient or a != 0: + break q.append(a) for z in shares_indexes: diff --git a/tests/test_shamir_functions.py b/tests/test_shamir_functions.py index dd9efdc..dd8048b 100644 --- a/tests/test_shamir_functions.py +++ b/tests/test_shamir_functions.py @@ -79,6 +79,19 @@ def test_secret_spliting(): shamir.split_secret(20, 20, secret, index_bits = 2) +def test_split_secret_regenerates_zero_leading_coefficient(monkeypatch): + entropy_chunks = iter([bytes([7, 0, 11])]) + monkeypatch.setattr(shamir, "generate_entropy", lambda hex=False: next(entropy_chunks)) + + secret = b"*" + shares = shamir.split_secret(3, 3, secret) + keys = list(shares) + two_shares = {keys[0]: shares[keys[0]], keys[1]: shares[keys[1]]} + + assert shamir.restore_secret(two_shares) != secret + assert shamir.restore_secret(shares) == secret + + def test__interpolation():