Source code for stegx.shamir
from __future__ import annotations
import os
from typing import List, Sequence, Tuple
from .exceptions import InsufficientSharesError
_LOG = [0] * 256
_EXP = [0] * 512
_a = 1
for _i in range(255):
_EXP[_i] = _a
_LOG[_a] = _i
_a <<= 1
if _a & 0x100:
_a ^= 0x11D
for _i in range(255, 512):
_EXP[_i] = _EXP[_i - 255]
[docs]
def gf_mul(a: int, b: int) -> int:
if a == 0 or b == 0:
return 0
return _EXP[_LOG[a] + _LOG[b]]
[docs]
def gf_div(a: int, b: int) -> int:
if b == 0:
raise ZeroDivisionError("GF(256) division by zero")
if a == 0:
return 0
return _EXP[(_LOG[a] - _LOG[b] + 255) % 255]
def _eval_poly(coeffs: Sequence[int], x: int) -> int:
acc = 0
for c in reversed(coeffs):
acc = gf_mul(acc, x) ^ c
return acc
[docs]
def split_secret(secret: bytes, k: int, n: int) -> List[bytes]:
if not (1 <= k <= n <= 255):
raise ValueError("Require 1 <= k <= n <= 255 for Shamir over GF(256)")
if not secret:
raise ValueError("Secret must be non-empty")
shares: List[bytearray] = [bytearray([i + 1, k]) for i in range(n)]
for byte in secret:
coeffs = [byte] + list(os.urandom(k - 1))
for i in range(n):
shares[i].append(_eval_poly(coeffs, i + 1))
return [bytes(s) for s in shares]
[docs]
def combine_shares(shares: Sequence[bytes]) -> bytes:
if len(shares) < 1:
raise ValueError("Need at least one share to read the threshold")
if len(shares[0]) < 3:
raise ValueError("Share too short — expected [x][k][y…] format")
k_required = shares[0][1]
if k_required < 2:
raise ValueError(f"Share header encodes an invalid threshold ({k_required} < 2)")
if len(shares) < k_required:
raise InsufficientSharesError(
f"Need at least {k_required} shares to reconstruct the secret "
f"(only {len(shares)} provided)"
)
xs = []
ys_per_byte: List[List[int]] = []
secret_len = len(shares[0]) - 2
for s in shares:
if len(s) != secret_len + 2:
raise ValueError("Shares have inconsistent length")
if s[0] == 0:
raise ValueError("Share has invalid x-coordinate 0")
if s[1] != k_required:
raise ValueError(
f"Shares have inconsistent threshold: expected {k_required}, got {s[1]}"
)
xs.append(s[0])
if len(set(xs)) != len(xs):
raise ValueError("Shares must have distinct x-coordinates")
for byte_idx in range(secret_len):
ys_per_byte.append([s[2 + byte_idx] for s in shares])
secret = bytearray(secret_len)
for byte_idx in range(secret_len):
secret[byte_idx] = _lagrange_at_zero(xs, ys_per_byte[byte_idx])
return bytes(secret)
def _lagrange_at_zero(xs: Sequence[int], ys: Sequence[int]) -> int:
acc = 0
m = len(xs)
for i in range(m):
num = 1
den = 1
for j in range(m):
if i == j:
continue
num = gf_mul(num, xs[j])
den = gf_mul(den, xs[j] ^ xs[i])
term = gf_mul(ys[i], gf_div(num, den))
acc ^= term
return acc
[docs]
def encode_share(share: bytes) -> bytes:
if len(share) < 3:
raise ValueError("share too short")
return share
[docs]
def decode_share(buf: bytes) -> Tuple[int, int, bytes]:
if len(buf) < 3:
raise ValueError("share too short")
return buf[0], buf[1], buf[2:]