from __future__ import annotations
import logging
import struct as _struct
from dataclasses import dataclass
from typing import Optional
from argon2.low_level import Type, hash_secret_raw
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.hmac import HMAC as _CryptoHMAC
from cryptography.hazmat.primitives.kdf.hkdf import HKDFExpand
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
from .fips import ban_if_fips
_TAG_PASSWORD = b"PWD0"
_TAG_KEYFILE = b"KFL0"
_TAG_YUBIKEY = b"YKR0"
_MAX_FACTOR_LEN = 16 * 1024 * 1024
KDF_ARGON2ID = 0x02
KDF_PBKDF2 = 0x01
MASTER_KEY_LEN = 32
ARGON2_TIME_COST = 3
ARGON2_MEMORY_COST_KIB = 65536
ARGON2_PARALLELISM = 4
PBKDF2_ITERATIONS = 600_000
HKDF_INFO_AES = b"stegx/v2/aes-256-gcm"
HKDF_INFO_CHACHA = b"stegx/v2/chacha20-poly1305"
HKDF_INFO_SEED = b"stegx/v2/pixel-shuffle-seed"
HKDF_INFO_SENTINEL = b"stegx/v2/sentinel"
HKDF_INFO_DECOY_SEED = b"stegx/v2/decoy-shuffle-seed"
[docs]
@dataclass(frozen=True)
class KdfParams:
kdf_id: int
time_cost: int = 0
memory_cost_kib: int = 0
parallelism: int = 0
iterations: int = 0
[docs]
@classmethod
def default_argon2id(cls) -> "KdfParams":
return cls(
kdf_id=KDF_ARGON2ID,
time_cost=ARGON2_TIME_COST,
memory_cost_kib=ARGON2_MEMORY_COST_KIB,
parallelism=ARGON2_PARALLELISM,
)
[docs]
@classmethod
def default_pbkdf2(cls) -> "KdfParams":
return cls(kdf_id=KDF_PBKDF2, iterations=PBKDF2_ITERATIONS)
def _frame_factor(tag: bytes, data: bytes) -> bytes:
if len(tag) != 4:
raise ValueError("Factor tag must be exactly 4 bytes.")
if len(data) > _MAX_FACTOR_LEN:
raise ValueError(
f"Factor '{tag.decode('ascii', 'replace')}' exceeds maximum "
f"size of {_MAX_FACTOR_LEN} bytes."
)
return tag + _struct.pack("!I", len(data)) + data
def _mix_factors(
password: bytes,
keyfile_bytes: Optional[bytes],
yubikey_response: Optional[bytes] = None,
) -> bytes:
return (
_frame_factor(_TAG_PASSWORD, password)
+ _frame_factor(_TAG_KEYFILE, keyfile_bytes or b"")
+ _frame_factor(_TAG_YUBIKEY, yubikey_response or b"")
)
[docs]
def derive_master_key(
password: str,
salt: bytes,
params: KdfParams,
keyfile_bytes: Optional[bytes] = None,
yubikey_response: Optional[bytes] = None,
*,
header_salt: Optional[bytes] = None,
) -> bytes:
if not password:
raise ValueError("Password cannot be empty.")
mixed = _mix_factors(password.encode("utf-8"), keyfile_bytes, yubikey_response)
if header_salt is not None:
mixed = hkdf_extract(salt=header_salt, ikm=mixed)
if params.kdf_id == KDF_ARGON2ID:
ban_if_fips("Argon2id KDF")
return hash_secret_raw(
secret=mixed,
salt=salt,
time_cost=params.time_cost,
memory_cost=params.memory_cost_kib,
parallelism=params.parallelism,
hash_len=MASTER_KEY_LEN,
type=Type.ID,
)
if params.kdf_id == KDF_PBKDF2:
kdf = PBKDF2HMAC(
algorithm=hashes.SHA256(),
length=MASTER_KEY_LEN,
salt=salt,
iterations=params.iterations,
)
return kdf.derive(mixed)
raise ValueError(f"Unknown KDF id: 0x{params.kdf_id:02x}")
[docs]
def hkdf_subkey(master_key: bytes, info: bytes, length: int = 32) -> bytes:
if len(master_key) != MASTER_KEY_LEN:
raise ValueError(
f"Master key must be exactly {MASTER_KEY_LEN} bytes "
f"(got {len(master_key)})."
)
expander = HKDFExpand(algorithm=hashes.SHA256(), length=length, info=info)
return expander.derive(master_key)
[docs]
def seed_int_from_subkey(subkey: bytes) -> int:
if len(subkey) < 8:
raise ValueError("Sub-key too short to derive seed.")
return int.from_bytes(subkey[:8], "big")
[docs]
def derive_legacy_seed_from_password(password: str) -> int:
kdf = PBKDF2HMAC(
algorithm=hashes.SHA256(),
length=8,
salt=b"stegx_pixel_shuffle_v1",
iterations=390_000,
)
key = kdf.derive(password.encode("utf-8"))
return int.from_bytes(key, "big")
[docs]
def calibrate_argon2_for_target_ms(target_ms: int = 500) -> KdfParams:
import time
ban_if_fips("Argon2id calibration")
test_salt = b"\x00" * 16
params = KdfParams.default_argon2id()
for memory_kib in (32_768, 65_536, 131_072, 262_144):
t0 = time.perf_counter()
hash_secret_raw(
secret=b"calibration",
salt=test_salt,
time_cost=params.time_cost,
memory_cost=memory_kib,
parallelism=params.parallelism,
hash_len=32,
type=Type.ID,
)
elapsed_ms = (time.perf_counter() - t0) * 1000
logging.debug("Argon2id calibration: memory=%d KiB took %.1f ms", memory_kib, elapsed_ms)
if elapsed_ms >= target_ms:
return KdfParams(
kdf_id=KDF_ARGON2ID,
time_cost=params.time_cost,
memory_cost_kib=memory_kib,
parallelism=params.parallelism,
)
return params