diff options
Diffstat (limited to 'net/tls')
-rw-r--r-- | net/tls/tls.h | 3 | ||||
-rw-r--r-- | net/tls/tls_strp.c | 25 | ||||
-rw-r--r-- | net/tls/tls_sw.c | 13 |
3 files changed, 28 insertions, 13 deletions
diff --git a/net/tls/tls.h b/net/tls/tls.h index 774859b63f0d..e4c42731ce39 100644 --- a/net/tls/tls.h +++ b/net/tls/tls.h @@ -141,6 +141,7 @@ void update_sk_prot(struct sock *sk, struct tls_context *ctx); int wait_on_pending_writer(struct sock *sk, long *timeo); void tls_err_abort(struct sock *sk, int err); +void tls_strp_abort_strp(struct tls_strparser *strp, int err); int init_prot_info(struct tls_prot_info *prot, const struct tls_crypto_info *crypto_info, @@ -196,7 +197,7 @@ void tls_strp_msg_done(struct tls_strparser *strp); int tls_rx_msg_size(struct tls_strparser *strp, struct sk_buff *skb); void tls_rx_msg_ready(struct tls_strparser *strp); -void tls_strp_msg_load(struct tls_strparser *strp, bool force_refresh); +bool tls_strp_msg_load(struct tls_strparser *strp, bool force_refresh); int tls_strp_msg_cow(struct tls_sw_context_rx *ctx); struct sk_buff *tls_strp_msg_detach(struct tls_sw_context_rx *ctx); int tls_strp_msg_hold(struct tls_strparser *strp, struct sk_buff_head *dst); diff --git a/net/tls/tls_strp.c b/net/tls/tls_strp.c index 095cf31bae0b..98e12f0ff57e 100644 --- a/net/tls/tls_strp.c +++ b/net/tls/tls_strp.c @@ -13,7 +13,7 @@ static struct workqueue_struct *tls_strp_wq; -static void tls_strp_abort_strp(struct tls_strparser *strp, int err) +void tls_strp_abort_strp(struct tls_strparser *strp, int err) { if (strp->stopped) return; @@ -211,11 +211,17 @@ static int tls_strp_copyin_frag(struct tls_strparser *strp, struct sk_buff *skb, struct sk_buff *in_skb, unsigned int offset, size_t in_len) { + unsigned int nfrag = skb->len / PAGE_SIZE; size_t len, chunk; skb_frag_t *frag; int sz; - frag = &skb_shinfo(skb)->frags[skb->len / PAGE_SIZE]; + if (unlikely(nfrag >= skb_shinfo(skb)->nr_frags)) { + DEBUG_NET_WARN_ON_ONCE(1); + return -EMSGSIZE; + } + + frag = &skb_shinfo(skb)->frags[nfrag]; len = in_len; /* First make sure we got the header */ @@ -475,7 +481,7 @@ static void tls_strp_load_anchor_with_queue(struct tls_strparser *strp, int len) strp->stm.offset = offset; } -void tls_strp_msg_load(struct tls_strparser *strp, bool force_refresh) +bool tls_strp_msg_load(struct tls_strparser *strp, bool force_refresh) { struct strp_msg *rxm; struct tls_msg *tlm; @@ -484,8 +490,11 @@ void tls_strp_msg_load(struct tls_strparser *strp, bool force_refresh) DEBUG_NET_WARN_ON_ONCE(!strp->stm.full_len); if (!strp->copy_mode && force_refresh) { - if (WARN_ON(tcp_inq(strp->sk) < strp->stm.full_len)) - return; + if (unlikely(tcp_inq(strp->sk) < strp->stm.full_len)) { + WRITE_ONCE(strp->msg_ready, 0); + memset(&strp->stm, 0, sizeof(strp->stm)); + return false; + } tls_strp_load_anchor_with_queue(strp, strp->stm.full_len); } @@ -495,6 +504,8 @@ void tls_strp_msg_load(struct tls_strparser *strp, bool force_refresh) rxm->offset = strp->stm.offset; tlm = tls_msg(strp->anchor); tlm->control = strp->mark; + + return true; } /* Called with lock held on lower socket */ @@ -515,10 +526,8 @@ static int tls_strp_read_sock(struct tls_strparser *strp) tls_strp_load_anchor_with_queue(strp, inq); if (!strp->stm.full_len) { sz = tls_rx_msg_size(strp, strp->anchor); - if (sz < 0) { - tls_strp_abort_strp(strp, sz); + if (sz < 0) return sz; - } strp->stm.full_len = sz; diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c index 549d1ea01a72..daac9fd4be7e 100644 --- a/net/tls/tls_sw.c +++ b/net/tls/tls_sw.c @@ -1384,7 +1384,8 @@ tls_rx_rec_wait(struct sock *sk, struct sk_psock *psock, bool nonblock, return sock_intr_errno(timeo); } - tls_strp_msg_load(&ctx->strp, released); + if (unlikely(!tls_strp_msg_load(&ctx->strp, released))) + return tls_rx_rec_wait(sk, psock, nonblock, false); return 1; } @@ -1807,6 +1808,9 @@ int decrypt_skb(struct sock *sk, struct scatterlist *sgout) return tls_decrypt_sg(sk, NULL, sgout, &darg); } +/* All records returned from a recvmsg() call must have the same type. + * 0 is not a valid content type. Use it as "no type reported, yet". + */ static int tls_record_content_type(struct msghdr *msg, struct tls_msg *tlm, u8 *control) { @@ -2050,8 +2054,10 @@ int tls_sw_recvmsg(struct sock *sk, if (err < 0) goto end; + /* process_rx_list() will set @control if it processed any records */ copied = err; - if (len <= copied || (copied && control != TLS_RECORD_TYPE_DATA) || rx_more) + if (len <= copied || rx_more || + (control && control != TLS_RECORD_TYPE_DATA)) goto end; target = sock_rcvlowat(sk, flags & MSG_WAITALL, len); @@ -2468,8 +2474,7 @@ int tls_rx_msg_size(struct tls_strparser *strp, struct sk_buff *skb) return data_len + TLS_HEADER_SIZE; read_failure: - tls_err_abort(strp->sk, ret); - + tls_strp_abort_strp(strp, ret); return ret; } |