smb: client: make use of smbdirect_socket.send_io.lcredits.*

This makes the logic to prevent on overflow of
the send submission queue with ib_post_send() easier.

As we first get a local credit and then a remote credit
before we mark us as pending.

For now we'll keep the logic around smbdirect_socket.send_io.pending.*,
but that will likely change or be removed completely.

The server will get a similar logic soon, so
we'll be able to share the send code in future.

Cc: Steve French <smfrench@gmail.com>
Cc: Tom Talpey <tom@talpey.com>
Cc: Long Li <longli@microsoft.com>
Cc: Namjae Jeon <linkinjeon@kernel.org>
Cc: linux-cifs@vger.kernel.org
Cc: samba-technical@lists.samba.org
Signed-off-by: Stefan Metzmacher <metze@samba.org>
Signed-off-by: Steve French <stfrench@microsoft.com>
diff --git a/fs/smb/client/smbdirect.c b/fs/smb/client/smbdirect.c
index 49e2df3..f2da694 100644
--- a/fs/smb/client/smbdirect.c
+++ b/fs/smb/client/smbdirect.c
@@ -172,6 +172,7 @@ static void smbd_disconnect_wake_up_all(struct smbdirect_socket *sc)
 	 * in order to notice the broken connection.
 	 */
 	wake_up_all(&sc->status_wait);
+	wake_up_all(&sc->send_io.lcredits.wait_queue);
 	wake_up_all(&sc->send_io.credits.wait_queue);
 	wake_up_all(&sc->send_io.pending.dec_wait_queue);
 	wake_up_all(&sc->send_io.pending.zero_wait_queue);
@@ -495,6 +496,7 @@ static void send_done(struct ib_cq *cq, struct ib_wc *wc)
 	struct smbdirect_send_io *request =
 		container_of(wc->wr_cqe, struct smbdirect_send_io, cqe);
 	struct smbdirect_socket *sc = request->socket;
+	int lcredits = 0;
 
 	log_rdma_send(INFO, "smbdirect_send_io 0x%p completed wc->status=%s\n",
 		request, ib_wc_status_msg(wc->status));
@@ -504,22 +506,24 @@ static void send_done(struct ib_cq *cq, struct ib_wc *wc)
 			request->sge[i].addr,
 			request->sge[i].length,
 			DMA_TO_DEVICE);
+	mempool_free(request, sc->send_io.mem.pool);
+	lcredits += 1;
 
 	if (wc->status != IB_WC_SUCCESS || wc->opcode != IB_WC_SEND) {
 		if (wc->status != IB_WC_WR_FLUSH_ERR)
 			log_rdma_send(ERR, "wc->status=%s wc->opcode=%d\n",
 				ib_wc_status_msg(wc->status), wc->opcode);
-		mempool_free(request, sc->send_io.mem.pool);
 		smbd_disconnect_rdma_connection(sc);
 		return;
 	}
 
+	atomic_add(lcredits, &sc->send_io.lcredits.count);
+	wake_up(&sc->send_io.lcredits.wait_queue);
+
 	if (atomic_dec_and_test(&sc->send_io.pending.count))
 		wake_up(&sc->send_io.pending.zero_wait_queue);
 
 	wake_up(&sc->send_io.pending.dec_wait_queue);
-
-	mempool_free(request, sc->send_io.mem.pool);
 }
 
 static void dump_smbdirect_negotiate_resp(struct smbdirect_negotiate_resp *resp)
@@ -567,6 +571,7 @@ static bool process_negotiation_response(
 		log_rdma_event(ERR, "error: credits_granted==0\n");
 		return false;
 	}
+	atomic_set(&sc->send_io.lcredits.count, sp->send_credit_target);
 	atomic_set(&sc->send_io.credits.count, le16_to_cpu(packet->credits_granted));
 
 	if (le32_to_cpu(packet->preferred_send_size) > sp->max_recv_size) {
@@ -1114,6 +1119,24 @@ static int smbd_post_send_iter(struct smbdirect_socket *sc,
 	struct smbdirect_data_transfer *packet;
 	int new_credits = 0;
 
+wait_lcredit:
+	/* Wait for local send credits */
+	rc = wait_event_interruptible(sc->send_io.lcredits.wait_queue,
+		atomic_read(&sc->send_io.lcredits.count) > 0 ||
+		sc->status != SMBDIRECT_SOCKET_CONNECTED);
+	if (rc)
+		goto err_wait_lcredit;
+
+	if (sc->status != SMBDIRECT_SOCKET_CONNECTED) {
+		log_outgoing(ERR, "disconnected not sending on wait_credit\n");
+		rc = -EAGAIN;
+		goto err_wait_lcredit;
+	}
+	if (unlikely(atomic_dec_return(&sc->send_io.lcredits.count) < 0)) {
+		atomic_inc(&sc->send_io.lcredits.count);
+		goto wait_lcredit;
+	}
+
 wait_credit:
 	/* Wait for send credits. A SMBD packet needs one credit */
 	rc = wait_event_interruptible(sc->send_io.credits.wait_queue,
@@ -1132,23 +1155,6 @@ static int smbd_post_send_iter(struct smbdirect_socket *sc,
 		goto wait_credit;
 	}
 
-wait_send_queue:
-	wait_event(sc->send_io.pending.dec_wait_queue,
-		atomic_read(&sc->send_io.pending.count) < sp->send_credit_target ||
-		sc->status != SMBDIRECT_SOCKET_CONNECTED);
-
-	if (sc->status != SMBDIRECT_SOCKET_CONNECTED) {
-		log_outgoing(ERR, "disconnected not sending on wait_send_queue\n");
-		rc = -EAGAIN;
-		goto err_wait_send_queue;
-	}
-
-	if (unlikely(atomic_inc_return(&sc->send_io.pending.count) >
-				sp->send_credit_target)) {
-		atomic_dec(&sc->send_io.pending.count);
-		goto wait_send_queue;
-	}
-
 	request = mempool_alloc(sc->send_io.mem.pool, GFP_KERNEL);
 	if (!request) {
 		rc = -ENOMEM;
@@ -1229,10 +1235,21 @@ static int smbd_post_send_iter(struct smbdirect_socket *sc,
 		     le32_to_cpu(packet->data_length),
 		     le32_to_cpu(packet->remaining_data_length));
 
+	/*
+	 * Now that we got a local and a remote credit
+	 * we add us as pending
+	 */
+	atomic_inc(&sc->send_io.pending.count);
+
 	rc = smbd_post_send(sc, request);
 	if (!rc)
 		return 0;
 
+	if (atomic_dec_and_test(&sc->send_io.pending.count))
+		wake_up(&sc->send_io.pending.zero_wait_queue);
+
+	wake_up(&sc->send_io.pending.dec_wait_queue);
+
 err_dma:
 	for (i = 0; i < request->num_sge; i++)
 		if (request->sge[i].addr)
@@ -1246,14 +1263,14 @@ static int smbd_post_send_iter(struct smbdirect_socket *sc,
 	atomic_sub(new_credits, &sc->recv_io.credits.count);
 
 err_alloc:
-	if (atomic_dec_and_test(&sc->send_io.pending.count))
-		wake_up(&sc->send_io.pending.zero_wait_queue);
-
-err_wait_send_queue:
-	/* roll back send credits and pending */
 	atomic_inc(&sc->send_io.credits.count);
+	wake_up(&sc->send_io.credits.wait_queue);
 
 err_wait_credit:
+	atomic_inc(&sc->send_io.lcredits.count);
+	wake_up(&sc->send_io.lcredits.wait_queue);
+
+err_wait_lcredit:
 	return rc;
 }