rxrpc: Extract useful fields from a received ACK to skb priv data

Extract useful fields from a received ACK packet into the skb private data
early on in the process of parsing incoming packets.  This makes the ACK
fields available even before we've matched the ACK up to a call and will
allow us to deal with path MTU discovery probe responses even after the
relevant call has been completed.

Signed-off-by: David Howells <dhowells@redhat.com>
cc: Marc Dionne <marc.dionne@auristor.com>
cc: "David S. Miller" <davem@davemloft.net>
cc: Eric Dumazet <edumazet@google.com>
cc: Jakub Kicinski <kuba@kernel.org>
cc: Paolo Abeni <pabeni@redhat.com>
cc: linux-afs@lists.infradead.org
cc: netdev@vger.kernel.org
diff --git a/net/rxrpc/ar-internal.h b/net/rxrpc/ar-internal.h
index 21ecac2..08c0a32 100644
--- a/net/rxrpc/ar-internal.h
+++ b/net/rxrpc/ar-internal.h
@@ -198,8 +198,8 @@
  * - max 48 bytes (struct sk_buff::cb)
  */
 struct rxrpc_skb_priv {
-	struct rxrpc_connection *conn;	/* Connection referred to (poke packet) */
 	union {
+		struct rxrpc_connection *conn;	/* Connection referred to (poke packet) */
 		struct {
 			u16		offset;		/* Offset of data */
 			u16		len;		/* Length of data */
@@ -208,9 +208,12 @@
 		};
 		struct {
 			rxrpc_seq_t	first_ack;	/* First packet in acks table */
+			rxrpc_seq_t	prev_ack;	/* Highest seq seen */
+			rxrpc_serial_t	acked_serial;	/* Packet in response to (or 0) */
+			u8		reason;		/* Reason for ack */
 			u8		nr_acks;	/* Number of acks+nacks */
 			u8		nr_nacks;	/* Number of nacks */
-		};
+		} ack;
 	};
 	struct rxrpc_host_header hdr;	/* RxRPC packet header from this packet */
 };
diff --git a/net/rxrpc/call_event.c b/net/rxrpc/call_event.c
index 6c5e305..7bbb685 100644
--- a/net/rxrpc/call_event.c
+++ b/net/rxrpc/call_event.c
@@ -93,12 +93,12 @@
 		sp = rxrpc_skb(ack_skb);
 		ack = (void *)ack_skb->data + sizeof(struct rxrpc_wire_header);
 
