diff --git a/src/common/sae.c b/src/common/sae.c index 49b9e9047..808dca045 100644 --- a/src/common/sae.c +++ b/src/common/sae.c @@ -12,6 +12,8 @@ #include "utils/const_time.h" #include "crypto/crypto.h" #include "crypto/sha256.h" +#include "crypto/sha384.h" +#include "crypto/sha512.h" #include "crypto/random.h" #include "crypto/dh_groups.h" #include "ieee802_11_defs.h" @@ -107,6 +109,7 @@ void sae_clear_temp_data(struct sae_data *sae) crypto_ec_point_deinit(tmp->own_commit_element_ecc, 0); crypto_ec_point_deinit(tmp->peer_commit_element_ecc, 0); wpabuf_free(tmp->anti_clogging_token); + wpabuf_free(tmp->own_rejected_groups); wpabuf_free(tmp->peer_rejected_groups); os_free(tmp->pw_id); bin_clear_free(tmp, sizeof(*tmp)); @@ -526,6 +529,745 @@ fail: } +static int hkdf_extract(size_t hash_len, const u8 *salt, size_t salt_len, + size_t num_elem, const u8 *addr[], const size_t len[], + u8 *prk) +{ + if (hash_len == 32) + return hmac_sha256_vector(salt, salt_len, num_elem, addr, len, + prk); +#ifdef CONFIG_SHA384 + if (hash_len == 48) + return hmac_sha384_vector(salt, salt_len, num_elem, addr, len, + prk); +#endif /* CONFIG_SHA384 */ +#ifdef CONFIG_SHA512 + if (hash_len == 64) + return hmac_sha512_vector(salt, salt_len, num_elem, addr, len, + prk); +#endif /* CONFIG_SHA512 */ + return -1; +} + + +static int hkdf_expand(size_t hash_len, const u8 *prk, size_t prk_len, + const char *info, u8 *okm, size_t okm_len) +{ + size_t info_len = os_strlen(info); + + if (hash_len == 32) + return hmac_sha256_kdf(prk, prk_len, NULL, + (const u8 *) info, info_len, + okm, okm_len); +#ifdef CONFIG_SHA384 + if (hash_len == 48) + return hmac_sha384_kdf(prk, prk_len, NULL, + (const u8 *) info, info_len, + okm, okm_len); +#endif /* CONFIG_SHA384 */ +#ifdef CONFIG_SHA512 + if (hash_len == 64) + return hmac_sha512_kdf(prk, prk_len, NULL, + (const u8 *) info, info_len, + okm, okm_len); +#endif /* CONFIG_SHA512 */ + return -1; +} + + +static int sswu_curve_param(int group, int *z) +{ + switch (group) { + case 19: + case 20: + case 21: + case 28: + *z = -2; + return 0; + case 25: + case 29: + *z = -5; + return 0; + case 26: + *z = -11; + return 0; + case 30: + *z = 2; + return 0; + } + + return -1; +} + + +static void debug_print_bignum(const char *title, const struct crypto_bignum *a, + size_t prime_len) +{ + u8 *bin; + + bin = os_malloc(prime_len); + if (bin && crypto_bignum_to_bin(a, bin, prime_len, prime_len) >= 0) + wpa_hexdump_key(MSG_DEBUG, title, bin, prime_len); + else + wpa_printf(MSG_DEBUG, "Could not print bignum (%s)", title); + bin_clear_free(bin, prime_len); +} + + +static struct crypto_ec_point * sswu(struct crypto_ec *ec, int group, + const struct crypto_bignum *u) +{ + int z_int; + const struct crypto_bignum *a, *b, *prime; + struct crypto_bignum *u2, *t1, *t2, *z, *t, *zero, *one, *two, *three, + *x1a, *x1b, *y = NULL; + struct crypto_bignum *x1 = NULL, *x2, *gx1, *gx2, *v = NULL; + unsigned int m_is_zero, is_qr, is_eq; + size_t prime_len; + u8 bin[SAE_MAX_ECC_PRIME_LEN]; + u8 bin1[SAE_MAX_ECC_PRIME_LEN]; + u8 bin2[SAE_MAX_ECC_PRIME_LEN]; + u8 x_y[2 * SAE_MAX_ECC_PRIME_LEN]; + struct crypto_ec_point *p = NULL; + + if (sswu_curve_param(group, &z_int) < 0) + return NULL; + + prime = crypto_ec_get_prime(ec); + prime_len = crypto_ec_prime_len(ec); + a = crypto_ec_get_a(ec); + b = crypto_ec_get_b(ec); + + u2 = crypto_bignum_init(); + t1 = crypto_bignum_init(); + t2 = crypto_bignum_init(); + z = crypto_bignum_init_uint(abs(z_int)); + t = crypto_bignum_init(); + zero = crypto_bignum_init_uint(0); + one = crypto_bignum_init_uint(1); + two = crypto_bignum_init_uint(2); + three = crypto_bignum_init_uint(3); + x1a = crypto_bignum_init(); + x1b = crypto_bignum_init(); + x2 = crypto_bignum_init(); + gx1 = crypto_bignum_init(); + gx2 = crypto_bignum_init(); + if (!u2 || !t1 || !t2 || !z || !t || !zero || !one || !two || !three || + !x1a || !x1b || !x2 || !gx1 || !gx2) + goto fail; + + if (z_int < 0 && crypto_bignum_sub(prime, z, z) < 0) + goto fail; + + /* m = z^2 * u^4 + z * u^2 */ + /* --> tmp = z * u^2, m = tmp^2 + tmp */ + + /* u2 = u^2 + * t1 = z * u2 + * t2 = t1^2 + * m = t1 = t1 + t2 */ + if (crypto_bignum_sqrmod(u, prime, u2) < 0 || + crypto_bignum_mulmod(z, u2, prime, t1) < 0 || + crypto_bignum_sqrmod(t1, prime, t2) < 0 || + crypto_bignum_addmod(t1, t2, prime, t1) < 0) + goto fail; + debug_print_bignum("SSWU: m", t1, prime_len); + + /* l = CEQ(m, 0) + * t = CSEL(l, 0, inverse(m); where inverse(x) is calculated as + * x^(p-2) modulo p which will handle m == 0 case correctly */ + /* TODO: Make sure crypto_bignum_is_zero() is constant time */ + m_is_zero = const_time_eq(crypto_bignum_is_zero(t1), 1); + /* t = m^(p-2) modulo p */ + if (crypto_bignum_sub(prime, two, t2) < 0 || + crypto_bignum_exptmod(t1, t2, prime, t) < 0) + goto fail; + debug_print_bignum("SSWU: t", t, prime_len); + + /* b / (z * a) */ + if (crypto_bignum_mulmod(z, a, prime, t1) < 0 || + crypto_bignum_inverse(t1, prime, t1) < 0 || + crypto_bignum_mulmod(b, t1, prime, x1a) < 0) + goto fail; + debug_print_bignum("SSWU: x1a = b / (z * a)", x1a, prime_len); + + /* (-b/a) * (1 + t) */ + if (crypto_bignum_sub(prime, b, t1) < 0 || + crypto_bignum_inverse(a, prime, t2) < 0 || + crypto_bignum_mulmod(t1, t2, prime, t1) < 0 || + crypto_bignum_addmod(one, t, prime, t2) < 0 || + crypto_bignum_mulmod(t1, t2, prime, x1b) < 0) + goto fail; + debug_print_bignum("SSWU: x1b = (-b/a) * (1 + t)", x1b, prime_len); + + /* x1 = CSEL(CEQ(m, 0), x1a, x1b) */ + if (crypto_bignum_to_bin(x1a, bin1, sizeof(bin1), prime_len) < 0 || + crypto_bignum_to_bin(x1b, bin2, sizeof(bin2), prime_len) < 0) + goto fail; + const_time_select_bin(m_is_zero, bin1, bin2, prime_len, bin); + x1 = crypto_bignum_init_set(bin, prime_len); + debug_print_bignum("SSWU: x1 = CSEL(l, x1a, x1b)", x1, prime_len); + + /* gx1 = x1^3 + a * x1 + b */ + if (crypto_bignum_exptmod(x1, three, prime, t1) < 0 || + crypto_bignum_mulmod(a, x1, prime, t2) < 0 || + crypto_bignum_addmod(t1, t2, prime, t1) < 0 || + crypto_bignum_addmod(t1, b, prime, gx1) < 0) + goto fail; + debug_print_bignum("SSWU: gx1 = x1^3 + a * x1 + b", gx1, prime_len); + + /* x2 = z * u^2 * x1 */ + if (crypto_bignum_mulmod(z, u2, prime, t1) < 0 || + crypto_bignum_mulmod(t1, x1, prime, x2) < 0) + goto fail; + debug_print_bignum("SSWU: x2 = z * u^2 * x1", x2, prime_len); + + /* gx2 = x2^3 + a * x2 + b */ + if (crypto_bignum_exptmod(x2, three, prime, t1) < 0 || + crypto_bignum_mulmod(a, x2, prime, t2) < 0 || + crypto_bignum_addmod(t1, t2, prime, t1) < 0 || + crypto_bignum_addmod(t1, b, prime, gx2) < 0) + goto fail; + debug_print_bignum("SSWU: gx2 = x2^3 + a * x2 + b", gx2, prime_len); + + /* l = gx1 is a quadratic residue modulo p + * --> gx1^((p-1)/2) modulo p is zero or one */ + if (crypto_bignum_sub(prime, one, t1) < 0 || + crypto_bignum_rshift(t1, 1, t1) < 0 || + crypto_bignum_exptmod(gx1, t1, prime, t1) < 0) + goto fail; + debug_print_bignum("SSWU: gx1^((p-1)/2) modulo p", t1, prime_len); + is_qr = const_time_eq(crypto_bignum_is_zero(t1) | + crypto_bignum_is_one(t1), 1); + + /* v = CSEL(l, gx1, gx2) */ + if (crypto_bignum_to_bin(gx1, bin1, sizeof(bin1), prime_len) < 0 || + crypto_bignum_to_bin(gx2, bin2, sizeof(bin2), prime_len) < 0) + goto fail; + const_time_select_bin(is_qr, bin1, bin2, prime_len, bin); + v = crypto_bignum_init_set(bin, prime_len); + debug_print_bignum("SSWU: v = CSEL(l, gx1, gx2)", v, prime_len); + + /* x = CSEL(l, x1, x2) */ + if (crypto_bignum_to_bin(x1, bin1, sizeof(bin1), prime_len) < 0 || + crypto_bignum_to_bin(x2, bin2, sizeof(bin2), prime_len) < 0) + goto fail; + const_time_select_bin(is_qr, bin1, bin2, prime_len, x_y); + wpa_hexdump_key(MSG_DEBUG, "SSWU: x = CSEL(l, x1, x2)", x_y, prime_len); + + /* y = sqrt(v) */ + y = crypto_bignum_init(); + /* TODO: Remove p = 3 mod 4 check and disable group 26 instead(?) */ + if (crypto_bignum_to_bin(prime, bin1, sizeof(bin1), prime_len) < 0) + goto fail; + if ((bin1[prime_len - 1] & 0x03) == 3) { + /* For prime p such that p = 3 mod 4 --> v^((p+1)/4) */ + if (!y || + crypto_bignum_add(prime, one, t1) < 0 || + crypto_bignum_rshift(t1, 2, t1) < 0 || + crypto_bignum_exptmod(v, t1, prime, y) < 0) + goto fail; + } else { + wpa_printf(MSG_DEBUG, "SSWU: prime does not have p = 3 mod 4"); + if (!y || crypto_bignum_sqrtmod(v, prime, y) < 0) + goto fail; + } + debug_print_bignum("SSWU: y = sqrt(v)", y, prime_len); + + /* l = CEQ(LSB(u), LSB(y)) */ + if (crypto_bignum_to_bin(u, bin1, sizeof(bin1), prime_len) < 0 || + crypto_bignum_to_bin(y, bin2, sizeof(bin2), prime_len) < 0) + goto fail; + is_eq = const_time_eq(bin1[prime_len - 1] & 0x01, + bin2[prime_len - 1] & 0x01); + + /* P = CSEL(l, (x,y), (x, p-y)) */ + if (crypto_bignum_sub(prime, y, t1) < 0) + goto fail; + debug_print_bignum("SSWU: p - y", t1, prime_len); + if (crypto_bignum_to_bin(y, bin1, sizeof(bin1), prime_len) < 0 || + crypto_bignum_to_bin(t1, bin2, sizeof(bin2), prime_len) < 0) + goto fail; + const_time_select_bin(is_eq, bin1, bin2, prime_len, &x_y[prime_len]); + + /* output P */ + wpa_hexdump_key(MSG_DEBUG, "SSWU: P.x", x_y, prime_len); + wpa_hexdump_key(MSG_DEBUG, "SSWU: P.y", &x_y[prime_len], prime_len); + p = crypto_ec_point_from_bin(ec, x_y); + +fail: + crypto_bignum_deinit(u2, 1); + crypto_bignum_deinit(t1, 1); + crypto_bignum_deinit(t2, 1); + crypto_bignum_deinit(z, 0); + crypto_bignum_deinit(t, 1); + crypto_bignum_deinit(x1a, 1); + crypto_bignum_deinit(x1b, 1); + crypto_bignum_deinit(x1, 1); + crypto_bignum_deinit(x2, 1); + crypto_bignum_deinit(gx1, 1); + crypto_bignum_deinit(gx2, 1); + crypto_bignum_deinit(y, 1); + crypto_bignum_deinit(v, 1); + crypto_bignum_deinit(zero, 0); + crypto_bignum_deinit(one, 0); + crypto_bignum_deinit(two, 0); + crypto_bignum_deinit(three, 0); + forced_memzero(bin, sizeof(bin)); + forced_memzero(bin1, sizeof(bin1)); + forced_memzero(bin2, sizeof(bin2)); + forced_memzero(x_y, sizeof(x_y)); + return p; +} + + +static int sae_pwd_seed(size_t hash_len, const u8 *ssid, size_t ssid_len, + const u8 *password, size_t password_len, + const char *identifier, u8 *pwd_seed) +{ + const u8 *addr[2]; + size_t len[2]; + size_t num_elem; + + /* pwd-seed = HKDF-Extract(ssid, password [ || identifier ]) */ + addr[0] = password; + len[0] = password_len; + num_elem = 1; + wpa_hexdump_ascii(MSG_DEBUG, "SAE: SSID", ssid, ssid_len); + wpa_hexdump_ascii_key(MSG_DEBUG, "SAE: password", + password, password_len); + if (identifier) { + wpa_printf(MSG_DEBUG, "SAE: password identifier: %s", + identifier); + addr[num_elem] = (const u8 *) identifier; + len[num_elem] = os_strlen(identifier); + num_elem++; + } + if (hkdf_extract(hash_len, ssid, ssid_len, num_elem, addr, len, + pwd_seed) < 0) + return -1; + wpa_hexdump_key(MSG_DEBUG, "SAE: pwd-seed", pwd_seed, hash_len); + return 0; +} + + +size_t sae_ecc_prime_len_2_hash_len(size_t prime_len) +{ + if (prime_len <= 256 / 8) + return 32; + if (prime_len <= 384 / 8) + return 48; + return 64; +} + + +struct crypto_ec_point * +sae_derive_pt_ecc(struct crypto_ec *ec, int group, + const u8 *ssid, size_t ssid_len, + const u8 *password, size_t password_len, + const char *identifier) +{ + u8 pwd_seed[64]; + u8 pwd_value[SAE_MAX_ECC_PRIME_LEN * 2]; + size_t pwd_value_len, hash_len, prime_len; + const struct crypto_bignum *prime; + struct crypto_bignum *bn = NULL; + struct crypto_ec_point *p1 = NULL, *p2 = NULL, *pt = NULL; + + prime = crypto_ec_get_prime(ec); + prime_len = crypto_ec_prime_len(ec); + if (prime_len > SAE_MAX_ECC_PRIME_LEN) + goto fail; + hash_len = sae_ecc_prime_len_2_hash_len(prime_len); + + /* len = olen(p) + ceil(olen(p)/2) */ + pwd_value_len = prime_len + (prime_len + 1) / 2; + + if (sae_pwd_seed(hash_len, ssid, ssid_len, password, password_len, + identifier, pwd_seed) < 0) + goto fail; + + /* pwd-value = HKDF-Expand(pwd-seed, "SAE Hash to Element u1 P1", len) + */ + if (hkdf_expand(hash_len, pwd_seed, hash_len, + "SAE Hash to Element u1 P1", pwd_value, pwd_value_len) < + 0) + goto fail; + wpa_hexdump_key(MSG_DEBUG, "SAE: pwd-value (u1 P1)", + pwd_value, pwd_value_len); + + /* u1 = pwd-value modulo p */ + bn = crypto_bignum_init_set(pwd_value, pwd_value_len); + if (!bn || crypto_bignum_mod(bn, prime, bn) < 0 || + crypto_bignum_to_bin(bn, pwd_value, sizeof(pwd_value), + prime_len) < 0) + goto fail; + wpa_hexdump_key(MSG_DEBUG, "SAE: u1", pwd_value, prime_len); + + /* P1 = SSWU(u1) */ + p1 = sswu(ec, group, bn); + if (!p1) + goto fail; + + /* pwd-value = HKDF-Expand(pwd-seed, "SAE Hash to Element u2 P2", len) + */ + if (hkdf_expand(hash_len, pwd_seed, hash_len, + "SAE Hash to Element u2 P2", pwd_value, + pwd_value_len) < 0) + goto fail; + wpa_hexdump_key(MSG_DEBUG, "SAE: pwd-value (u2 P2)", + pwd_value, pwd_value_len); + + /* u2 = pwd-value modulo p */ + crypto_bignum_deinit(bn, 1); + bn = crypto_bignum_init_set(pwd_value, pwd_value_len); + if (!bn || crypto_bignum_mod(bn, prime, bn) < 0 || + crypto_bignum_to_bin(bn, pwd_value, sizeof(pwd_value), + prime_len) < 0) + goto fail; + wpa_hexdump_key(MSG_DEBUG, "SAE: u2", pwd_value, prime_len); + + /* P2 = SSWU(u2) */ + p2 = sswu(ec, group, bn); + if (!p2) + goto fail; + + /* PT = elem-op(P1, P2) */ + pt = crypto_ec_point_init(ec); + if (!pt) + goto fail; + if (crypto_ec_point_add(ec, p1, p2, pt) < 0) { + crypto_ec_point_deinit(pt, 1); + pt = NULL; + } + +fail: + forced_memzero(pwd_seed, sizeof(pwd_seed)); + forced_memzero(pwd_value, sizeof(pwd_value)); + crypto_bignum_deinit(bn, 1); + crypto_ec_point_deinit(p1, 1); + crypto_ec_point_deinit(p2, 1); + return pt; +} + + +size_t sae_ffc_prime_len_2_hash_len(size_t prime_len) +{ + if (prime_len <= 2048 / 8) + return 32; + if (prime_len <= 3072 / 8) + return 48; + return 64; +} + + +static struct crypto_bignum * +sae_derive_pt_ffc(const struct dh_group *dh, int group, + const u8 *ssid, size_t ssid_len, + const u8 *password, size_t password_len, + const char *identifier) +{ + size_t hash_len, prime_len, pwd_value_len; + struct crypto_bignum *prime, *order; + struct crypto_bignum *one = NULL, *two = NULL, *bn = NULL, *tmp = NULL, + *pt = NULL; + u8 pwd_seed[64]; + u8 pwd_value[SAE_MAX_PRIME_LEN + SAE_MAX_PRIME_LEN / 2]; + + prime = crypto_bignum_init_set(dh->prime, dh->prime_len); + order = crypto_bignum_init_set(dh->order, dh->order_len); + if (!prime || !order) + goto fail; + prime_len = dh->prime_len; + if (prime_len > SAE_MAX_PRIME_LEN) + goto fail; + hash_len = sae_ffc_prime_len_2_hash_len(prime_len); + + /* len = olen(p) + ceil(olen(p)/2) */ + pwd_value_len = prime_len + (prime_len + 1) / 2; + if (pwd_value_len > sizeof(pwd_value)) + goto fail; + + if (sae_pwd_seed(hash_len, ssid, ssid_len, password, password_len, + identifier, pwd_seed) < 0) + goto fail; + + /* pwd-value = HKDF-Expand(pwd-seed, "SAE Hash to Element", len) */ + if (hkdf_expand(hash_len, pwd_seed, hash_len, + "SAE Hash to Element", pwd_value, pwd_value_len) < 0) + goto fail; + wpa_hexdump_key(MSG_DEBUG, "SAE: pwd-value", + pwd_value, pwd_value_len); + + /* pwd-value = (pwd-value modulo (p-2)) + 2 */ + bn = crypto_bignum_init_set(pwd_value, pwd_value_len); + one = crypto_bignum_init_uint(1); + two = crypto_bignum_init_uint(2); + tmp = crypto_bignum_init(); + if (!bn || !one || !two || !tmp || + crypto_bignum_sub(prime, two, tmp) < 0 || + crypto_bignum_mod(bn, tmp, bn) < 0 || + crypto_bignum_add(bn, two, bn) < 0 || + crypto_bignum_to_bin(bn, pwd_value, sizeof(pwd_value), + prime_len) < 0) + goto fail; + wpa_hexdump_key(MSG_DEBUG, "SAE: pwd-value(reduced)", + pwd_value, prime_len); + + /* PT = pwd-value^((p-1)/q) modulo p */ + pt = crypto_bignum_init(); + if (!pt || + crypto_bignum_sub(prime, one, tmp) < 0 || + crypto_bignum_div(tmp, order, tmp) < 0 || + crypto_bignum_exptmod(bn, tmp, prime, pt) < 0) { + crypto_bignum_deinit(pt, 1); + pt = NULL; + goto fail; + } + debug_print_bignum("SAE: PT", pt, prime_len); + +fail: + forced_memzero(pwd_seed, sizeof(pwd_seed)); + forced_memzero(pwd_value, sizeof(pwd_value)); + crypto_bignum_deinit(bn, 1); + crypto_bignum_deinit(tmp, 1); + crypto_bignum_deinit(one, 0); + crypto_bignum_deinit(two, 0); + crypto_bignum_deinit(prime, 0); + crypto_bignum_deinit(order, 0); + return pt; +} + + +static struct sae_pt * +sae_derive_pt_group(int group, const u8 *ssid, size_t ssid_len, + const u8 *password, size_t password_len, + const char *identifier) +{ + struct sae_pt *pt; + + wpa_printf(MSG_DEBUG, "SAE: Derive PT - group %d", group); + + pt = os_zalloc(sizeof(*pt)); + if (!pt) + return NULL; + + pt->group = group; + pt->ec = crypto_ec_init(group); + if (pt->ec) { + pt->ecc_pt = sae_derive_pt_ecc(pt->ec, group, ssid, ssid_len, + password, password_len, + identifier); + if (!pt->ecc_pt) { + wpa_printf(MSG_DEBUG, "SAE: Failed to derive PT"); + goto fail; + } + + return pt; + } + + pt->dh = dh_groups_get(group); + if (!pt->dh) { + wpa_printf(MSG_DEBUG, "SAE: Unsupported group %d", group); + goto fail; + } + + pt->ffc_pt = sae_derive_pt_ffc(pt->dh, group, ssid, ssid_len, + password, password_len, identifier); + if (!pt->ffc_pt) { + wpa_printf(MSG_DEBUG, "SAE: Failed to derive PT"); + goto fail; + } + + return pt; +fail: + sae_deinit_pt(pt); + return NULL; +} + + +struct sae_pt * sae_derive_pt(int *groups, const u8 *ssid, size_t ssid_len, + const u8 *password, size_t password_len, + const char *identifier) +{ + struct sae_pt *pt = NULL, *last = NULL, *tmp; + int default_groups[] = { 19, 0 }; + int i; + + if (!groups) + groups = default_groups; + for (i = 0; groups[i] > 0; i++) { + tmp = sae_derive_pt_group(groups[i], ssid, ssid_len, password, + password_len, identifier); + if (!tmp) + continue; + + if (last) + last->next = tmp; + else + pt = tmp; + last = tmp; + } + + return pt; +} + + +static void sae_max_min_addr(const u8 *addr[], size_t len[], + const u8 *addr1, const u8 *addr2) +{ + len[0] = ETH_ALEN; + len[1] = ETH_ALEN; + if (os_memcmp(addr1, addr2, ETH_ALEN) > 0) { + addr[0] = addr1; + addr[1] = addr2; + } else { + addr[0] = addr2; + addr[1] = addr1; + } +} + + +struct crypto_ec_point * +sae_derive_pwe_from_pt_ecc(const struct sae_pt *pt, + const u8 *addr1, const u8 *addr2) +{ + u8 bin[SAE_MAX_ECC_PRIME_LEN * 2]; + size_t prime_len; + const u8 *addr[2]; + size_t len[2]; + u8 salt[64], hash[64]; + size_t hash_len; + const struct crypto_bignum *order; + struct crypto_bignum *tmp = NULL, *val = NULL, *one = NULL; + struct crypto_ec_point *pwe = NULL; + + wpa_printf(MSG_DEBUG, "SAE: Derive PWE from PT"); + prime_len = crypto_ec_prime_len(pt->ec); + if (crypto_ec_point_to_bin(pt->ec, pt->ecc_pt, + bin, bin + prime_len) < 0) + return NULL; + wpa_hexdump_key(MSG_DEBUG, "SAE: PT.x", bin, prime_len); + wpa_hexdump_key(MSG_DEBUG, "SAE: PT.y", bin + prime_len, prime_len); + + sae_max_min_addr(addr, len, addr1, addr2); + + /* val = H(0^n, + * MAX(STA-A-MAC, STA-B-MAC) || MIN(STA-A-MAC, STA-B-MAC)) */ + wpa_printf(MSG_DEBUG, "SAE: val = H(0^n, MAX(addrs) || MIN(addrs))"); + hash_len = sae_ecc_prime_len_2_hash_len(prime_len); + os_memset(salt, 0, hash_len); + if (hkdf_extract(hash_len, salt, hash_len, 2, addr, len, hash) < 0) + goto fail; + wpa_hexdump(MSG_DEBUG, "SAE: val", hash, hash_len); + + /* val = val modulo (q - 1) + 1 */ + order = crypto_ec_get_order(pt->ec); + tmp = crypto_bignum_init(); + val = crypto_bignum_init_set(hash, hash_len); + one = crypto_bignum_init_uint(1); + if (!tmp || !val || !one || + crypto_bignum_sub(order, one, tmp) < 0 || + crypto_bignum_mod(val, tmp, val) < 0 || + crypto_bignum_add(val, one, val) < 0) + goto fail; + debug_print_bignum("SAE: val(reduced to 1..q-1)", val, prime_len); + + /* PWE = scalar-op(val, PT) */ + pwe = crypto_ec_point_init(pt->ec); + if (!pwe || + crypto_ec_point_mul(pt->ec, pt->ecc_pt, val, pwe) < 0 || + crypto_ec_point_to_bin(pt->ec, pwe, bin, bin + prime_len) < 0) { + crypto_ec_point_deinit(pwe, 1); + pwe = NULL; + goto fail; + } + wpa_hexdump_key(MSG_DEBUG, "SAE: PWE.x", bin, prime_len); + wpa_hexdump_key(MSG_DEBUG, "SAE: PWE.y", bin + prime_len, prime_len); + +fail: + crypto_bignum_deinit(tmp, 1); + crypto_bignum_deinit(val, 1); + crypto_bignum_deinit(one, 0); + return pwe; +} + + +struct crypto_bignum * +sae_derive_pwe_from_pt_ffc(const struct sae_pt *pt, + const u8 *addr1, const u8 *addr2) +{ + size_t prime_len; + const u8 *addr[2]; + size_t len[2]; + u8 salt[64], hash[64]; + size_t hash_len; + struct crypto_bignum *tmp = NULL, *val = NULL, *one = NULL; + struct crypto_bignum *pwe = NULL, *order = NULL, *prime = NULL; + + wpa_printf(MSG_DEBUG, "SAE: Derive PWE from PT"); + prime = crypto_bignum_init_set(pt->dh->prime, pt->dh->prime_len); + order = crypto_bignum_init_set(pt->dh->order, pt->dh->order_len); + if (!prime || !order) + goto fail; + prime_len = pt->dh->prime_len; + + sae_max_min_addr(addr, len, addr1, addr2); + + /* val = H(0^n, + * MAX(STA-A-MAC, STA-B-MAC) || MIN(STA-A-MAC, STA-B-MAC)) */ + wpa_printf(MSG_DEBUG, "SAE: val = H(0^n, MAX(addrs) || MIN(addrs))"); + hash_len = sae_ffc_prime_len_2_hash_len(prime_len); + os_memset(salt, 0, hash_len); + if (hkdf_extract(hash_len, salt, hash_len, 2, addr, len, hash) < 0) + goto fail; + wpa_hexdump(MSG_DEBUG, "SAE: val", hash, hash_len); + + /* val = val modulo (q - 1) + 1 */ + tmp = crypto_bignum_init(); + val = crypto_bignum_init_set(hash, hash_len); + one = crypto_bignum_init_uint(1); + if (!tmp || !val || !one || + crypto_bignum_sub(order, one, tmp) < 0 || + crypto_bignum_mod(val, tmp, val) < 0 || + crypto_bignum_add(val, one, val) < 0) + goto fail; + debug_print_bignum("SAE: val(reduced to 1..q-1)", val, prime_len); + + /* PWE = scalar-op(val, PT) */ + pwe = crypto_bignum_init(); + if (!pwe || crypto_bignum_exptmod(pt->ffc_pt, val, prime, pwe) < 0) { + crypto_bignum_deinit(pwe, 1); + pwe = NULL; + goto fail; + } + debug_print_bignum("SAE: PWE", pwe, prime_len); + +fail: + crypto_bignum_deinit(tmp, 1); + crypto_bignum_deinit(val, 1); + crypto_bignum_deinit(one, 0); + crypto_bignum_deinit(prime, 0); + crypto_bignum_deinit(order, 0); + return pwe; +} + + +void sae_deinit_pt(struct sae_pt *pt) +{ + struct sae_pt *prev; + + while (pt) { + crypto_ec_point_deinit(pt->ecc_pt, 1); + crypto_bignum_deinit(pt->ffc_pt, 1); + crypto_ec_deinit(pt->ec); + prev = pt; + pt = pt->next; + os_free(prev); + } +} + + static int sae_derive_commit_element_ecc(struct sae_data *sae, struct crypto_bignum *mask) { @@ -605,10 +1347,66 @@ int sae_prepare_commit(const u8 *addr1, const u8 *addr2, identifier) < 0) || (sae->tmp->dh && sae_derive_pwe_ffc(sae, addr1, addr2, password, password_len, - identifier) < 0) || - sae_derive_commit(sae) < 0) + identifier) < 0)) return -1; - return 0; + + sae->tmp->h2e = 0; + return sae_derive_commit(sae); +} + + +int sae_prepare_commit_pt(struct sae_data *sae, const struct sae_pt *pt, + const u8 *addr1, const u8 *addr2, + int *rejected_groups) +{ + if (!sae->tmp) + return -1; + + while (pt) { + if (pt->group == sae->group) + break; + pt = pt->next; + } + if (!pt) { + wpa_printf(MSG_INFO, "SAE: Could not find PT for group %u", + sae->group); + return -1; + } + + sae->tmp->own_addr_higher = os_memcmp(addr1, addr2, ETH_ALEN) > 0; + wpabuf_free(sae->tmp->own_rejected_groups); + sae->tmp->own_rejected_groups = NULL; + if (rejected_groups) { + int count, i; + struct wpabuf *groups; + + count = int_array_len(rejected_groups); + groups = wpabuf_alloc(count * 2); + if (!groups) + return -1; + for (i = 0; i < count; i++) + wpabuf_put_le16(groups, rejected_groups[i]); + sae->tmp->own_rejected_groups = groups; + } + + if (pt->ec) { + crypto_ec_point_deinit(sae->tmp->pwe_ecc, 1); + sae->tmp->pwe_ecc = sae_derive_pwe_from_pt_ecc(pt, addr1, + addr2); + if (!sae->tmp->pwe_ecc) + return -1; + } + + if (pt->dh) { + crypto_bignum_deinit(sae->tmp->pwe_ffc, 1); + sae->tmp->pwe_ffc = sae_derive_pwe_from_pt_ffc(pt, addr1, + addr2); + if (!sae->tmp->pwe_ffc) + return -1; + } + + sae->tmp->h2e = 1; + return sae_derive_commit(sae); } diff --git a/src/common/sae.h b/src/common/sae.h index fee0b197c..fdc2fd84b 100644 --- a/src/common/sae.h +++ b/src/common/sae.h @@ -43,7 +43,20 @@ struct sae_temporary_data { char *pw_id; int vlan_id; u8 bssid[ETH_ALEN]; + struct wpabuf *own_rejected_groups; struct wpabuf *peer_rejected_groups; + unsigned int h2e:1; + unsigned int own_addr_higher:1; +}; + +struct sae_pt { + struct sae_pt *next; + int group; + struct crypto_ec *ec; + struct crypto_ec_point *ecc_pt; + + const struct dh_group *dh; + struct crypto_bignum *ffc_pt; }; enum sae_state { @@ -69,6 +82,9 @@ void sae_clear_data(struct sae_data *sae); int sae_prepare_commit(const u8 *addr1, const u8 *addr2, const u8 *password, size_t password_len, const char *identifier, struct sae_data *sae); +int sae_prepare_commit_pt(struct sae_data *sae, const struct sae_pt *pt, + const u8 *addr1, const u8 *addr2, + int *rejected_groups); int sae_process_commit(struct sae_data *sae); void sae_write_commit(struct sae_data *sae, struct wpabuf *buf, const struct wpabuf *token, const char *identifier); @@ -79,5 +95,15 @@ void sae_write_confirm(struct sae_data *sae, struct wpabuf *buf); int sae_check_confirm(struct sae_data *sae, const u8 *data, size_t len); u16 sae_group_allowed(struct sae_data *sae, int *allowed_groups, u16 group); const char * sae_state_txt(enum sae_state state); +struct sae_pt * sae_derive_pt(int *groups, const u8 *ssid, size_t ssid_len, + const u8 *password, size_t password_len, + const char *identifier); +struct crypto_ec_point * +sae_derive_pwe_from_pt_ecc(const struct sae_pt *pt, + const u8 *addr1, const u8 *addr2); +struct crypto_bignum * +sae_derive_pwe_from_pt_ffc(const struct sae_pt *pt, + const u8 *addr1, const u8 *addr2); +void sae_deinit_pt(struct sae_pt *pt); #endif /* SAE_H */