netfs: Support encryption on Unbuffered/DIO write

Support unbuffered and direct I/O writes to an encrypted file.  This may
require making an RMW cycle if the write is not appropriately aligned with
respect to the crypto blocks.

Signed-off-by: David Howells <dhowells@redhat.com>
cc: Paulo Alcantara <pc@manguebit.org>
cc: netfs@lists.linux.dev
cc: linux-fsdevel@vger.kernel.org
cc: linux-mm@kvack.org
diff --git a/fs/netfs/direct_write.c b/fs/netfs/direct_write.c
index 3d108f5..4866304 100644
--- a/fs/netfs/direct_write.c
+++ b/fs/netfs/direct_write.c
@@ -10,26 +10,251 @@
 #include "internal.h"
 
 /*
+ * Perform a read to a buffer from the server, slicing up the region to be read
+ * according to the network rsize.
+ */
+static bool netfs_rmw_read_one(struct netfs_io_request *rreq, struct bvecq *bq)
+{
+	struct netfs_io_stream *stream = &rreq->io_streams[0];
+	size_t len = 0;
+	int ret = 0;
+
+	for (int i = 0; i < bq->nr_slots; i++)
+		len += bq->bv[i].bv_len;
+
+	rreq->start		= bq->fpos;
+	rreq->len		= len;
+	stream->issue_from	= bq->fpos;
+	stream->buffered	= len;
+
+	do {
+		struct netfs_io_subrequest *subreq;
+
+		subreq = netfs_alloc_subrequest(rreq, NETFS_DOWNLOAD_FROM_SERVER);
+		if (!subreq) {
+			ret = -ENOMEM;
+			break;
+		}
+
+		subreq->start	= stream->issue_from;
+		subreq->len	= stream->buffered;
+
+		spin_lock(&rreq->lock);
+		list_add_tail(&subreq->rreq_link, &stream->subrequests);
+		trace_netfs_sreq(subreq, netfs_sreq_trace_added);
+		spin_unlock(&rreq->lock);
+
+		netfs_stat(&netfs_n_rh_download);
+		rreq->netfs_ops->issue_read(subreq);
+
+		cond_resched();
+	} while (stream->buffered > 0);
+
+	return ret;
+}
+
+/*
+ * Perform the read side of an RMW write.  We're supplied with a chain of one
+ * or two buffers into which we should read directly.
+ */
+static ssize_t netfs_rmw_read(struct netfs_io_request *wreq, struct bvecq *bq)
+{
+	struct netfs_io_request *rreq;
+	struct netfs_io_stream *stream;
+	ssize_t ret;
+
+	_enter("RMW:R=%x %llx", wreq->debug_id, bq->fpos);
+
+	rreq = netfs_alloc_request(wreq->mapping, NULL, bq->fpos, 0, NETFS_RMW_READ);
+	if (IS_ERR(rreq))
+		return PTR_ERR(rreq);
+	stream = &rreq->io_streams[0];
+
+	stream->dispatch_cursor.bvecq = bvecq_get(bq);
+	stream->dispatch_cursor.slot = 0;
+	stream->dispatch_cursor.offset = 0;
+
+	bvecq_pos_set(&rreq->encrypt_cursor, &stream->dispatch_cursor);
+	bvecq_pos_set(&rreq->bounce_copy, &stream->dispatch_cursor);
+	bvecq_pos_set(&rreq->bounce_collect, &stream->dispatch_cursor);
+
+	__set_bit(NETFS_RREQ_CONTENT_ENCRYPTION, &rreq->flags);
+
+	netfs_rmw_read_one(rreq, bq);
+	if (bq->next)
+		netfs_rmw_read_one(rreq, bq->next);
+
+	ret = netfs_wait_for_read(rreq);
+	netfs_put_request(rreq, netfs_rreq_trace_put_return);
+	return ret;
+}
+
+/*
+ * Read gaps at either end of the bounce buffer that need to be filled for an
+ * RMW cycle.
+ */
+static ssize_t netfs_unbuffered_rmw(struct netfs_io_request *wreq,
+				    struct netfs_io_subrequest *subreq,
+				    unsigned long long to,
+				    unsigned long long end)
+{
+	struct bvecq *before = NULL, *after = NULL;
+	size_t bsize = wreq->crypto_bsize;
+	int ret;
+
+	_enter("%llx,%llx", to, end);
+
+	/* Build a buffer chain to cover the gaps.  If we have two gaps, they
+	 * must be discontiguous and so we will need two separate bvecqs for
+	 * that; however, if the entire write spans at most two pages, just do
+	 * one read for both gaps plus the middle.
+	 */
+	if (subreq->start < wreq->start) {
+		before = bvecq_alloc_one(2, GFP_KERNEL);
+		if (!before)
+			return -ENOMEM;
+		before->fpos = subreq->start;
+		before->bv[0] = wreq->encrypt_cursor.bvecq->bv[wreq->encrypt_cursor.slot];
+		before->bv[0].bv_offset += wreq->encrypt_cursor.offset;
+		before->bv[0].bv_len = bsize;
+		bvecq_filled_to(before, 1);
+	}
+
+	if (to == end && subreq->start + subreq->len < to) {
+		size_t part = end - subreq->start;
+
+		if (before && part <= 2 * PAGE_SIZE) {
+			struct bvecq *bq;
+			size_t page0 = PAGE_SIZE - before->bv[0].bv_offset;
+			int slot;
+
+			if (part <= page0) {
+				before->bv[0].bv_len = part;
+				bvecq_filled_to(before, 1);
+				goto do_it;
+			}
+
+			bq = wreq->encrypt_cursor.bvecq;
+			slot = wreq->encrypt_cursor.slot + 1;
+			if (slot > bq->nr_slots) {
+				bq = bq->next;
+				slot = 0;
+			}
+
+			before->bv[0].bv_len = page0;
+			before->bv[1] = bq->bv[slot];
+			before->bv[1].bv_len = part - page0;
+			bvecq_filled_to(before, 2);
+			goto do_it;
+		}
+
+		after = bvecq_alloc_one(1, GFP_KERNEL);
+		if (!after) {
+			ret = -ENOMEM;
+			goto out;
+		}
+		after->fpos = to - bsize;
+		after->bv[0] = wreq->bounce_alloc.bvecq->bv[wreq->bounce_alloc.slot];
+		after->bv[0].bv_offset = to & (PAGE_SIZE - 1);
+		after->bv[0].bv_len = bsize;
+		bvecq_filled_to(after, 1);
+	}
+
+	if (before && after) {
+		before->next = after;
+		after->prev = before;
+		after->discontig = true;
+	}
+
+do_it:
+	ret = netfs_rmw_read(wreq, before ?: after);
+
+out:
+	bvecq_put(before ?: after);
+	return ret;
+}
+
+/*
+ * Load data into the bounce buffer and encrypt it.
+ */
+static int netfs_unbuffered_load_bounce(struct netfs_io_subrequest *subreq)
+{
+	struct netfs_io_request *wreq = subreq->rreq;
+	struct netfs_io_stream *stream = &wreq->io_streams[subreq->stream_nr];
+	unsigned long long to, end;
+	ssize_t got;
+	size_t amount = subreq->len;
+	int ret;
+
+	/* Expand the bounce buffer as needed. */
+	to = round_up(subreq->start + subreq->len, wreq->crypto_bsize);
+	end = round_up(wreq->start + wreq->len, wreq->crypto_bsize);
+
+	if (wreq->bounce_alloc_to < to) {
+		ret = bvecq_buffer_add_space(&wreq->bounce_alloc,
+					     &wreq->bounce_alloc_to,
+					     to, end, false, GFP_KERNEL);
+		if (ret < 0)
+			return ret;
+	}
+
+	/* Perform RMW if there are gaps to be filled. */
+	if (stream->issue_from < wreq->start ||
+	    (to == end && subreq->start + subreq->len < to)) {
+		ret = netfs_unbuffered_rmw(wreq, subreq, to, end);
+		if (ret < 0)
+			return ret;
+	}
+
+	/* Copy in the data.  We need to work around any RMW gaps. */
+	if (subreq->start < wreq->start + wreq->submitted)
+		amount -= wreq->submitted;
+	if (amount > wreq->len - wreq->submitted)
+		amount = wreq->len - wreq->submitted;
+
+	got = bvecq_copy_to_bvecq(&wreq->copy_cursor, &wreq->bounce_copy, amount);
+	if (got != amount)
+		return -EFAULT;
+
+	/* And then encrypt the data in-place. */
+	return netfs_encrypt(wreq, to, GFP_KERNEL);
+}
+
+/*
  * Prepare the buffer for an unbuffered/DIO write.
  */
 int netfs_prepare_unbuffered_write_buffer(struct netfs_io_subrequest *subreq,
 					  unsigned int max_segs, bool copy)
 {
-	struct netfs_io_stream *stream = &subreq->rreq->io_streams[subreq->stream_nr];
+	struct netfs_io_request *wreq = subreq->rreq;
+	struct netfs_io_stream *stream = &wreq->io_streams[subreq->stream_nr];
 	ssize_t got;
 	size_t len;
+	int ret;
+
+	len = subreq->len;
+	if (test_bit(NETFS_RREQ_CONTENT_ENCRYPTION, &wreq->flags) &&
+	    len >= wreq->crypto_bsize)
+		len = round_down(len, wreq->crypto_bsize);
+
+	if (test_bit(NETFS_RREQ_USE_BOUNCE_BUFFER, &wreq->flags)) {
+		ret = netfs_unbuffered_load_bounce(subreq);
+		if (ret < 0)
+			return ret;
+	}
 
 	bvecq_pos_set(&subreq->dispatch_pos, &stream->dispatch_cursor);
+
 	if (copy) {
-		got = bvecq_extract(&stream->dispatch_cursor, subreq->len, max_segs,
+		got = bvecq_extract(&stream->dispatch_cursor, len, max_segs,
 				    &subreq->content.bvecq);
 		if (got < 0)
 			return -ENOMEM;
 		len = got;
 	} else {
 		bvecq_pos_set(&subreq->content, &stream->dispatch_cursor);
-		len = bvecq_slice(&stream->dispatch_cursor, subreq->len, max_segs,
-				  &subreq->nr_segs);
+
+		len = bvecq_slice(&stream->dispatch_cursor, len, max_segs, &subreq->nr_segs);
 	}
 
 	if (len < subreq->len) {
@@ -143,12 +368,11 @@ static int netfs_unbuffered_write(struct netfs_io_request *wreq)
 
 	stream->issue_from = wreq->start;
 	stream->buffered = wreq->len;
-	bvecq_pos_set(&stream->dispatch_cursor, &wreq->load_cursor);
-	bvecq_pos_set(&wreq->collect_cursor, &stream->dispatch_cursor);
 
 	if (wreq->origin == NETFS_DIO_WRITE)
 		inode_dio_begin(wreq->inode);
 
+
 	for (;;) {
 		bool retry = false;
 
@@ -243,20 +467,15 @@ ssize_t netfs_unbuffered_write_iter_locked(struct kiocb *iocb, struct iov_iter *
 					   struct netfs_group *netfs_group)
 {
 	struct netfs_io_request *wreq;
+	struct netfs_io_stream *stream;
 	unsigned long long start = iocb->ki_pos;
 	unsigned long long end = start + iov_iter_count(iter);
-	ssize_t ret;
+	ssize_t ret, n;
 	size_t len = iov_iter_count(iter);
 	bool async = !is_sync_kiocb(iocb);
 
 	_enter("");
 
-	/* We're going to need a bounce buffer if what we transmit is going to
-	 * be different in some way to the source buffer, e.g. because it gets
-	 * encrypted/compressed or because it needs expanding to a block size.
-	 */
-	// TODO
-
 	_debug("uw %llx-%llx", start, end);
 
 	wreq = netfs_create_write_req(iocb->ki_filp->f_mapping, iocb->ki_filp, start,
@@ -266,33 +485,81 @@ ssize_t netfs_unbuffered_write_iter_locked(struct kiocb *iocb, struct iov_iter *
 		return PTR_ERR(wreq);
 
 	wreq->len = iov_iter_count(iter);
-	wreq->io_streams[0].avail = true;
+	wreq->submitted = 0;
+	stream = &wreq->io_streams[0];
+	stream->avail = true;
 	trace_netfs_write(wreq, (iocb->ki_flags & IOCB_DIRECT ?
 				 netfs_write_trace_dio_write :
 				 netfs_write_trace_unbuffered_write));
 
-	{
-		/* If this is an async op and we're not using a bounce buffer,
-		 * we have to save the source buffer as the iterator is only
-		 * good until we return.  In such a case, extract an iterator
-		 * to represent as much of the the output buffer as we can
-		 * manage.  Note that the extraction may shorten the request.
-		 */
-		ssize_t n = netfs_extract_iter(iter, len, INT_MAX, iocb->ki_pos,
-					       &wreq->load_cursor.bvecq, 0);
-
-		if (n < 0) {
-			ret = n;
-			goto error_put;
-		}
-		wreq->len = n;
-		_debug("dio-write %zx/%zx %u/%u",
-		       n, len, wreq->load_cursor.bvecq->nr_slots,
-		       wreq->load_cursor.bvecq->max_slots);
+	/* If we're going to do encryption or compression, we're going to need
+	 * a bounce buffer.
+	 */
+	if (test_bit(NETFS_RREQ_CONTENT_ENCRYPTION, &wreq->flags)) {
+		__set_bit(NETFS_RREQ_USE_BOUNCE_BUFFER, &wreq->flags);
+		__set_bit(NETFS_RREQ_CRYPT_IN_PLACE, &wreq->flags);
 	}
 
-	/* Copy the data into the bounce buffer and encrypt it. */
-	// TODO
+	/* Transcribe the source buffer into a bvecq chain.  We need this for
+	 * async writes because the source iterator but we also use it for
+	 * unencrypted sync writes as it gets passed to the filesystem in this
+	 * form.
+	 *
+	 * We extract as much of the buffer as we can manage, but this may
+	 * shorten the request.
+	 */
+	n = netfs_extract_iter(iter, len, INT_MAX, iocb->ki_pos,
+			       &wreq->load_cursor.bvecq, 0);
+	if (n < 0) {
+		ret = n;
+		goto error_put;
+	}
+	wreq->len = n;
+	_debug("dio-write %zx/%zx %u/%u",
+	       n, len, wreq->load_cursor.bvecq->nr_slots,
+	       wreq->load_cursor.bvecq->max_slots);
+
+	/* Set up the bounce buffer if we need it.  Allow for padding the
+	 * request out to the crypo block size and allocate at least one bvecq
+	 * into it.
+	 */
+	if (test_bit(NETFS_RREQ_USE_BOUNCE_BUFFER, &wreq->flags)) {
+		size_t bsize = wreq->crypto_bsize;
+		size_t gap;
+
+		bvecq_pos_set(&wreq->copy_cursor, &wreq->load_cursor);
+
+		wreq->bounce_alloc_to = round_down(wreq->start, bsize);
+		atomic64_set(&wreq->encrypted_to, wreq->bounce_alloc_to);
+		gap = wreq->start - wreq->bounce_alloc_to;
+
+		stream->issue_from = wreq->bounce_alloc_to;
+		stream->buffered = round_up(wreq->len + gap, bsize);
+
+		ret = bvecq_buffer_init(&wreq->bounce_alloc, wreq->debug_id);
+		if (ret < 0)
+			goto error_put;
+
+		/*   0--->
+		 *  ~--+-------+-------+-------+-------+---~
+		 *     :       |       |       |       |
+		 *     :spent  |encrypt|copied |alloced|
+		 *     :       |-ed    |       |       |
+		 *  ~--+-------+-------+-------+-------+---~
+		 *                                     ^bounce_alloc
+		 *                             ^bounce_copy
+		 *                     ^encrypt_cursor
+		 *             ^dispatch_cursor
+		 */
+		bvecq_pos_set(&wreq->bounce_copy, &wreq->bounce_alloc);
+		bvecq_pos_set(&wreq->encrypt_cursor, &wreq->bounce_alloc);
+		bvecq_pos_set(&stream->dispatch_cursor, &wreq->bounce_alloc);
+
+	} else {
+		stream->buffered = ret;
+		stream->issue_from = wreq->start;
+		bvecq_pos_set(&stream->dispatch_cursor, &wreq->load_cursor);
+	}
 
 	/* Dispatch the write. */
 	__set_bit(NETFS_RREQ_UPLOAD_TO_SERVER, &wreq->flags);
diff --git a/fs/netfs/main.c b/fs/netfs/main.c
index a64bd28..73216a0 100644
--- a/fs/netfs/main.c
+++ b/fs/netfs/main.c
@@ -44,6 +44,7 @@ static const char *netfs_origins[nr__netfs_io_origin] = {
 	[NETFS_WRITEBACK]		= "WB",
 	[NETFS_WRITEBACK_SINGLE]	= "W1",
 	[NETFS_WRITETHROUGH]		= "WT",
+	[NETFS_RMW_READ]		= "RM",
 	[NETFS_UNBUFFERED_WRITE]	= "UW",
 	[NETFS_DIO_WRITE]		= "DW",
 	[NETFS_PGPRIV2_COPY_TO_CACHE]	= "2C",
diff --git a/fs/netfs/objects.c b/fs/netfs/objects.c
index f618581..d295580 100644
--- a/fs/netfs/objects.c
+++ b/fs/netfs/objects.c
@@ -138,6 +138,7 @@ static void netfs_deinit_request(struct netfs_io_request *rreq)
 	if (rreq->cache_resources.ops)
 		rreq->cache_resources.ops->end_operation(&rreq->cache_resources);
 	bvecq_pos_unset(&rreq->load_cursor);
+	bvecq_pos_unset(&rreq->copy_cursor);
 	bvecq_pos_unset(&rreq->collect_cursor);
 	bvecq_pos_unset(&rreq->bounce_alloc);
 	bvecq_pos_unset(&rreq->encrypt_cursor);
diff --git a/fs/netfs/read_collect.c b/fs/netfs/read_collect.c
index 739713c..7bda8cf 100644
--- a/fs/netfs/read_collect.c
+++ b/fs/netfs/read_collect.c
@@ -437,6 +437,7 @@ bool netfs_read_collection(struct netfs_io_request *rreq)
 	case NETFS_UNBUFFERED_READ:
 	case NETFS_DIO_READ:
 	case NETFS_READ_GAPS:
+	case NETFS_RMW_READ:
 		netfs_rreq_assess_dio(rreq);
 		break;
 	case NETFS_READ_SINGLE:
diff --git a/include/linux/netfs.h b/include/linux/netfs.h
index 06e2265..394cf3e 100644
--- a/include/linux/netfs.h
+++ b/include/linux/netfs.h
@@ -231,6 +231,7 @@ enum netfs_io_origin {
 	NETFS_WRITEBACK,		/* This write was triggered by writepages */
 	NETFS_WRITEBACK_SINGLE,		/* This monolithic write was triggered by writepages */
 	NETFS_WRITETHROUGH,		/* This write was made by netfs_perform_write() */
+	NETFS_RMW_READ,			/* This is an unbuffered read for RMW */
 	NETFS_UNBUFFERED_WRITE,		/* This is an unbuffered write */
 	NETFS_DIO_WRITE,		/* This is a direct I/O write */
 	NETFS_PGPRIV2_COPY_TO_CACHE,	/* [DEPRECATED] This is writing read data to the cache */
@@ -260,6 +261,7 @@ struct netfs_io_request {
 	struct netfs_group	*group;		/* Writeback group being written back */
 	struct bvecq		*spare;		/* Advance allocation of bvecq */
 	struct bvecq_pos	load_cursor;	/* Point at which new folios are loaded in */
+	struct bvecq_pos	copy_cursor;	/* Copy-out point from main buffer list */
 	struct bvecq_pos	collect_cursor;	/* Clear-up point of I/O buffer */
 	struct bvecq_pos	bounce_alloc;	/* Bounce buffer allocation point */
 	struct bvecq_pos	encrypt_cursor;	/* Encrypt dispatch point */
diff --git a/include/trace/events/netfs.h b/include/trace/events/netfs.h
index b143a0d..2415631 100644
--- a/include/trace/events/netfs.h
+++ b/include/trace/events/netfs.h
@@ -44,6 +44,7 @@
 	EM(NETFS_WRITEBACK,			"WB")		\
 	EM(NETFS_WRITEBACK_SINGLE,		"W1")		\
 	EM(NETFS_WRITETHROUGH,			"WT")		\
+	EM(NETFS_RMW_READ,			"RM")		\
 	EM(NETFS_UNBUFFERED_WRITE,		"UW")		\
 	EM(NETFS_DIO_WRITE,			"DW")		\
 	E_(NETFS_PGPRIV2_COPY_TO_CACHE,		"2C")