diff --git a/src/crypto/crypto_internal-cipher.c b/src/crypto/crypto_internal-cipher.c index 4bd711d09..ad0930a5a 100644 --- a/src/crypto/crypto_internal-cipher.c +++ b/src/crypto/crypto_internal-cipher.c @@ -24,7 +24,6 @@ struct crypto_cipher { } rc4; struct { u8 cbc[32]; - size_t block_size; void *ctx_enc; void *ctx_dec; } aes; @@ -63,10 +62,6 @@ struct crypto_cipher * crypto_cipher_init(enum crypto_cipher_alg alg, os_memcpy(ctx->u.rc4.key, key, key_len); break; case CRYPTO_CIPHER_ALG_AES: - if (key_len > sizeof(ctx->u.aes.cbc)) { - os_free(ctx); - return NULL; - } ctx->u.aes.ctx_enc = aes_encrypt_init(key, key_len); if (ctx->u.aes.ctx_enc == NULL) { os_free(ctx); @@ -78,8 +73,7 @@ struct crypto_cipher * crypto_cipher_init(enum crypto_cipher_alg alg, os_free(ctx); return NULL; } - ctx->u.aes.block_size = key_len; - os_memcpy(ctx->u.aes.cbc, iv, ctx->u.aes.block_size); + os_memcpy(ctx->u.aes.cbc, iv, AES_BLOCK_SIZE); break; case CRYPTO_CIPHER_ALG_3DES: if (key_len != 24) { @@ -120,18 +114,17 @@ int crypto_cipher_encrypt(struct crypto_cipher *ctx, const u8 *plain, ctx->u.rc4.used_bytes += len; break; case CRYPTO_CIPHER_ALG_AES: - if (len % ctx->u.aes.block_size) + if (len % AES_BLOCK_SIZE) return -1; - blocks = len / ctx->u.aes.block_size; + blocks = len / AES_BLOCK_SIZE; for (i = 0; i < blocks; i++) { - for (j = 0; j < ctx->u.aes.block_size; j++) + for (j = 0; j < AES_BLOCK_SIZE; j++) ctx->u.aes.cbc[j] ^= plain[j]; aes_encrypt(ctx->u.aes.ctx_enc, ctx->u.aes.cbc, ctx->u.aes.cbc); - os_memcpy(crypt, ctx->u.aes.cbc, - ctx->u.aes.block_size); - plain += ctx->u.aes.block_size; - crypt += ctx->u.aes.block_size; + os_memcpy(crypt, ctx->u.aes.cbc, AES_BLOCK_SIZE); + plain += AES_BLOCK_SIZE; + crypt += AES_BLOCK_SIZE; } break; case CRYPTO_CIPHER_ALG_3DES: @@ -185,17 +178,17 @@ int crypto_cipher_decrypt(struct crypto_cipher *ctx, const u8 *crypt, ctx->u.rc4.used_bytes += len; break; case CRYPTO_CIPHER_ALG_AES: - if (len % ctx->u.aes.block_size) + if (len % AES_BLOCK_SIZE) return -1; - blocks = len / ctx->u.aes.block_size; + blocks = len / AES_BLOCK_SIZE; for (i = 0; i < blocks; i++) { - os_memcpy(tmp, crypt, ctx->u.aes.block_size); + os_memcpy(tmp, crypt, AES_BLOCK_SIZE); aes_decrypt(ctx->u.aes.ctx_dec, crypt, plain); - for (j = 0; j < ctx->u.aes.block_size; j++) + for (j = 0; j < AES_BLOCK_SIZE; j++) plain[j] ^= ctx->u.aes.cbc[j]; - os_memcpy(ctx->u.aes.cbc, tmp, ctx->u.aes.block_size); - plain += ctx->u.aes.block_size; - crypt += ctx->u.aes.block_size; + os_memcpy(ctx->u.aes.cbc, tmp, AES_BLOCK_SIZE); + plain += AES_BLOCK_SIZE; + crypt += AES_BLOCK_SIZE; } break; case CRYPTO_CIPHER_ALG_3DES: