summaryrefslogtreecommitdiff
path: root/drivers/nvme/common/auth.c
diff options
context:
space:
mode:
Diffstat (limited to 'drivers/nvme/common/auth.c')
-rw-r--r--drivers/nvme/common/auth.c86
1 files changed, 66 insertions, 20 deletions
diff --git a/drivers/nvme/common/auth.c b/drivers/nvme/common/auth.c
index 91e273b89fea..1f51fbebd9fa 100644
--- a/drivers/nvme/common/auth.c
+++ b/drivers/nvme/common/auth.c
@@ -684,6 +684,59 @@ out_free_enc:
EXPORT_SYMBOL_GPL(nvme_auth_generate_digest);
/**
+ * hkdf_expand_label - HKDF-Expand-Label (RFC 8846 section 7.1)
+ * @hmac_tfm: hash context keyed with pseudorandom key
+ * @label: ASCII label without "tls13 " prefix
+ * @labellen: length of @label
+ * @context: context bytes
+ * @contextlen: length of @context
+ * @okm: output keying material
+ * @okmlen: length of @okm
+ *
+ * Build the TLS 1.3 HkdfLabel structure and invoke hkdf_expand().
+ *
+ * Returns 0 on success with output keying material stored in @okm,
+ * or a negative errno value otherwise.
+ */
+static int hkdf_expand_label(struct crypto_shash *hmac_tfm,
+ const u8 *label, unsigned int labellen,
+ const u8 *context, unsigned int contextlen,
+ u8 *okm, unsigned int okmlen)
+{
+ int err;
+ u8 *info;
+ unsigned int infolen;
+ const char *tls13_prefix = "tls13 ";
+ unsigned int prefixlen = strlen(tls13_prefix);
+
+ if (WARN_ON(labellen > (255 - prefixlen)))
+ return -EINVAL;
+ if (WARN_ON(contextlen > 255))
+ return -EINVAL;
+
+ infolen = 2 + (1 + prefixlen + labellen) + (1 + contextlen);
+ info = kzalloc(infolen, GFP_KERNEL);
+ if (!info)
+ return -ENOMEM;
+
+ /* HkdfLabel.Length */
+ put_unaligned_be16(okmlen, info);
+
+ /* HkdfLabel.Label */
+ info[2] = prefixlen + labellen;
+ memcpy(info + 3, tls13_prefix, prefixlen);
+ memcpy(info + 3 + prefixlen, label, labellen);
+
+ /* HkdfLabel.Context */
+ info[3 + prefixlen + labellen] = contextlen;
+ memcpy(info + 4 + prefixlen + labellen, context, contextlen);
+
+ err = hkdf_expand(hmac_tfm, info, infolen, okm, okmlen);
+ kfree_sensitive(info);
+ return err;
+}
+
+/**
* nvme_auth_derive_tls_psk - Derive TLS PSK
* @hmac_id: Hash function identifier
* @psk: generated input PSK
@@ -715,10 +768,10 @@ int nvme_auth_derive_tls_psk(int hmac_id, u8 *psk, size_t psk_len,
{
struct crypto_shash *hmac_tfm;
const char *hmac_name;
- const char *psk_prefix = "tls13 nvme-tls-psk";
+ const char *label = "nvme-tls-psk";
static const char default_salt[HKDF_MAX_HASHLEN];
- size_t info_len, prk_len;
- char *info;
+ size_t prk_len;
+ const char *ctx;
unsigned char *prk, *tls_key;
int ret;
@@ -758,36 +811,29 @@ int nvme_auth_derive_tls_psk(int hmac_id, u8 *psk, size_t psk_len,
if (ret)
goto out_free_prk;
- /*
- * 2 additional bytes for the length field from HDKF-Expand-Label,
- * 2 additional bytes for the HMAC ID, and one byte for the space
- * separator.
- */
- info_len = strlen(psk_digest) + strlen(psk_prefix) + 5;
- info = kzalloc(info_len + 1, GFP_KERNEL);
- if (!info) {
+ ctx = kasprintf(GFP_KERNEL, "%02d %s", hmac_id, psk_digest);
+ if (!ctx) {
ret = -ENOMEM;
goto out_free_prk;
}
- put_unaligned_be16(psk_len, info);
- memcpy(info + 2, psk_prefix, strlen(psk_prefix));
- sprintf(info + 2 + strlen(psk_prefix), "%02d %s", hmac_id, psk_digest);
-
tls_key = kzalloc(psk_len, GFP_KERNEL);
if (!tls_key) {
ret = -ENOMEM;
- goto out_free_info;
+ goto out_free_ctx;
}
- ret = hkdf_expand(hmac_tfm, info, info_len, tls_key, psk_len);
+ ret = hkdf_expand_label(hmac_tfm,
+ label, strlen(label),
+ ctx, strlen(ctx),
+ tls_key, psk_len);
if (ret) {
kfree(tls_key);
- goto out_free_info;
+ goto out_free_ctx;
}
*ret_psk = tls_key;
-out_free_info:
- kfree(info);
+out_free_ctx:
+ kfree(ctx);
out_free_prk:
kfree(prk);
out_free_shash: