From 163ab96b52ae2bb2d8f188cd29f0b570610f9007 Mon Sep 17 00:00:00 2001 From: Jakub Kicinski Date: Sun, 6 Oct 2019 21:09:27 -0700 Subject: net: sockmap: use bitmap for copy info Don't use bool array in struct sk_msg_sg, save 12 bytes. Signed-off-by: Jakub Kicinski Reviewed-by: Dirk van der Merwe Signed-off-by: David S. Miller --- include/linux/skmsg.h | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) (limited to 'include/linux/skmsg.h') diff --git a/include/linux/skmsg.h b/include/linux/skmsg.h index e4b3fb4bb77c..fe80d537945d 100644 --- a/include/linux/skmsg.h +++ b/include/linux/skmsg.h @@ -28,13 +28,14 @@ struct sk_msg_sg { u32 end; u32 size; u32 copybreak; - bool copy[MAX_MSG_FRAGS]; + unsigned long copy; /* The extra element is used for chaining the front and sections when * the list becomes partitioned (e.g. end < start). The crypto APIs * require the chaining. */ struct scatterlist data[MAX_MSG_FRAGS + 1]; }; +static_assert(BITS_PER_LONG >= MAX_MSG_FRAGS); /* UAPI in filter.c depends on struct sk_msg_sg being first element. */ struct sk_msg { @@ -227,7 +228,7 @@ static inline void sk_msg_compute_data_pointers(struct sk_msg *msg) { struct scatterlist *sge = sk_msg_elem(msg, msg->sg.start); - if (msg->sg.copy[msg->sg.start]) { + if (test_bit(msg->sg.start, &msg->sg.copy)) { msg->data = NULL; msg->data_end = NULL; } else { @@ -246,7 +247,7 @@ static inline void sk_msg_page_add(struct sk_msg *msg, struct page *page, sg_set_page(sge, page, len, offset); sg_unmark_end(sge); - msg->sg.copy[msg->sg.end] = true; + __set_bit(msg->sg.end, &msg->sg.copy); msg->sg.size += len; sk_msg_iter_next(msg, end); } @@ -254,7 +255,10 @@ static inline void sk_msg_page_add(struct sk_msg *msg, struct page *page, static inline void sk_msg_sg_copy(struct sk_msg *msg, u32 i, bool copy_state) { do { - msg->sg.copy[i] = copy_state; + if (copy_state) + __set_bit(i, &msg->sg.copy); + else + __clear_bit(i, &msg->sg.copy); sk_msg_iter_var_next(i); if (i == msg->sg.end) break; -- cgit v1.2.3 From 683916f6a84023407761d843048f1aea486b2612 Mon Sep 17 00:00:00 2001 From: Jakub Kicinski Date: Mon, 4 Nov 2019 15:36:57 -0800 Subject: net/tls: fix sk_msg trim on fallback to copy mode sk_msg_trim() tries to only update curr pointer if it falls into the trimmed region. The logic, however, does not take into the account pointer wrapping that sk_msg_iter_var_prev() does nor (as John points out) the fact that msg->sg is a ring buffer. This means that when the message was trimmed completely, the new curr pointer would have the value of MAX_MSG_FRAGS - 1, which is neither smaller than any other value, nor would it actually be correct. Special case the trimming to 0 length a little bit and rework the comparison between curr and end to take into account wrapping. This bug caused the TLS code to not copy all of the message, if zero copy filled in fewer sg entries than memcopy would need. Big thanks to Alexander Potapenko for the non-KMSAN reproducer. v2: - take into account that msg->sg is a ring buffer (John). Link: https://lore.kernel.org/netdev/20191030160542.30295-1-jakub.kicinski@netronome.com/ (v1) Fixes: d829e9c4112b ("tls: convert to generic sk_msg interface") Reported-by: syzbot+f8495bff23a879a6d0bd@syzkaller.appspotmail.com Reported-by: syzbot+6f50c99e8f6194bf363f@syzkaller.appspotmail.com Co-developed-by: John Fastabend Signed-off-by: Jakub Kicinski Signed-off-by: John Fastabend Signed-off-by: David S. Miller --- include/linux/skmsg.h | 9 ++++++--- net/core/skmsg.c | 20 +++++++++++++++----- 2 files changed, 21 insertions(+), 8 deletions(-) (limited to 'include/linux/skmsg.h') diff --git a/include/linux/skmsg.h b/include/linux/skmsg.h index e4b3fb4bb77c..ce7055259877 100644 --- a/include/linux/skmsg.h +++ b/include/linux/skmsg.h @@ -139,6 +139,11 @@ static inline void sk_msg_apply_bytes(struct sk_psock *psock, u32 bytes) } } +static inline u32 sk_msg_iter_dist(u32 start, u32 end) +{ + return end >= start ? end - start : end + (MAX_MSG_FRAGS - start); +} + #define sk_msg_iter_var_prev(var) \ do { \ if (var == 0) \ @@ -198,9 +203,7 @@ static inline u32 sk_msg_elem_used(const struct sk_msg *msg) if (sk_msg_full(msg)) return MAX_MSG_FRAGS; - return msg->sg.end >= msg->sg.start ? - msg->sg.end - msg->sg.start : - msg->sg.end + (MAX_MSG_FRAGS - msg->sg.start); + return sk_msg_iter_dist(msg->sg.start, msg->sg.end); } static inline struct scatterlist *sk_msg_elem(struct sk_msg *msg, int which) diff --git a/net/core/skmsg.c b/net/core/skmsg.c index cf390e0aa73d..ad31e4e53d0a 100644 --- a/net/core/skmsg.c +++ b/net/core/skmsg.c @@ -270,18 +270,28 @@ void sk_msg_trim(struct sock *sk, struct sk_msg *msg, int len) msg->sg.data[i].length -= trim; sk_mem_uncharge(sk, trim); + /* Adjust copybreak if it falls into the trimmed part of last buf */ + if (msg->sg.curr == i && msg->sg.copybreak > msg->sg.data[i].length) + msg->sg.copybreak = msg->sg.data[i].length; out: - /* If we trim data before curr pointer update copybreak and current - * so that any future copy operations start at new copy location. + sk_msg_iter_var_next(i); + msg->sg.end = i; + + /* If we trim data a full sg elem before curr pointer update + * copybreak and current so that any future copy operations + * start at new copy location. * However trimed data that has not yet been used in a copy op * does not require an update. */ - if (msg->sg.curr >= i) { + if (!msg->sg.size) { + msg->sg.curr = msg->sg.start; + msg->sg.copybreak = 0; + } else if (sk_msg_iter_dist(msg->sg.start, msg->sg.curr) >= + sk_msg_iter_dist(msg->sg.start, msg->sg.end)) { + sk_msg_iter_var_prev(i); msg->sg.curr = i; msg->sg.copybreak = msg->sg.data[i].length; } - sk_msg_iter_var_next(i); - msg->sg.end = i; } EXPORT_SYMBOL_GPL(sk_msg_trim); -- cgit v1.2.3 From 031097d9e079e40dce401031d1012e83d80eaf01 Mon Sep 17 00:00:00 2001 From: Jakub Kicinski Date: Wed, 27 Nov 2019 12:16:41 -0800 Subject: net: skmsg: fix TLS 1.3 crash with full sk_msg TLS 1.3 started using the entry at the end of the SG array for chaining-in the single byte content type entry. This mostly works: [ E E E E E E . . ] ^ ^ start end E < content type / [ E E E E E E C . ] ^ ^ start end (Where E denotes a populated SG entry; C denotes a chaining entry.) If the array is full, however, the end will point to the start: [ E E E E E E E E ] ^ start end And we end up overwriting the start: E < content type / [ C E E E E E E E ] ^ start end The sg array is supposed to be a circular buffer with start and end markers pointing anywhere. In case where start > end (i.e. the circular buffer has "wrapped") there is an extra entry reserved at the end to chain the two halves together. [ E E E E E E . . l ] (Where l is the reserved entry for "looping" back to front. As suggested by John, let's reserve another entry for chaining SG entries after the main circular buffer. Note that this entry has to be pointed to by the end entry so its position is not fixed. Examples of full messages: [ E E E E E E E E . l ] ^ ^ start end <---------------. [ E E . E E E E E E l ] ^ ^ end start Now the end will always point to an unused entry, so TLS 1.3 can always use it. Fixes: 130b392c6cd6 ("net: tls: Add tls 1.3 support") Signed-off-by: Jakub Kicinski Reviewed-by: Simon Horman Signed-off-by: David S. Miller --- include/linux/skmsg.h | 28 ++++++++++++++-------------- net/core/filter.c | 8 ++++---- net/core/skmsg.c | 2 +- net/ipv4/tcp_bpf.c | 2 +- 4 files changed, 20 insertions(+), 20 deletions(-) (limited to 'include/linux/skmsg.h') diff --git a/include/linux/skmsg.h b/include/linux/skmsg.h index 6cb077b646a5..ef7031f8a304 100644 --- a/include/linux/skmsg.h +++ b/include/linux/skmsg.h @@ -14,6 +14,7 @@ #include #define MAX_MSG_FRAGS MAX_SKB_FRAGS +#define NR_MSG_FRAG_IDS (MAX_MSG_FRAGS + 1) enum __sk_action { __SK_DROP = 0, @@ -29,13 +30,15 @@ struct sk_msg_sg { u32 size; u32 copybreak; unsigned long copy; - /* The extra element is used for chaining the front and sections when - * the list becomes partitioned (e.g. end < start). The crypto APIs - * require the chaining. + /* The extra two elements: + * 1) used for chaining the front and sections when the list becomes + * partitioned (e.g. end < start). The crypto APIs require the + * chaining; + * 2) to chain tailer SG entries after the message. */ - struct scatterlist data[MAX_MSG_FRAGS + 1]; + struct scatterlist data[MAX_MSG_FRAGS + 2]; }; -static_assert(BITS_PER_LONG >= MAX_MSG_FRAGS); +static_assert(BITS_PER_LONG >= NR_MSG_FRAG_IDS); /* UAPI in filter.c depends on struct sk_msg_sg being first element. */ struct sk_msg { @@ -142,13 +145,13 @@ static inline void sk_msg_apply_bytes(struct sk_psock *psock, u32 bytes) static inline u32 sk_msg_iter_dist(u32 start, u32 end) { - return end >= start ? end - start : end + (MAX_MSG_FRAGS - start); + return end >= start ? end - start : end + (NR_MSG_FRAG_IDS - start); } #define sk_msg_iter_var_prev(var) \ do { \ if (var == 0) \ - var = MAX_MSG_FRAGS - 1; \ + var = NR_MSG_FRAG_IDS - 1; \ else \ var--; \ } while (0) @@ -156,7 +159,7 @@ static inline u32 sk_msg_iter_dist(u32 start, u32 end) #define sk_msg_iter_var_next(var) \ do { \ var++; \ - if (var == MAX_MSG_FRAGS) \ + if (var == NR_MSG_FRAG_IDS) \ var = 0; \ } while (0) @@ -173,9 +176,9 @@ static inline void sk_msg_clear_meta(struct sk_msg *msg) static inline void sk_msg_init(struct sk_msg *msg) { - BUILD_BUG_ON(ARRAY_SIZE(msg->sg.data) - 1 != MAX_MSG_FRAGS); + BUILD_BUG_ON(ARRAY_SIZE(msg->sg.data) - 1 != NR_MSG_FRAG_IDS); memset(msg, 0, sizeof(*msg)); - sg_init_marker(msg->sg.data, MAX_MSG_FRAGS); + sg_init_marker(msg->sg.data, NR_MSG_FRAG_IDS); } static inline void sk_msg_xfer(struct sk_msg *dst, struct sk_msg *src, @@ -196,14 +199,11 @@ static inline void sk_msg_xfer_full(struct sk_msg *dst, struct sk_msg *src) static inline bool sk_msg_full(const struct sk_msg *msg) { - return (msg->sg.end == msg->sg.start) && msg->sg.size; + return sk_msg_iter_dist(msg->sg.start, msg->sg.end) == MAX_MSG_FRAGS; } static inline u32 sk_msg_elem_used(const struct sk_msg *msg) { - if (sk_msg_full(msg)) - return MAX_MSG_FRAGS; - return sk_msg_iter_dist(msg->sg.start, msg->sg.end); } diff --git a/net/core/filter.c b/net/core/filter.c index b0ed048585ba..f1e703eed3d2 100644 --- a/net/core/filter.c +++ b/net/core/filter.c @@ -2299,7 +2299,7 @@ BPF_CALL_4(bpf_msg_pull_data, struct sk_msg *, msg, u32, start, WARN_ON_ONCE(last_sge == first_sge); shift = last_sge > first_sge ? last_sge - first_sge - 1 : - MAX_SKB_FRAGS - first_sge + last_sge - 1; + NR_MSG_FRAG_IDS - first_sge + last_sge - 1; if (!shift) goto out; @@ -2308,8 +2308,8 @@ BPF_CALL_4(bpf_msg_pull_data, struct sk_msg *, msg, u32, start, do { u32 move_from; - if (i + shift >= MAX_MSG_FRAGS) - move_from = i + shift - MAX_MSG_FRAGS; + if (i + shift >= NR_MSG_FRAG_IDS) + move_from = i + shift - NR_MSG_FRAG_IDS; else move_from = i + shift; if (move_from == msg->sg.end) @@ -2323,7 +2323,7 @@ BPF_CALL_4(bpf_msg_pull_data, struct sk_msg *, msg, u32, start, } while (1); msg->sg.end = msg->sg.end - shift > msg->sg.end ? - msg->sg.end - shift + MAX_MSG_FRAGS : + msg->sg.end - shift + NR_MSG_FRAG_IDS : msg->sg.end - shift; out: msg->data = sg_virt(&msg->sg.data[first_sge]) + start - offset; diff --git a/net/core/skmsg.c b/net/core/skmsg.c index a469d2124f3f..ded2d5227678 100644 --- a/net/core/skmsg.c +++ b/net/core/skmsg.c @@ -421,7 +421,7 @@ static int sk_psock_skb_ingress(struct sk_psock *psock, struct sk_buff *skb) copied = skb->len; msg->sg.start = 0; msg->sg.size = copied; - msg->sg.end = num_sge == MAX_MSG_FRAGS ? 0 : num_sge; + msg->sg.end = num_sge; msg->skb = skb; sk_psock_queue_msg(psock, msg); diff --git a/net/ipv4/tcp_bpf.c b/net/ipv4/tcp_bpf.c index 8a56e09cfb0e..e38705165ac9 100644 --- a/net/ipv4/tcp_bpf.c +++ b/net/ipv4/tcp_bpf.c @@ -301,7 +301,7 @@ EXPORT_SYMBOL_GPL(tcp_bpf_sendmsg_redir); static int tcp_bpf_send_verdict(struct sock *sk, struct sk_psock *psock, struct sk_msg *msg, int *copied, int flags) { - bool cork = false, enospc = msg->sg.start == msg->sg.end; + bool cork = false, enospc = sk_msg_full(msg); struct sock *sk_redir; u32 tosend, delta = 0; int ret; -- cgit v1.2.3