SAE: Avoid branches in is_quadratic_residue_blind()

Make the non-failure path in the function proceed without branches based
on r_odd and in constant time to minimize risk of observable differences
in timing or cache use. (CVE-2019-9494)

Signed-off-by: Jouni Malinen <jouni@codeaurora.org>
This commit is contained in:
Jouni Malinen 2019-02-26 19:34:38 +02:00 committed by Jouni Malinen
parent 6513db3e96
commit 362704dda0

View file

@ -232,12 +232,14 @@ get_rand_1_to_p_1(const u8 *prime, size_t prime_len, size_t prime_bits,
static int is_quadratic_residue_blind(struct sae_data *sae, static int is_quadratic_residue_blind(struct sae_data *sae,
const u8 *prime, size_t bits, const u8 *prime, size_t bits,
const struct crypto_bignum *qr, const u8 *qr, const u8 *qnr,
const struct crypto_bignum *qnr,
const struct crypto_bignum *y_sqr) const struct crypto_bignum *y_sqr)
{ {
struct crypto_bignum *r, *num; struct crypto_bignum *r, *num, *qr_or_qnr = NULL;
int r_odd, check, res = -1; int r_odd, check, res = -1;
u8 qr_or_qnr_bin[SAE_MAX_ECC_PRIME_LEN];
size_t prime_len = sae->tmp->prime_len;
unsigned int mask;
/* /*
* Use the blinding technique to mask y_sqr while determining * Use the blinding technique to mask y_sqr while determining
@ -248,7 +250,7 @@ static int is_quadratic_residue_blind(struct sae_data *sae,
* r = a random number between 1 and p-1, inclusive * r = a random number between 1 and p-1, inclusive
* num = (v * r * r) modulo p * num = (v * r * r) modulo p
*/ */
r = get_rand_1_to_p_1(prime, sae->tmp->prime_len, bits, &r_odd); r = get_rand_1_to_p_1(prime, prime_len, bits, &r_odd);
if (!r) if (!r)
return -1; return -1;
@ -258,41 +260,45 @@ static int is_quadratic_residue_blind(struct sae_data *sae,
crypto_bignum_mulmod(num, r, sae->tmp->prime, num) < 0) crypto_bignum_mulmod(num, r, sae->tmp->prime, num) < 0)
goto fail; goto fail;
if (r_odd) {
/* /*
* Need to minimize differences in handling different cases, so try to
* avoid branches and timing differences.
*
* If r_odd:
* num = (num * qr) module p * num = (num * qr) module p
* LGR(num, p) = 1 ==> quadratic residue * LGR(num, p) = 1 ==> quadratic residue
*/ * else:
if (crypto_bignum_mulmod(num, qr, sae->tmp->prime, num) < 0)
goto fail;
check = 1;
} else {
/*
* num = (num * qnr) module p * num = (num * qnr) module p
* LGR(num, p) = -1 ==> quadratic residue * LGR(num, p) = -1 ==> quadratic residue
*/ */
if (crypto_bignum_mulmod(num, qnr, sae->tmp->prime, num) < 0) mask = const_time_is_zero(r_odd);
const_time_select_bin(mask, qnr, qr, prime_len, qr_or_qnr_bin);
qr_or_qnr = crypto_bignum_init_set(qr_or_qnr_bin, prime_len);
if (!qr_or_qnr ||
crypto_bignum_mulmod(num, qr_or_qnr, sae->tmp->prime, num) < 0)
goto fail; goto fail;
check = -1; /* r_odd is 0 or 1; branchless version of check = r_odd ? 1 : -1, */
} check = const_time_select_int(mask, -1, 1);
res = crypto_bignum_legendre(num, sae->tmp->prime); res = crypto_bignum_legendre(num, sae->tmp->prime);
if (res == -2) { if (res == -2) {
res = -1; res = -1;
goto fail; goto fail;
} }
res = res == check; /* branchless version of res = res == check
* (res is -1, 0, or 1; check is -1 or 1) */
mask = const_time_eq(res, check);
res = const_time_select_int(mask, 1, 0);
fail: fail:
crypto_bignum_deinit(num, 1); crypto_bignum_deinit(num, 1);
crypto_bignum_deinit(r, 1); crypto_bignum_deinit(r, 1);
crypto_bignum_deinit(qr_or_qnr, 1);
return res; return res;
} }
static int sae_test_pwd_seed_ecc(struct sae_data *sae, const u8 *pwd_seed, static int sae_test_pwd_seed_ecc(struct sae_data *sae, const u8 *pwd_seed,
const u8 *prime, const u8 *prime, const u8 *qr, const u8 *qnr,
const struct crypto_bignum *qr,
const struct crypto_bignum *qnr,
u8 *pwd_value) u8 *pwd_value)
{ {
struct crypto_bignum *y_sqr, *x_cand; struct crypto_bignum *y_sqr, *x_cand;
@ -452,6 +458,8 @@ static int sae_derive_pwe_ecc(struct sae_data *sae, const u8 *addr1,
struct crypto_bignum *x = NULL, *qr = NULL, *qnr = NULL; struct crypto_bignum *x = NULL, *qr = NULL, *qnr = NULL;
u8 x_bin[SAE_MAX_ECC_PRIME_LEN]; u8 x_bin[SAE_MAX_ECC_PRIME_LEN];
u8 x_cand_bin[SAE_MAX_ECC_PRIME_LEN]; u8 x_cand_bin[SAE_MAX_ECC_PRIME_LEN];
u8 qr_bin[SAE_MAX_ECC_PRIME_LEN];
u8 qnr_bin[SAE_MAX_ECC_PRIME_LEN];
size_t bits; size_t bits;
int res = -1; int res = -1;
u8 found = 0; /* 0 (false) or 0xff (true) to be used as const_time_* u8 found = 0; /* 0 (false) or 0xff (true) to be used as const_time_*
@ -476,7 +484,9 @@ static int sae_derive_pwe_ecc(struct sae_data *sae, const u8 *addr1,
* (qnr) modulo p for blinding purposes during the loop. * (qnr) modulo p for blinding purposes during the loop.
*/ */
if (get_random_qr_qnr(prime, prime_len, sae->tmp->prime, bits, if (get_random_qr_qnr(prime, prime_len, sae->tmp->prime, bits,
&qr, &qnr) < 0) &qr, &qnr) < 0 ||
crypto_bignum_to_bin(qr, qr_bin, sizeof(qr_bin), prime_len) < 0 ||
crypto_bignum_to_bin(qnr, qnr_bin, sizeof(qnr_bin), prime_len) < 0)
goto fail; goto fail;
wpa_hexdump_ascii_key(MSG_DEBUG, "SAE: password", wpa_hexdump_ascii_key(MSG_DEBUG, "SAE: password",
@ -527,7 +537,7 @@ static int sae_derive_pwe_ecc(struct sae_data *sae, const u8 *addr1,
break; break;
res = sae_test_pwd_seed_ecc(sae, pwd_seed, res = sae_test_pwd_seed_ecc(sae, pwd_seed,
prime, qr, qnr, x_cand_bin); prime, qr_bin, qnr_bin, x_cand_bin);
const_time_select_bin(found, x_bin, x_cand_bin, prime_len, const_time_select_bin(found, x_bin, x_cand_bin, prime_len,
x_bin); x_bin);
pwd_seed_odd = const_time_select_u8( pwd_seed_odd = const_time_select_u8(