rxrpc: Fix NULL pointer deref due to call->conn being cleared on disconnect

When a call is disconnected, the connection pointer from the call is
cleared to make sure it isn't used again and to prevent further attempted
transmission for the call.  Unfortunately, there might be a daemon trying
to use it at the same time to transmit a packet.

Fix this by keeping call->conn set, but setting a flag on the call to
indicate disconnection instead.

Remove also the bits in the transmission functions where the conn pointer is
checked and a ref taken under spinlock as this is now redundant.

Fixes: 8d94aa381dab ("rxrpc: Calls shouldn't hold socket refs")
Signed-off-by: David Howells <dhowells@redhat.com>
diff --git a/net/rxrpc/ar-internal.h b/net/rxrpc/ar-internal.h
index 94441fe..7d730c4 100644
--- a/net/rxrpc/ar-internal.h
+++ b/net/rxrpc/ar-internal.h
@@ -490,6 +490,7 @@
 	RXRPC_CALL_RX_HEARD,		/* The peer responded at least once to this call */
 	RXRPC_CALL_RX_UNDERRUN,		/* Got data underrun */
 	RXRPC_CALL_IS_INTR,		/* The call is interruptible */
+	RXRPC_CALL_DISCONNECTED,	/* The call has been disconnected */
 };
 
 /*
diff --git a/net/rxrpc/call_object.c b/net/rxrpc/call_object.c
index a31c18c..dbdbc4f 100644
--- a/net/rxrpc/call_object.c
+++ b/net/rxrpc/call_object.c
@@ -493,7 +493,7 @@
 
 	_debug("RELEASE CALL %p (%d CONN %p)", call, call->debug_id, conn);
 
-	if (conn)
+	if (conn && !test_bit(RXRPC_CALL_DISCONNECTED, &call->flags))
 		rxrpc_disconnect_call(call);
 	if (call->security)
 		call->security->free_call_crypto(call);
@@ -569,6 +569,7 @@
 	struct rxrpc_call *call = container_of(rcu, struct rxrpc_call, rcu);
 	struct rxrpc_net *rxnet = call->rxnet;
 
+	rxrpc_put_connection(call->conn);
 	rxrpc_put_peer(call->peer);
 	kfree(call->rxtx_buffer);
 	kfree(call->rxtx_annotations);
@@ -590,7 +591,6 @@
 
 	ASSERTCMP(call->state, ==, RXRPC_CALL_COMPLETE);
 	ASSERT(test_bit(RXRPC_CALL_RELEASED, &call->flags));
-	ASSERTCMP(call->conn, ==, NULL);
 
 	rxrpc_cleanup_ring(call);
 	rxrpc_free_skb(call->tx_pending, rxrpc_skb_cleaned);
diff --git a/net/rxrpc/conn_client.c b/net/rxrpc/conn_client.c
index 376370c..ea7d4c2 100644
--- a/net/rxrpc/conn_client.c
+++ b/net/rxrpc/conn_client.c
@@ -785,6 +785,7 @@
 	u32 cid;
 
 	spin_lock(&conn->channel_lock);
+	set_bit(RXRPC_CALL_DISCONNECTED, &call->flags);
 
 	cid = call->cid;
 	if (cid) {
@@ -792,7 +793,6 @@
 		chan = &conn->channels[channel];
 	}
 	trace_rxrpc_client(conn, channel, rxrpc_client_chan_disconnect);
-	call->conn = NULL;
 
 	/* Calls that have never actually been assigned a channel can simply be
 	 * discarded.  If the conn didn't get used either, it will follow
@@ -908,7 +908,6 @@
 	spin_unlock(&rxnet->client_conn_cache_lock);
 out_2:
 	spin_unlock(&conn->channel_lock);
-	rxrpc_put_connection(conn);
 	_leave("");
 	return;
 
diff --git a/net/rxrpc/conn_object.c b/net/rxrpc/conn_object.c
index 38d718e..c0b3154 100644
--- a/net/rxrpc/conn_object.c
+++ b/net/rxrpc/conn_object.c
@@ -171,6 +171,8 @@
 
 	_enter("%d,%x", conn->debug_id, call->cid);
 
+	set_bit(RXRPC_CALL_DISCONNECTED, &call->flags);
+
 	if (rcu_access_pointer(chan->call) == call) {
 		/* Save the result of the call so that we can repeat it if necessary
 		 * through the channel, whilst disposing of the actual call record.
@@ -223,9 +225,7 @@
 	__rxrpc_disconnect_call(conn, call);
 	spin_unlock(&conn->channel_lock);
 
-	call->conn = NULL;
 	conn->idle_timestamp = jiffies;
-	rxrpc_put_connection(conn);
 }
 
 /*
diff --git a/net/rxrpc/output.c b/net/rxrpc/output.c
index 935bb60..bad3d24 100644
--- a/net/rxrpc/output.c
+++ b/net/rxrpc/output.c
@@ -129,7 +129,7 @@
 int rxrpc_send_ack_packet(struct rxrpc_call *call, bool ping,
 			  rxrpc_serial_t *_serial)
 {
-	struct rxrpc_connection *conn = NULL;
+	struct rxrpc_connection *conn;
 	struct rxrpc_ack_buffer *pkt;
 	struct msghdr msg;
 	struct kvec iov[2];
@@ -139,18 +139,14 @@
 	int ret;
 	u8 reason;
 
-	spin_lock_bh(&call->lock);
-	if (call->conn)
-		conn = rxrpc_get_connection_maybe(call->conn);
-	spin_unlock_bh(&call->lock);
-	if (!conn)
+	if (test_bit(RXRPC_CALL_DISCONNECTED, &call->flags))
 		return -ECONNRESET;
 
 	pkt = kzalloc(sizeof(*pkt), GFP_KERNEL);
-	if (!pkt) {
-		rxrpc_put_connection(conn);
+	if (!pkt)
 		return -ENOMEM;
-	}
+
+	conn = call->conn;
 
 	msg.msg_name	= &call->peer->srx.transport;
 	msg.msg_namelen	= call->peer->srx.transport_len;
@@ -244,7 +240,6 @@
 	}
 
 out:
-	rxrpc_put_connection(conn);
 	kfree(pkt);
 	return ret;
 }
@@ -254,7 +249,7 @@
  */
 int rxrpc_send_abort_packet(struct rxrpc_call *call)
 {
-	struct rxrpc_connection *conn = NULL;
+	struct rxrpc_connection *conn;
 	struct rxrpc_abort_buffer pkt;
 	struct msghdr msg;
 	struct kvec iov[1];
@@ -271,13 +266,11 @@
 	    test_bit(RXRPC_CALL_TX_LAST, &call->flags))
 		return 0;
 
-	spin_lock_bh(&call->lock);
-	if (call->conn)
-		conn = rxrpc_get_connection_maybe(call->conn);
-	spin_unlock_bh(&call->lock);
-	if (!conn)
+	if (test_bit(RXRPC_CALL_DISCONNECTED, &call->flags))
 		return -ECONNRESET;
 
+	conn = call->conn;
+
 	msg.msg_name	= &call->peer->srx.transport;
 	msg.msg_namelen	= call->peer->srx.transport_len;
 	msg.msg_control	= NULL;
@@ -312,8 +305,6 @@
 		trace_rxrpc_tx_packet(call->debug_id, &pkt.whdr,
 				      rxrpc_tx_point_call_abort);
 	rxrpc_tx_backoff(call, ret);
-
-	rxrpc_put_connection(conn);
 	return ret;
 }