summaryrefslogtreecommitdiff
path: root/net/tls
diff options
context:
space:
mode:
Diffstat (limited to 'net/tls')
-rw-r--r--net/tls/tls.h2
-rw-r--r--net/tls/tls_strp.c11
-rw-r--r--net/tls/tls_sw.c10
3 files changed, 17 insertions, 6 deletions
diff --git a/net/tls/tls.h b/net/tls/tls.h
index 774859b63f0d..4e077068e6d9 100644
--- a/net/tls/tls.h
+++ b/net/tls/tls.h
@@ -196,7 +196,7 @@ void tls_strp_msg_done(struct tls_strparser *strp);
int tls_rx_msg_size(struct tls_strparser *strp, struct sk_buff *skb);
void tls_rx_msg_ready(struct tls_strparser *strp);
-void tls_strp_msg_load(struct tls_strparser *strp, bool force_refresh);
+bool tls_strp_msg_load(struct tls_strparser *strp, bool force_refresh);
int tls_strp_msg_cow(struct tls_sw_context_rx *ctx);
struct sk_buff *tls_strp_msg_detach(struct tls_sw_context_rx *ctx);
int tls_strp_msg_hold(struct tls_strparser *strp, struct sk_buff_head *dst);
diff --git a/net/tls/tls_strp.c b/net/tls/tls_strp.c
index 095cf31bae0b..d71643b494a1 100644
--- a/net/tls/tls_strp.c
+++ b/net/tls/tls_strp.c
@@ -475,7 +475,7 @@ static void tls_strp_load_anchor_with_queue(struct tls_strparser *strp, int len)
strp->stm.offset = offset;
}
-void tls_strp_msg_load(struct tls_strparser *strp, bool force_refresh)
+bool tls_strp_msg_load(struct tls_strparser *strp, bool force_refresh)
{
struct strp_msg *rxm;
struct tls_msg *tlm;
@@ -484,8 +484,11 @@ void tls_strp_msg_load(struct tls_strparser *strp, bool force_refresh)
DEBUG_NET_WARN_ON_ONCE(!strp->stm.full_len);
if (!strp->copy_mode && force_refresh) {
- if (WARN_ON(tcp_inq(strp->sk) < strp->stm.full_len))
- return;
+ if (unlikely(tcp_inq(strp->sk) < strp->stm.full_len)) {
+ WRITE_ONCE(strp->msg_ready, 0);
+ memset(&strp->stm, 0, sizeof(strp->stm));
+ return false;
+ }
tls_strp_load_anchor_with_queue(strp, strp->stm.full_len);
}
@@ -495,6 +498,8 @@ void tls_strp_msg_load(struct tls_strparser *strp, bool force_refresh)
rxm->offset = strp->stm.offset;
tlm = tls_msg(strp->anchor);
tlm->control = strp->mark;
+
+ return true;
}
/* Called with lock held on lower socket */
diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c
index 549d1ea01a72..bac65d0d4e3e 100644
--- a/net/tls/tls_sw.c
+++ b/net/tls/tls_sw.c
@@ -1384,7 +1384,8 @@ tls_rx_rec_wait(struct sock *sk, struct sk_psock *psock, bool nonblock,
return sock_intr_errno(timeo);
}
- tls_strp_msg_load(&ctx->strp, released);
+ if (unlikely(!tls_strp_msg_load(&ctx->strp, released)))
+ return tls_rx_rec_wait(sk, psock, nonblock, false);
return 1;
}
@@ -1807,6 +1808,9 @@ int decrypt_skb(struct sock *sk, struct scatterlist *sgout)
return tls_decrypt_sg(sk, NULL, sgout, &darg);
}
+/* All records returned from a recvmsg() call must have the same type.
+ * 0 is not a valid content type. Use it as "no type reported, yet".
+ */
static int tls_record_content_type(struct msghdr *msg, struct tls_msg *tlm,
u8 *control)
{
@@ -2050,8 +2054,10 @@ int tls_sw_recvmsg(struct sock *sk,
if (err < 0)
goto end;
+ /* process_rx_list() will set @control if it processed any records */
copied = err;
- if (len <= copied || (copied && control != TLS_RECORD_TYPE_DATA) || rx_more)
+ if (len <= copied || rx_more ||
+ (control && control != TLS_RECORD_TYPE_DATA))
goto end;
target = sock_rcvlowat(sk, flags & MSG_WAITALL, len);