summaryrefslogtreecommitdiff
path: root/fs/smb/server/transport_rdma.c
diff options
context:
space:
mode:
Diffstat (limited to 'fs/smb/server/transport_rdma.c')
-rw-r--r--fs/smb/server/transport_rdma.c183
1 files changed, 125 insertions, 58 deletions
diff --git a/fs/smb/server/transport_rdma.c b/fs/smb/server/transport_rdma.c
index 5466aa8c39b1..6550bd9f002c 100644
--- a/fs/smb/server/transport_rdma.c
+++ b/fs/smb/server/transport_rdma.c
@@ -554,7 +554,7 @@ static void recv_done(struct ib_cq *cq, struct ib_wc *wc)
case SMB_DIRECT_MSG_DATA_TRANSFER: {
struct smb_direct_data_transfer *data_transfer =
(struct smb_direct_data_transfer *)recvmsg->packet;
- unsigned int data_length;
+ u32 remaining_data_length, data_offset, data_length;
int avail_recvmsg_count, receive_credits;
if (wc->byte_len <
@@ -564,15 +564,25 @@ static void recv_done(struct ib_cq *cq, struct ib_wc *wc)
return;
}
+ remaining_data_length = le32_to_cpu(data_transfer->remaining_data_length);
data_length = le32_to_cpu(data_transfer->data_length);
- if (data_length) {
- if (wc->byte_len < sizeof(struct smb_direct_data_transfer) +
- (u64)data_length) {
- put_recvmsg(t, recvmsg);
- smb_direct_disconnect_rdma_connection(t);
- return;
- }
+ data_offset = le32_to_cpu(data_transfer->data_offset);
+ if (wc->byte_len < data_offset ||
+ wc->byte_len < (u64)data_offset + data_length) {
+ put_recvmsg(t, recvmsg);
+ smb_direct_disconnect_rdma_connection(t);
+ return;
+ }
+ if (remaining_data_length > t->max_fragmented_recv_size ||
+ data_length > t->max_fragmented_recv_size ||
+ (u64)remaining_data_length + (u64)data_length >
+ (u64)t->max_fragmented_recv_size) {
+ put_recvmsg(t, recvmsg);
+ smb_direct_disconnect_rdma_connection(t);
+ return;
+ }
+ if (data_length) {
if (t->full_packet_received)
recvmsg->first_segment = true;
@@ -1209,78 +1219,130 @@ static int smb_direct_writev(struct ksmbd_transport *t,
bool need_invalidate, unsigned int remote_key)
{
struct smb_direct_transport *st = smb_trans_direct_transfort(t);
- int remaining_data_length;
- int start, i, j;
- int max_iov_size = st->max_send_size -
+ size_t remaining_data_length;
+ size_t iov_idx;
+ size_t iov_ofs;
+ size_t max_iov_size = st->max_send_size -
sizeof(struct smb_direct_data_transfer);
int ret;
- struct kvec vec;
struct smb_direct_send_ctx send_ctx;
+ int error = 0;
if (st->status != SMB_DIRECT_CS_CONNECTED)
return -ENOTCONN;
//FIXME: skip RFC1002 header..
+ if (WARN_ON_ONCE(niovs <= 1 || iov[0].iov_len != 4))
+ return -EINVAL;
buflen -= 4;
+ iov_idx = 1;
+ iov_ofs = 0;
remaining_data_length = buflen;
ksmbd_debug(RDMA, "Sending smb (RDMA): smb_len=%u\n", buflen);
smb_direct_send_ctx_init(st, &send_ctx, need_invalidate, remote_key);
- start = i = 1;
- buflen = 0;
- while (true) {
- buflen += iov[i].iov_len;
- if (buflen > max_iov_size) {
- if (i > start) {
- remaining_data_length -=
- (buflen - iov[i].iov_len);
- ret = smb_direct_post_send_data(st, &send_ctx,
- &iov[start], i - start,
- remaining_data_length);
- if (ret)
+ while (remaining_data_length) {
+ struct kvec vecs[SMB_DIRECT_MAX_SEND_SGES - 1]; /* minus smbdirect hdr */
+ size_t possible_bytes = max_iov_size;
+ size_t possible_vecs;
+ size_t bytes = 0;
+ size_t nvecs = 0;
+
+ /*
+ * For the last message remaining_data_length should be
+ * have been 0 already!
+ */
+ if (WARN_ON_ONCE(iov_idx >= niovs)) {
+ error = -EINVAL;
+ goto done;
+ }
+
+ /*
+ * We have 2 factors which limit the arguments we pass
+ * to smb_direct_post_send_data():
+ *
+ * 1. The number of supported sges for the send,
+ * while one is reserved for the smbdirect header.
+ * And we currently need one SGE per page.
+ * 2. The number of negotiated payload bytes per send.
+ */
+ possible_vecs = min_t(size_t, ARRAY_SIZE(vecs), niovs - iov_idx);
+
+ while (iov_idx < niovs && possible_vecs && possible_bytes) {
+ struct kvec *v = &vecs[nvecs];
+ int page_count;
+
+ v->iov_base = ((u8 *)iov[iov_idx].iov_base) + iov_ofs;
+ v->iov_len = min_t(size_t,
+ iov[iov_idx].iov_len - iov_ofs,
+ possible_bytes);
+ page_count = get_buf_page_count(v->iov_base, v->iov_len);
+ if (page_count > possible_vecs) {
+ /*
+ * If the number of pages in the buffer
+ * is to much (because we currently require
+ * one SGE per page), we need to limit the
+ * length.
+ *
+ * We know possible_vecs is at least 1,
+ * so we always keep the first page.
+ *
+ * We need to calculate the number extra
+ * pages (epages) we can also keep.
+ *
+ * We calculate the number of bytes in the
+ * first page (fplen), this should never be
+ * larger than v->iov_len because page_count is
+ * at least 2, but adding a limitation feels
+ * better.
+ *
+ * Then we calculate the number of bytes (elen)
+ * we can keep for the extra pages.
+ */
+ size_t epages = possible_vecs - 1;
+ size_t fpofs = offset_in_page(v->iov_base);
+ size_t fplen = min_t(size_t, PAGE_SIZE - fpofs, v->iov_len);
+ size_t elen = min_t(size_t, v->iov_len - fplen, epages*PAGE_SIZE);
+
+ v->iov_len = fplen + elen;
+ page_count = get_buf_page_count(v->iov_base, v->iov_len);
+ if (WARN_ON_ONCE(page_count > possible_vecs)) {
+ /*
+ * Something went wrong in the above
+ * logic...
+ */
+ error = -EINVAL;
goto done;
- } else {
- /* iov[start] is too big, break it */
- int nvec = (buflen + max_iov_size - 1) /
- max_iov_size;
-
- for (j = 0; j < nvec; j++) {
- vec.iov_base =
- (char *)iov[start].iov_base +
- j * max_iov_size;
- vec.iov_len =
- min_t(int, max_iov_size,
- buflen - max_iov_size * j);
- remaining_data_length -= vec.iov_len;
- ret = smb_direct_post_send_data(st, &send_ctx, &vec, 1,
- remaining_data_length);
- if (ret)
- goto done;
}
- i++;
- if (i == niovs)
- break;
}
- start = i;
- buflen = 0;
- } else {
- i++;
- if (i == niovs) {
- /* send out all remaining vecs */
- remaining_data_length -= buflen;
- ret = smb_direct_post_send_data(st, &send_ctx,
- &iov[start], i - start,
- remaining_data_length);
- if (ret)
- goto done;
- break;
+ possible_vecs -= page_count;
+ nvecs += 1;
+ possible_bytes -= v->iov_len;
+ bytes += v->iov_len;
+
+ iov_ofs += v->iov_len;
+ if (iov_ofs >= iov[iov_idx].iov_len) {
+ iov_idx += 1;
+ iov_ofs = 0;
}
}
+
+ remaining_data_length -= bytes;
+
+ ret = smb_direct_post_send_data(st, &send_ctx,
+ vecs, nvecs,
+ remaining_data_length);
+ if (unlikely(ret)) {
+ error = ret;
+ goto done;
+ }
}
done:
ret = smb_direct_flush_send_list(st, &send_ctx, true);
+ if (unlikely(!ret && error))
+ ret = error;
/*
* As an optimization, we don't wait for individual I/O to finish
@@ -1744,6 +1806,11 @@ static int smb_direct_init_params(struct smb_direct_transport *t,
return -EINVAL;
}
+ if (device->attrs.max_send_sge < SMB_DIRECT_MAX_SEND_SGES) {
+ pr_err("warning: device max_send_sge = %d too small\n",
+ device->attrs.max_send_sge);
+ return -EINVAL;
+ }
if (device->attrs.max_recv_sge < SMB_DIRECT_MAX_RECV_SGES) {
pr_err("warning: device max_recv_sge = %d too small\n",
device->attrs.max_recv_sge);
@@ -1767,7 +1834,7 @@ static int smb_direct_init_params(struct smb_direct_transport *t,
cap->max_send_wr = max_send_wrs;
cap->max_recv_wr = t->recv_credit_max;
- cap->max_send_sge = max_sge_per_wr;
+ cap->max_send_sge = SMB_DIRECT_MAX_SEND_SGES;
cap->max_recv_sge = SMB_DIRECT_MAX_RECV_SGES;
cap->max_inline_data = 0;
cap->max_rdma_ctxs = t->max_rw_credits;