diff --git a/src/ap/ieee802_11.c b/src/ap/ieee802_11.c index 0be9b49dd..4252fccaa 100644 --- a/src/ap/ieee802_11.c +++ b/src/ap/ieee802_11.c @@ -1417,6 +1417,11 @@ prepare_auth_resp_fils(struct hostapd_data *hapd, } pmk = pmk_buf; + /* Don't use DHss in PTK derivation if PMKSA caching is not + * used. */ + wpabuf_clear_free(sta->fils_dh_ss); + sta->fils_dh_ss = NULL; + if (sta->fils_erp_pmkid_set) { /* TODO: get PMKLifetime from WPA parameters */ unsigned int dot11RSNAConfigPMKLifetime = 43200; @@ -1449,6 +1454,10 @@ prepare_auth_resp_fils(struct hostapd_data *hapd, if (fils_auth_pmk_to_ptk(sta->wpa_sm, pmk, pmk_len, sta->fils_snonce, fils_nonce, + sta->fils_dh_ss ? + wpabuf_head(sta->fils_dh_ss) : NULL, + sta->fils_dh_ss ? + wpabuf_len(sta->fils_dh_ss) : 0, sta->fils_g_sta, pub) < 0) { *resp = WLAN_STATUS_UNSPECIFIED_FAILURE; wpabuf_free(data); diff --git a/src/ap/wpa_auth.c b/src/ap/wpa_auth.c index 56857e855..e5b66b397 100644 --- a/src/ap/wpa_auth.c +++ b/src/ap/wpa_auth.c @@ -2113,6 +2113,7 @@ static int wpa_derive_ptk(struct wpa_state_machine *sm, const u8 *snonce, int fils_auth_pmk_to_ptk(struct wpa_state_machine *sm, const u8 *pmk, size_t pmk_len, const u8 *snonce, const u8 *anonce, + const u8 *dhss, size_t dhss_len, struct wpabuf *g_sta, struct wpabuf *g_ap) { u8 ick[FILS_ICK_MAX_LEN]; @@ -2122,7 +2123,8 @@ int fils_auth_pmk_to_ptk(struct wpa_state_machine *sm, const u8 *pmk, size_t fils_ft_len = 0; res = fils_pmk_to_ptk(pmk, pmk_len, sm->addr, sm->wpa_auth->addr, - snonce, anonce, &sm->PTK, ick, &ick_len, + snonce, anonce, dhss, dhss_len, + &sm->PTK, ick, &ick_len, sm->wpa_key_mgmt, sm->pairwise, fils_ft, &fils_ft_len); if (res < 0) diff --git a/src/ap/wpa_auth.h b/src/ap/wpa_auth.h index df4fac5cf..5e8a4cc84 100644 --- a/src/ap/wpa_auth.h +++ b/src/ap/wpa_auth.h @@ -400,6 +400,7 @@ int wpa_auth_ensure_group(struct wpa_authenticator *wpa_auth, int vlan_id); int wpa_auth_release_group(struct wpa_authenticator *wpa_auth, int vlan_id); int fils_auth_pmk_to_ptk(struct wpa_state_machine *sm, const u8 *pmk, size_t pmk_len, const u8 *snonce, const u8 *anonce, + const u8 *dhss, size_t dhss_len, struct wpabuf *g_sta, struct wpabuf *g_ap); int fils_decrypt_assoc(struct wpa_state_machine *sm, const u8 *fils_session, const struct ieee80211_mgmt *mgmt, size_t frame_len, diff --git a/src/common/wpa_common.c b/src/common/wpa_common.c index 68e788377..4efb04604 100644 --- a/src/common/wpa_common.c +++ b/src/common/wpa_common.c @@ -389,19 +389,22 @@ int fils_pmkid_erp(int akmp, const u8 *reauth, size_t reauth_len, int fils_pmk_to_ptk(const u8 *pmk, size_t pmk_len, const u8 *spa, const u8 *aa, - const u8 *snonce, const u8 *anonce, struct wpa_ptk *ptk, + const u8 *snonce, const u8 *anonce, const u8 *dhss, + size_t dhss_len, struct wpa_ptk *ptk, u8 *ick, size_t *ick_len, int akmp, int cipher, u8 *fils_ft, size_t *fils_ft_len) { - u8 data[2 * ETH_ALEN + 2 * FILS_NONCE_LEN]; + u8 *data, *pos; + size_t data_len; u8 tmp[FILS_ICK_MAX_LEN + WPA_KEK_MAX_LEN + WPA_TK_MAX_LEN + FILS_FT_MAX_LEN]; size_t key_data_len; const char *label = "FILS PTK Derivation"; + int ret = -1; /* * FILS-Key-Data = PRF-X(PMK, "FILS PTK Derivation", - * SPA || AA || SNonce || ANonce) + * SPA || AA || SNonce || ANonce [ || DHss ]) * ICK = L(FILS-Key-Data, 0, ICK_bits) * KEK = L(FILS-Key-Data, ICK_bits, KEK_bits) * TK = L(FILS-Key-Data, ICK_bits + KEK_bits, TK_bits) @@ -409,10 +412,21 @@ int fils_pmk_to_ptk(const u8 *pmk, size_t pmk_len, const u8 *spa, const u8 *aa, * FILS-FT = L(FILS-Key-Data, ICK_bits + KEK_bits + TK_bits, * FILS-FT_bits) */ - os_memcpy(data, spa, ETH_ALEN); - os_memcpy(data + ETH_ALEN, aa, ETH_ALEN); - os_memcpy(data + 2 * ETH_ALEN, snonce, FILS_NONCE_LEN); - os_memcpy(data + 2 * ETH_ALEN + FILS_NONCE_LEN, anonce, FILS_NONCE_LEN); + data_len = 2 * ETH_ALEN + 2 * FILS_NONCE_LEN + dhss_len; + data = os_malloc(data_len); + if (!data) + goto err; + pos = data; + os_memcpy(pos, spa, ETH_ALEN); + pos += ETH_ALEN; + os_memcpy(pos, aa, ETH_ALEN); + pos += ETH_ALEN; + os_memcpy(pos, snonce, FILS_NONCE_LEN); + pos += FILS_NONCE_LEN; + os_memcpy(pos, anonce, FILS_NONCE_LEN); + pos += FILS_NONCE_LEN; + if (dhss) + os_memcpy(pos, dhss, dhss_len); ptk->kck_len = 0; ptk->kek_len = wpa_kek_len(akmp, pmk_len); @@ -422,7 +436,7 @@ int fils_pmk_to_ptk(const u8 *pmk, size_t pmk_len, const u8 *spa, const u8 *aa, else if (wpa_key_mgmt_sha256(akmp)) *ick_len = 32; else - return -1; + goto err; key_data_len = *ick_len + ptk->kek_len + ptk->tk_len; if (fils_ft && fils_ft_len) { @@ -439,20 +453,22 @@ int fils_pmk_to_ptk(const u8 *pmk, size_t pmk_len, const u8 *spa, const u8 *aa, if (wpa_key_mgmt_sha384(akmp)) { wpa_printf(MSG_DEBUG, "FILS: PTK derivation using PRF(SHA384)"); - if (sha384_prf(pmk, pmk_len, label, data, sizeof(data), + if (sha384_prf(pmk, pmk_len, label, data, data_len, tmp, key_data_len) < 0) - return -1; + goto err; } else { wpa_printf(MSG_DEBUG, "FILS: PTK derivation using PRF(SHA256)"); - if (sha256_prf(pmk, pmk_len, label, data, sizeof(data), + if (sha256_prf(pmk, pmk_len, label, data, data_len, tmp, key_data_len) < 0) - return -1; + goto err; } wpa_printf(MSG_DEBUG, "FILS: PTK derivation - SPA=" MACSTR " AA=" MACSTR, MAC2STR(spa), MAC2STR(aa)); wpa_hexdump(MSG_DEBUG, "FILS: SNonce", snonce, FILS_NONCE_LEN); wpa_hexdump(MSG_DEBUG, "FILS: ANonce", anonce, FILS_NONCE_LEN); + if (dhss) + wpa_hexdump_key(MSG_DEBUG, "FILS: DHss", dhss, dhss_len); wpa_hexdump_key(MSG_DEBUG, "FILS: PMK", pmk, pmk_len); wpa_hexdump_key(MSG_DEBUG, "FILS: FILS-Key-Data", tmp, key_data_len); @@ -473,7 +489,10 @@ int fils_pmk_to_ptk(const u8 *pmk, size_t pmk_len, const u8 *spa, const u8 *aa, } os_memset(tmp, 0, sizeof(tmp)); - return 0; + ret = 0; +err: + bin_clear_free(data, data_len); + return ret; } diff --git a/src/common/wpa_common.h b/src/common/wpa_common.h index b4d6c1384..2f11d2f34 100644 --- a/src/common/wpa_common.h +++ b/src/common/wpa_common.h @@ -347,7 +347,8 @@ int fils_rmsk_to_pmk(int akmp, const u8 *rmsk, size_t rmsk_len, int fils_pmkid_erp(int akmp, const u8 *reauth, size_t reauth_len, u8 *pmkid); int fils_pmk_to_ptk(const u8 *pmk, size_t pmk_len, const u8 *spa, const u8 *aa, - const u8 *snonce, const u8 *anonce, struct wpa_ptk *ptk, + const u8 *snonce, const u8 *anonce, const u8 *dhss, + size_t dhss_len, struct wpa_ptk *ptk, u8 *ick, size_t *ick_len, int akmp, int cipher, u8 *fils_ft, size_t *fils_ft_len); int fils_key_auth_sk(const u8 *ick, size_t ick_len, const u8 *snonce, diff --git a/src/rsn_supp/wpa.c b/src/rsn_supp/wpa.c index ab2302b8c..739689d18 100644 --- a/src/rsn_supp/wpa.c +++ b/src/rsn_supp/wpa.c @@ -3639,8 +3639,12 @@ int fils_process_auth(struct wpa_sm *sm, const u8 *bssid, const u8 *data, dh_ss ? wpabuf_len(dh_ss) : 0, sm->pmk, &sm->pmk_len); os_memset(rmsk, 0, sizeof(rmsk)); + + /* Don't use DHss in PTK derivation if PMKSA caching is not + * used. */ wpabuf_clear_free(dh_ss); dh_ss = NULL; + if (res) goto fail; @@ -3665,12 +3669,19 @@ int fils_process_auth(struct wpa_sm *sm, const u8 *bssid, const u8 *data, } if (fils_pmk_to_ptk(sm->pmk, sm->pmk_len, sm->own_addr, sm->bssid, - sm->fils_nonce, sm->fils_anonce, &sm->ptk, - ick, &ick_len, sm->key_mgmt, sm->pairwise_cipher, + sm->fils_nonce, sm->fils_anonce, + dh_ss ? wpabuf_head(dh_ss) : NULL, + dh_ss ? wpabuf_len(dh_ss) : 0, + &sm->ptk, ick, &ick_len, + sm->key_mgmt, sm->pairwise_cipher, sm->fils_ft, &sm->fils_ft_len) < 0) { wpa_printf(MSG_DEBUG, "FILS: Failed to derive PTK"); goto fail; } + + wpabuf_clear_free(dh_ss); + dh_ss = NULL; + sm->ptk_set = 1; sm->tptk_set = 0; os_memset(&sm->tptk, 0, sizeof(sm->tptk)); diff --git a/wlantest/rx_mgmt.c b/wlantest/rx_mgmt.c index b15f561e3..08ebd88cc 100644 --- a/wlantest/rx_mgmt.c +++ b/wlantest/rx_mgmt.c @@ -351,7 +351,8 @@ static int try_rmsk(struct wlantest *wt, struct wlantest_bss *bss, return -1; if (fils_pmk_to_ptk(pmk_buf, pmk_len, sta->addr, bss->bssid, - sta->snonce, sta->anonce, &ptk, ick, &ick_len, + sta->snonce, sta->anonce, NULL, 0, + &ptk, ick, &ick_len, sta->key_mgmt, sta->pairwise_cipher, NULL, NULL) < 0) return -1;