-		for (i = 0; i < sp->nr_acks; i++) {
+		for (i = 0; i < sp->ack.nr_acks; i++) {
 			rxrpc_seq_t seq;
 
 			if (ack->acks[i] & 1)
 				continue;
-			seq = sp->first_ack + i;
+			seq = sp->ack.first_ack + i;
 			if (after(txb->seq, transmitted))
 				break;
 			if (after(txb->seq, seq))
diff --git a/net/rxrpc/input.c b/net/rxrpc/input.c
index 09cce1d..3dedb8c 100644
--- a/net/rxrpc/input.c
+++ b/net/rxrpc/input.c
@@ -710,20 +710,19 @@
 					      rxrpc_seq_t seq)
 {
 	struct sk_buff *skb = call->cong_last_nack;
-	struct rxrpc_ackpacket ack;
 	struct rxrpc_skb_priv *sp = rxrpc_skb(skb);
 	unsigned int i, new_acks = 0, retained_nacks = 0;
-	rxrpc_seq_t old_seq = sp->first_ack;
-	u8 *acks = skb->data + sizeof(struct rxrpc_wire_header) + sizeof(ack);
+	rxrpc_seq_t old_seq = sp->ack.first_ack;
+	u8 *acks = skb->data + sizeof(struct rxrpc_wire_header) + sizeof(struct rxrpc_ackpacket);
 
-	if (after_eq(seq, old_seq + sp->nr_acks)) {
-		summary->nr_new_acks += sp->nr_nacks;
-		summary->nr_new_acks += seq - (old_seq + sp->nr_acks);
+	if (after_eq(seq, old_seq + sp->ack.nr_acks)) {
+		summary->nr_new_acks += sp->ack.nr_nacks;
+		summary->nr_new_acks += seq - (old_seq + sp->ack.nr_acks);
 		summary->nr_retained_nacks = 0;
 	} else if (seq == old_seq) {
-		summary->nr_retained_nacks = sp->nr_nacks;
+		summary->nr_retained_nacks = sp->ack.nr_nacks;
 	} else {
-		for (i = 0; i < sp->nr_acks; i++) {
+		for (i = 0; i < sp->ack.nr_acks; i++) {
 			if (acks[i] == RXRPC_ACK_TYPE_NACK) {
 				if (before(old_seq + i, seq))
 					new_acks++;
@@ -736,7 +735,7 @@
 		summary->nr_retained_nacks = retained_nacks;
 	}
 
-	return old_seq + sp->nr_acks;
+	return old_seq + sp->ack.nr_acks;
 }
 
 /*
@@ -756,10 +755,10 @@
 {
 	struct rxrpc_skb_priv *sp = rxrpc_skb(skb);
 	unsigned int i, old_nacks = 0;
-	rxrpc_seq_t lowest_nak = seq + sp->nr_acks;
+	rxrpc_seq_t lowest_nak = seq + sp->ack.nr_acks;
 	u8 *acks = skb->data + sizeof(struct rxrpc_wire_header) + sizeof(struct rxrpc_ackpacket);
 
-	for (i = 0; i < sp->nr_acks; i++) {
+	for (i = 0; i < sp->ack.nr_acks; i++) {
 		if (acks[i] == RXRPC_ACK_TYPE_ACK) {
 			summary->nr_acks++;
 			if (after_eq(seq, since))
@@ -771,7 +770,7 @@
 				old_nacks++;
 			} else {
 				summary->nr_new_nacks++;
-				sp->nr_nacks++;
+				sp->ack.nr_nacks++;
 			}
 
 			if (before(seq, lowest_nak))
@@ -832,7 +831,6 @@
 static void rxrpc_input_ack(struct rxrpc_call *call, struct sk_buff *skb)
 {
 	struct rxrpc_ack_summary summary = { 0 };
-	struct rxrpc_ackpacket ack;
 	struct rxrpc_skb_priv *sp = rxrpc_skb(skb);
 	struct rxrpc_acktrailer trailer;
 	rxrpc_serial_t ack_serial, acked_serial;
@@ -841,29 +839,24 @@
 
 	_enter("");
 
-	offset = sizeof(struct rxrpc_wire_header);
-	if (skb_copy_bits(skb, offset, &ack, sizeof(ack)) < 0)
-		return rxrpc_proto_abort(call, 0, rxrpc_badmsg_short_ack);
-	offset += sizeof(ack);
+	offset = sizeof(struct rxrpc_wire_header) + sizeof(struct rxrpc_ackpacket);
 
-	ack_serial = sp->hdr.serial;
-	acked_serial = ntohl(ack.serial);
-	first_soft_ack = ntohl(ack.firstPacket);
-	prev_pkt = ntohl(ack.previousPacket);
-	hard_ack = first_soft_ack - 1;
-	nr_acks = ack.nAcks;
-	sp->first_ack = first_soft_ack;
-	sp->nr_acks = nr_acks;
-	summary.ack_reason = (ack.reason < RXRPC_ACK__INVALID ?
-			      ack.reason : RXRPC_ACK__INVALID);
+	ack_serial	= sp->hdr.serial;
+	acked_serial	= sp->ack.acked_serial;
+	first_soft_ack	= sp->ack.first_ack;
+	prev_pkt	= sp->ack.prev_ack;
+	nr_acks		= sp->ack.nr_acks;
+	hard_ack	= first_soft_ack - 1;
+	summary.ack_reason = (sp->ack.reason < RXRPC_ACK__INVALID ?
+			      sp->ack.reason : RXRPC_ACK__INVALID);
 
 	trace_rxrpc_rx_ack(call, ack_serial, acked_serial,
 			   first_soft_ack, prev_pkt,
 			   summary.ack_reason, nr_acks);
-	rxrpc_inc_stat(call->rxnet, stat_rx_acks[ack.reason]);
+	rxrpc_inc_stat(call->rxnet, stat_rx_acks[summary.ack_reason]);
 
 	if (acked_serial != 0) {
-		switch (ack.reason) {
+		switch (summary.ack_reason) {
 		case RXRPC_ACK_PING_RESPONSE:
 			rxrpc_complete_rtt_probe(call, skb->tstamp, acked_serial, ack_serial,
 						 rxrpc_rtt_rx_ping_response);
@@ -883,7 +876,7 @@
 	 * indicates that the client address changed due to NAT.  The server
 	 * lost the call because it switched to a different peer.
 	 */
-	if (unlikely(ack.reason == RXRPC_ACK_EXCEEDS_WINDOW) &&
+	if (unlikely(summary.ack_reason == RXRPC_ACK_EXCEEDS_WINDOW) &&
 	    first_soft_ack == 1 &&
 	    prev_pkt == 0 &&
 	    rxrpc_is_client_call(call)) {
@@ -896,7 +889,7 @@
 	 * indicate a change of address.  However, we can retransmit the call
 	 * if we still have it buffered to the beginning.
 	 */
-	if (unlikely(ack.reason == RXRPC_ACK_OUT_OF_SEQUENCE) &&
+	if (unlikely(summary.ack_reason == RXRPC_ACK_OUT_OF_SEQUENCE) &&
 	    first_soft_ack == 1 &&
 	    prev_pkt == 0 &&
 	    call->acks_hard_ack == 0 &&
@@ -937,7 +930,7 @@
 	call->acks_first_seq = first_soft_ack;
 	call->acks_prev_seq = prev_pkt;
 
-	switch (ack.reason) {
+	switch (summary.ack_reason) {
 	case RXRPC_ACK_PING:
 		break;
 	default:
@@ -994,7 +987,7 @@
 	rxrpc_congestion_management(call, skb, &summary, acked_serial);
 
 send_response:
-	if (ack.reason == RXRPC_ACK_PING)
+	if (summary.ack_reason == RXRPC_ACK_PING)
 		rxrpc_send_ACK(call, RXRPC_ACK_PING_RESPONSE, ack_serial,
 			       rxrpc_propose_ack_respond_to_ping);
 	else if (sp->hdr.flags & RXRPC_REQUEST_ACK)
diff --git a/net/rxrpc/io_thread.c b/net/rxrpc/io_thread.c
index 4a3a08a..0300baa 100644
--- a/net/rxrpc/io_thread.c
+++ b/net/rxrpc/io_thread.c
@@ -124,6 +124,7 @@
 				 struct sk_buff *skb)
 {
 	struct rxrpc_wire_header whdr;
+	struct rxrpc_ackpacket ack;
 
 	/* dig out the RxRPC connection details */
 	if (skb_copy_bits(skb, 0, &whdr, sizeof(whdr)) < 0)
@@ -141,6 +142,16 @@
 	sp->hdr.securityIndex	= whdr.securityIndex;
 	sp->hdr._rsvd		= ntohs(whdr._rsvd);
 	sp->hdr.serviceId	= ntohs(whdr.serviceId);
+
+	if (sp->hdr.type == RXRPC_PACKET_TYPE_ACK) {
+		if (skb_copy_bits(skb, sizeof(whdr), &ack, sizeof(ack)) < 0)
+			return rxrpc_bad_message(skb, rxrpc_badmsg_short_ack);
+		sp->ack.first_ack	= ntohl(ack.firstPacket);
+		sp->ack.prev_ack	= ntohl(ack.previousPacket);
+		sp->ack.acked_serial	= ntohl(ack.serial);
+		sp->ack.reason		= ack.reason;
+		sp->ack.nr_acks		= ack.nAcks;
+	}
 	return true;
 }