Merge tag 'for_linus' of git://git.kernel.org/pub/scm/linux/kernel/git/mst/vhost

Pull virtio updates from Michael Tsirkin:
 "Just fixes and cleanups this time around. The mapping cleanups are
  preparing the ground for new features, though"

* tag 'for_linus' of git://git.kernel.org/pub/scm/linux/kernel/git/mst/vhost:
  virtio-vdpa: Drop redundant conversion to bool
  vduse: Use fixed 4KB bounce pages for non-4KB page size
  vduse: switch to use virtio map API instead of DMA API
  vdpa: introduce map ops
  vdpa: support virtio_map
  virtio: introduce map ops in virtio core
  virtio_ring: rename dma_handle to map_handle
  virtio: introduce virtio_map container union
  virtio: rename dma helpers
  virtio_ring: switch to use dma_{map|unmap}_page()
  virtio_ring: constify virtqueue pointer for DMA helpers
  virtio_balloon: Remove redundant __GFP_NOWARN
  vhost: vringh: Fix copy_to_iter return value check
  vhost: vringh: Modify the return value check
diff --git a/drivers/net/virtio_net.c b/drivers/net/virtio_net.c
index 7da5a37..a757cbc 100644
--- a/drivers/net/virtio_net.c
+++ b/drivers/net/virtio_net.c
@@ -962,7 +962,7 @@ static void virtnet_rq_unmap(struct receive_queue *rq, void *buf, u32 len)
 	if (dma->need_sync && len) {
 		offset = buf - (head + sizeof(*dma));
 
-		virtqueue_dma_sync_single_range_for_cpu(rq->vq, dma->addr,
+		virtqueue_map_sync_single_range_for_cpu(rq->vq, dma->addr,
 							offset, len,
 							DMA_FROM_DEVICE);
 	}
@@ -970,8 +970,8 @@ static void virtnet_rq_unmap(struct receive_queue *rq, void *buf, u32 len)
 	if (dma->ref)
 		return;
 
-	virtqueue_dma_unmap_single_attrs(rq->vq, dma->addr, dma->len,
-					 DMA_FROM_DEVICE, DMA_ATTR_SKIP_CPU_SYNC);
+	virtqueue_unmap_single_attrs(rq->vq, dma->addr, dma->len,
+				     DMA_FROM_DEVICE, DMA_ATTR_SKIP_CPU_SYNC);
 	put_page(page);
 }
 
@@ -1038,13 +1038,13 @@ static void *virtnet_rq_alloc(struct receive_queue *rq, u32 size, gfp_t gfp)
 
 		dma->len = alloc_frag->size - sizeof(*dma);
 
-		addr = virtqueue_dma_map_single_attrs(rq->vq, dma + 1,
-						      dma->len, DMA_FROM_DEVICE, 0);
-		if (virtqueue_dma_mapping_error(rq->vq, addr))
+		addr = virtqueue_map_single_attrs(rq->vq, dma + 1,
+						  dma->len, DMA_FROM_DEVICE, 0);
+		if (virtqueue_map_mapping_error(rq->vq, addr))
 			return NULL;
 
 		dma->addr = addr;
-		dma->need_sync = virtqueue_dma_need_sync(rq->vq, addr);
+		dma->need_sync = virtqueue_map_need_sync(rq->vq, addr);
 
 		/* Add a reference to dma to prevent the entire dma from
 		 * being released during error handling. This reference
@@ -5942,9 +5942,9 @@ static int virtnet_xsk_pool_enable(struct net_device *dev,
 	if (!rq->xsk_buffs)
 		return -ENOMEM;
 
-	hdr_dma = virtqueue_dma_map_single_attrs(sq->vq, &xsk_hdr, vi->hdr_len,
-						 DMA_TO_DEVICE, 0);
-	if (virtqueue_dma_mapping_error(sq->vq, hdr_dma)) {
+	hdr_dma = virtqueue_map_single_attrs(sq->vq, &xsk_hdr, vi->hdr_len,
+					     DMA_TO_DEVICE, 0);
+	if (virtqueue_map_mapping_error(sq->vq, hdr_dma)) {
 		err = -ENOMEM;
 		goto err_free_buffs;
 	}
@@ -5973,8 +5973,8 @@ static int virtnet_xsk_pool_enable(struct net_device *dev,
 err_rq:
 	xsk_pool_dma_unmap(pool, 0);
 err_xsk_map:
-	virtqueue_dma_unmap_single_attrs(rq->vq, hdr_dma, vi->hdr_len,
-					 DMA_TO_DEVICE, 0);
+	virtqueue_unmap_single_attrs(rq->vq, hdr_dma, vi->hdr_len,
+				     DMA_TO_DEVICE, 0);
 err_free_buffs:
 	kvfree(rq->xsk_buffs);
 	return err;
@@ -6001,8 +6001,8 @@ static int virtnet_xsk_pool_disable(struct net_device *dev, u16 qid)
 
 	xsk_pool_dma_unmap(pool, 0);
 
-	virtqueue_dma_unmap_single_attrs(sq->vq, sq->xsk_hdr_dma_addr,
-					 vi->hdr_len, DMA_TO_DEVICE, 0);
+	virtqueue_unmap_single_attrs(sq->vq, sq->xsk_hdr_dma_addr,
+				     vi->hdr_len, DMA_TO_DEVICE, 0);
 	kvfree(rq->xsk_buffs);
 
 	return err;
diff --git a/drivers/vdpa/Kconfig b/drivers/vdpa/Kconfig
index 559fb9d..857cf28 100644
--- a/drivers/vdpa/Kconfig
+++ b/drivers/vdpa/Kconfig
@@ -34,13 +34,7 @@
 
 config VDPA_USER
 	tristate "VDUSE (vDPA Device in Userspace) support"
-	depends on EVENTFD && MMU && HAS_DMA
-	#
-	# This driver incorrectly tries to override the dma_ops.  It should
-	# never have done that, but for now keep it working on architectures
-	# that use dma ops
-	#
-	depends on ARCH_HAS_DMA_OPS
+	depends on EVENTFD && MMU
 	select VHOST_IOTLB
 	select IOMMU_IOVA
 	help
diff --git a/drivers/vdpa/alibaba/eni_vdpa.c b/drivers/vdpa/alibaba/eni_vdpa.c
index ad7f344..e476504 100644
--- a/drivers/vdpa/alibaba/eni_vdpa.c
+++ b/drivers/vdpa/alibaba/eni_vdpa.c
@@ -478,7 +478,8 @@ static int eni_vdpa_probe(struct pci_dev *pdev, const struct pci_device_id *id)
 		return ret;
 
 	eni_vdpa = vdpa_alloc_device(struct eni_vdpa, vdpa,
-				     dev, &eni_vdpa_ops, 1, 1, NULL, false);
+				     dev, &eni_vdpa_ops, NULL,
+				     1, 1, NULL, false);
 	if (IS_ERR(eni_vdpa)) {
 		ENI_ERR(pdev, "failed to allocate vDPA structure\n");
 		return PTR_ERR(eni_vdpa);
@@ -496,7 +497,7 @@ static int eni_vdpa_probe(struct pci_dev *pdev, const struct pci_device_id *id)
 	pci_set_master(pdev);
 	pci_set_drvdata(pdev, eni_vdpa);
 
-	eni_vdpa->vdpa.dma_dev = &pdev->dev;
+	eni_vdpa->vdpa.vmap.dma_dev = &pdev->dev;
 	eni_vdpa->queues = eni_vdpa_get_num_queues(eni_vdpa);
 
 	eni_vdpa->vring = devm_kcalloc(&pdev->dev, eni_vdpa->queues,
diff --git a/drivers/vdpa/ifcvf/ifcvf_main.c b/drivers/vdpa/ifcvf/ifcvf_main.c
index ccf64d7..6658dc7 100644
--- a/drivers/vdpa/ifcvf/ifcvf_main.c
+++ b/drivers/vdpa/ifcvf/ifcvf_main.c
@@ -705,7 +705,8 @@ static int ifcvf_vdpa_dev_add(struct vdpa_mgmt_dev *mdev, const char *name,
 	vf = &ifcvf_mgmt_dev->vf;
 	pdev = vf->pdev;
 	adapter = vdpa_alloc_device(struct ifcvf_adapter, vdpa,
-				    &pdev->dev, &ifc_vdpa_ops, 1, 1, NULL, false);
+				    &pdev->dev, &ifc_vdpa_ops,
+				    NULL, 1, 1, NULL, false);
 	if (IS_ERR(adapter)) {
 		IFCVF_ERR(pdev, "Failed to allocate vDPA structure");
 		return PTR_ERR(adapter);
@@ -713,7 +714,7 @@ static int ifcvf_vdpa_dev_add(struct vdpa_mgmt_dev *mdev, const char *name,
 
 	ifcvf_mgmt_dev->adapter = adapter;
 	adapter->pdev = pdev;
-	adapter->vdpa.dma_dev = &pdev->dev;
+	adapter->vdpa.vmap.dma_dev = &pdev->dev;
 	adapter->vdpa.mdev = mdev;
 	adapter->vf = vf;
 	vdpa_dev = &adapter->vdpa;
diff --git a/drivers/vdpa/mlx5/core/mr.c b/drivers/vdpa/mlx5/core/mr.c
index c7a2027..8870a71 100644
--- a/drivers/vdpa/mlx5/core/mr.c
+++ b/drivers/vdpa/mlx5/core/mr.c
@@ -378,7 +378,7 @@ static int map_direct_mr(struct mlx5_vdpa_dev *mvdev, struct mlx5_vdpa_direct_mr
 	u64 pa, offset;
 	u64 paend;
 	struct scatterlist *sg;
-	struct device *dma = mvdev->vdev.dma_dev;
+	struct device *dma = mvdev->vdev.vmap.dma_dev;
 
 	for (map = vhost_iotlb_itree_first(iotlb, mr->start, mr->end - 1);
 	     map; map = vhost_iotlb_itree_next(map, mr->start, mr->end - 1)) {
@@ -432,7 +432,7 @@ static int map_direct_mr(struct mlx5_vdpa_dev *mvdev, struct mlx5_vdpa_direct_mr
 
 static void unmap_direct_mr(struct mlx5_vdpa_dev *mvdev, struct mlx5_vdpa_direct_mr *mr)
 {
-	struct device *dma = mvdev->vdev.dma_dev;
+	struct device *dma = mvdev->vdev.vmap.dma_dev;
 
 	destroy_direct_mr(mvdev, mr);
 	dma_unmap_sg_attrs(dma, mr->sg_head.sgl, mr->nsg, DMA_BIDIRECTIONAL, 0);
diff --git a/drivers/vdpa/mlx5/net/mlx5_vnet.c b/drivers/vdpa/mlx5/net/mlx5_vnet.c
index 0ed2fc2..82034ef 100644
--- a/drivers/vdpa/mlx5/net/mlx5_vnet.c
+++ b/drivers/vdpa/mlx5/net/mlx5_vnet.c
@@ -3395,14 +3395,17 @@ static int mlx5_vdpa_reset_map(struct vdpa_device *vdev, unsigned int asid)
 	return err;
 }
 
-static struct device *mlx5_get_vq_dma_dev(struct vdpa_device *vdev, u16 idx)
+static union virtio_map mlx5_get_vq_map(struct vdpa_device *vdev, u16 idx)
 {
 	struct mlx5_vdpa_dev *mvdev = to_mvdev(vdev);
+	union virtio_map map;
 
 	if (is_ctrl_vq_idx(mvdev, idx))
-		return &vdev->dev;
+		map.dma_dev = &vdev->dev;
+	else
+		map.dma_dev = mvdev->vdev.vmap.dma_dev;
 
-	return mvdev->vdev.dma_dev;
+	return map;
 }
 
 static void free_irqs(struct mlx5_vdpa_net *ndev)
@@ -3686,7 +3689,7 @@ static const struct vdpa_config_ops mlx5_vdpa_ops = {
 	.set_map = mlx5_vdpa_set_map,
 	.reset_map = mlx5_vdpa_reset_map,
 	.set_group_asid = mlx5_set_group_asid,
-	.get_vq_dma_dev = mlx5_get_vq_dma_dev,
+	.get_vq_map = mlx5_get_vq_map,
 	.free = mlx5_vdpa_free,
 	.suspend = mlx5_vdpa_suspend,
 	.resume = mlx5_vdpa_resume, /* Op disabled if not supported. */
@@ -3879,7 +3882,7 @@ static int mlx5_vdpa_dev_add(struct vdpa_mgmt_dev *v_mdev, const char *name,
 	}
 
 	ndev = vdpa_alloc_device(struct mlx5_vdpa_net, mvdev.vdev, mdev->device, &mgtdev->vdpa_ops,
-				 MLX5_VDPA_NUMVQ_GROUPS, MLX5_VDPA_NUM_AS, name, false);
+				 NULL, MLX5_VDPA_NUMVQ_GROUPS, MLX5_VDPA_NUM_AS, name, false);
 	if (IS_ERR(ndev))
 		return PTR_ERR(ndev);
 
@@ -3965,7 +3968,7 @@ static int mlx5_vdpa_dev_add(struct vdpa_mgmt_dev *v_mdev, const char *name,
 	}
 
 	ndev->mvdev.mlx_features = device_features;
-	mvdev->vdev.dma_dev = &mdev->pdev->dev;
+	mvdev->vdev.vmap.dma_dev = &mdev->pdev->dev;
 	err = mlx5_vdpa_alloc_resources(&ndev->mvdev);
 	if (err)
 		goto err_alloc;
diff --git a/drivers/vdpa/octeon_ep/octep_vdpa_main.c b/drivers/vdpa/octeon_ep/octep_vdpa_main.c
index 9b49efd..9e8d070 100644
--- a/drivers/vdpa/octeon_ep/octep_vdpa_main.c
+++ b/drivers/vdpa/octeon_ep/octep_vdpa_main.c
@@ -508,15 +508,15 @@ static int octep_vdpa_dev_add(struct vdpa_mgmt_dev *mdev, const char *name,
 	u64 device_features;
 	int ret;
 
-	oct_vdpa = vdpa_alloc_device(struct octep_vdpa, vdpa, &pdev->dev, &octep_vdpa_ops, 1, 1,
-				     NULL, false);
+	oct_vdpa = vdpa_alloc_device(struct octep_vdpa, vdpa, &pdev->dev, &octep_vdpa_ops,
+				     NULL, 1, 1, NULL, false);
 	if (IS_ERR(oct_vdpa)) {
 		dev_err(&pdev->dev, "Failed to allocate vDPA structure for octep vdpa device");
 		return PTR_ERR(oct_vdpa);
 	}
 
 	oct_vdpa->pdev = pdev;
-	oct_vdpa->vdpa.dma_dev = &pdev->dev;
+	oct_vdpa->vdpa.vmap.dma_dev = &pdev->dev;
 	oct_vdpa->vdpa.mdev = mdev;
 	oct_vdpa->oct_hw = oct_hw;
 	vdpa_dev = &oct_vdpa->vdpa;
diff --git a/drivers/vdpa/pds/vdpa_dev.c b/drivers/vdpa/pds/vdpa_dev.c
index 301d95e..36f61cc 100644
--- a/drivers/vdpa/pds/vdpa_dev.c
+++ b/drivers/vdpa/pds/vdpa_dev.c
@@ -632,7 +632,8 @@ static int pds_vdpa_dev_add(struct vdpa_mgmt_dev *mdev, const char *name,
 	}
 
 	pdsv = vdpa_alloc_device(struct pds_vdpa_device, vdpa_dev,
-				 dev, &pds_vdpa_ops, 1, 1, name, false);
+				 dev, &pds_vdpa_ops, NULL,
+				 1, 1, name, false);
 	if (IS_ERR(pdsv)) {
 		dev_err(dev, "Failed to allocate vDPA structure: %pe\n", pdsv);
 		return PTR_ERR(pdsv);
@@ -643,7 +644,7 @@ static int pds_vdpa_dev_add(struct vdpa_mgmt_dev *mdev, const char *name,
 
 	pdev = vdpa_aux->padev->vf_pdev;
 	dma_dev = &pdev->dev;
-	pdsv->vdpa_dev.dma_dev = dma_dev;
+	pdsv->vdpa_dev.vmap.dma_dev = dma_dev;
 
 	status = pds_vdpa_get_status(&pdsv->vdpa_dev);
 	if (status == 0xff) {
diff --git a/drivers/vdpa/solidrun/snet_main.c b/drivers/vdpa/solidrun/snet_main.c
index 55ec51c..4588211 100644
--- a/drivers/vdpa/solidrun/snet_main.c
+++ b/drivers/vdpa/solidrun/snet_main.c
@@ -1008,8 +1008,8 @@ static int snet_vdpa_probe_vf(struct pci_dev *pdev)
 	}
 
 	/* Allocate vdpa device */
-	snet = vdpa_alloc_device(struct snet, vdpa, &pdev->dev, &snet_config_ops, 1, 1, NULL,
-				 false);
+	snet = vdpa_alloc_device(struct snet, vdpa, &pdev->dev, &snet_config_ops,
+				 NULL, 1, 1, NULL, false);
 	if (!snet) {
 		SNET_ERR(pdev, "Failed to allocate a vdpa device\n");
 		ret = -ENOMEM;
@@ -1052,8 +1052,8 @@ static int snet_vdpa_probe_vf(struct pci_dev *pdev)
 	 */
 	snet_reserve_irq_idx(pf_irqs ? pdev_pf : pdev, snet);
 
-	/*set DMA device*/
-	snet->vdpa.dma_dev = &pdev->dev;
+	/* set map metadata */
+	snet->vdpa.vmap.dma_dev = &pdev->dev;
 
 	/* Register VDPA device */
 	ret = vdpa_register_device(&snet->vdpa, snet->cfg->vq_num);
diff --git a/drivers/vdpa/vdpa.c b/drivers/vdpa/vdpa.c
index 8a372b5..34874be 100644
--- a/drivers/vdpa/vdpa.c
+++ b/drivers/vdpa/vdpa.c
@@ -142,6 +142,7 @@ static void vdpa_release_dev(struct device *d)
  * initialized but before registered.
  * @parent: the parent device
  * @config: the bus operations that is supported by this device
+ * @map: the map operations that is supported by this device
  * @ngroups: number of groups supported by this device
  * @nas: number of address spaces supported by this device
  * @size: size of the parent structure that contains private data
@@ -151,11 +152,12 @@ static void vdpa_release_dev(struct device *d)
  * Driver should use vdpa_alloc_device() wrapper macro instead of
  * using this directly.
  *
- * Return: Returns an error when parent/config/dma_dev is not set or fail to get
+ * Return: Returns an error when parent/config/map is not set or fail to get
  *	   ida.
  */
 struct vdpa_device *__vdpa_alloc_device(struct device *parent,
 					const struct vdpa_config_ops *config,
+					const struct virtio_map_ops *map,
 					unsigned int ngroups, unsigned int nas,
 					size_t size, const char *name,
 					bool use_va)
@@ -187,6 +189,7 @@ struct vdpa_device *__vdpa_alloc_device(struct device *parent,
 	vdev->dev.release = vdpa_release_dev;
 	vdev->index = err;
 	vdev->config = config;
+	vdev->map = map;
 	vdev->features_valid = false;
 	vdev->use_va = use_va;
 	vdev->ngroups = ngroups;
diff --git a/drivers/vdpa/vdpa_sim/vdpa_sim.c b/drivers/vdpa/vdpa_sim/vdpa_sim.c
index c204fc8..c1c6431 100644
--- a/drivers/vdpa/vdpa_sim/vdpa_sim.c
+++ b/drivers/vdpa/vdpa_sim/vdpa_sim.c
@@ -215,7 +215,7 @@ struct vdpasim *vdpasim_create(struct vdpasim_dev_attr *dev_attr,
 	else
 		ops = &vdpasim_config_ops;
 
-	vdpa = __vdpa_alloc_device(NULL, ops,
+	vdpa = __vdpa_alloc_device(NULL, ops, NULL,
 				   dev_attr->ngroups, dev_attr->nas,
 				   dev_attr->alloc_size,
 				   dev_attr->name, use_va);
@@ -272,7 +272,7 @@ struct vdpasim *vdpasim_create(struct vdpasim_dev_attr *dev_attr,
 		vringh_set_iotlb(&vdpasim->vqs[i].vring, &vdpasim->iommu[0],
 				 &vdpasim->iommu_lock);
 
-	vdpasim->vdpa.dma_dev = dev;
+	vdpasim->vdpa.vmap.dma_dev = dev;
 
 	return vdpasim;
 
diff --git a/drivers/vdpa/vdpa_user/iova_domain.c b/drivers/vdpa/vdpa_user/iova_domain.c
index 58116f8..4352b5c 100644
--- a/drivers/vdpa/vdpa_user/iova_domain.c
+++ b/drivers/vdpa/vdpa_user/iova_domain.c
@@ -103,19 +103,38 @@ void vduse_domain_clear_map(struct vduse_iova_domain *domain,
 static int vduse_domain_map_bounce_page(struct vduse_iova_domain *domain,
 					 u64 iova, u64 size, u64 paddr)
 {
-	struct vduse_bounce_map *map;
+	struct vduse_bounce_map *map, *head_map;
+	struct page *tmp_page;
 	u64 last = iova + size - 1;
 
 	while (iova <= last) {
-		map = &domain->bounce_maps[iova >> PAGE_SHIFT];
+		/*
+		 * When PAGE_SIZE is larger than 4KB, multiple adjacent bounce_maps will
+		 * point to the same memory page of PAGE_SIZE. Since bounce_maps originate
+		 * from IO requests, we may not be able to guarantee that the orig_phys
+		 * values of all IO requests within the same 64KB memory page are contiguous.
+		 * Therefore, we need to store them separately.
+		 *
+		 * Bounce pages are allocated on demand. As a result, it may occur that
+		 * multiple bounce pages corresponding to the same 64KB memory page attempt
+		 * to allocate memory simultaneously, so we use cmpxchg to handle this
+		 * concurrency.
+		 */
+		map = &domain->bounce_maps[iova >> BOUNCE_MAP_SHIFT];
 		if (!map->bounce_page) {
-			map->bounce_page = alloc_page(GFP_ATOMIC);
-			if (!map->bounce_page)
-				return -ENOMEM;
+			head_map = &domain->bounce_maps[(iova & PAGE_MASK) >> BOUNCE_MAP_SHIFT];
+			if (!head_map->bounce_page) {
+				tmp_page = alloc_page(GFP_ATOMIC);
+				if (!tmp_page)
+					return -ENOMEM;
+				if (cmpxchg(&head_map->bounce_page, NULL, tmp_page))
+					__free_page(tmp_page);
+			}
+			map->bounce_page = head_map->bounce_page;
 		}
 		map->orig_phys = paddr;
-		paddr += PAGE_SIZE;
-		iova += PAGE_SIZE;
+		paddr += BOUNCE_MAP_SIZE;
+		iova += BOUNCE_MAP_SIZE;
 	}
 	return 0;
 }
@@ -127,12 +146,17 @@ static void vduse_domain_unmap_bounce_page(struct vduse_iova_domain *domain,
 	u64 last = iova + size - 1;
 
 	while (iova <= last) {
-		map = &domain->bounce_maps[iova >> PAGE_SHIFT];
+		map = &domain->bounce_maps[iova >> BOUNCE_MAP_SHIFT];
 		map->orig_phys = INVALID_PHYS_ADDR;
-		iova += PAGE_SIZE;
+		iova += BOUNCE_MAP_SIZE;
 	}
 }
 
+static unsigned int offset_in_bounce_page(dma_addr_t addr)
+{
+	return (addr & ~BOUNCE_MAP_MASK);
+}
+
 static void do_bounce(phys_addr_t orig, void *addr, size_t size,
 		      enum dma_data_direction dir)
 {
@@ -163,7 +187,7 @@ static void vduse_domain_bounce(struct vduse_iova_domain *domain,
 {
 	struct vduse_bounce_map *map;
 	struct page *page;
-	unsigned int offset;
+	unsigned int offset, head_offset;
 	void *addr;
 	size_t sz;
 
@@ -171,9 +195,10 @@ static void vduse_domain_bounce(struct vduse_iova_domain *domain,
 		return;
 
 	while (size) {
-		map = &domain->bounce_maps[iova >> PAGE_SHIFT];
-		offset = offset_in_page(iova);
-		sz = min_t(size_t, PAGE_SIZE - offset, size);
+		map = &domain->bounce_maps[iova >> BOUNCE_MAP_SHIFT];
+		head_offset = offset_in_page(iova);
+		offset = offset_in_bounce_page(iova);
+		sz = min_t(size_t, BOUNCE_MAP_SIZE - offset, size);
 
 		if (WARN_ON(!map->bounce_page ||
 			    map->orig_phys == INVALID_PHYS_ADDR))
@@ -183,7 +208,7 @@ static void vduse_domain_bounce(struct vduse_iova_domain *domain,
 		       map->user_bounce_page : map->bounce_page;
 
 		addr = kmap_local_page(page);
-		do_bounce(map->orig_phys + offset, addr + offset, sz, dir);
+		do_bounce(map->orig_phys + offset, addr + head_offset, sz, dir);
 		kunmap_local(addr);
 		size -= sz;
 		iova += sz;
@@ -218,7 +243,7 @@ vduse_domain_get_bounce_page(struct vduse_iova_domain *domain, u64 iova)
 	struct page *page = NULL;
 
 	read_lock(&domain->bounce_lock);
-	map = &domain->bounce_maps[iova >> PAGE_SHIFT];
+	map = &domain->bounce_maps[iova >> BOUNCE_MAP_SHIFT];
 	if (domain->user_bounce_pages || !map->bounce_page)
 		goto out;
 
@@ -236,7 +261,7 @@ vduse_domain_free_kernel_bounce_pages(struct vduse_iova_domain *domain)
 	struct vduse_bounce_map *map;
 	unsigned long pfn, bounce_pfns;
 
-	bounce_pfns = domain->bounce_size >> PAGE_SHIFT;
+	bounce_pfns = domain->bounce_size >> BOUNCE_MAP_SHIFT;
 
 	for (pfn = 0; pfn < bounce_pfns; pfn++) {
 		map = &domain->bounce_maps[pfn];
@@ -246,7 +271,8 @@ vduse_domain_free_kernel_bounce_pages(struct vduse_iova_domain *domain)
 		if (!map->bounce_page)
 			continue;
 
-		__free_page(map->bounce_page);
+		if (!((pfn << BOUNCE_MAP_SHIFT) & ~PAGE_MASK))
+			__free_page(map->bounce_page);
 		map->bounce_page = NULL;
 	}
 }
@@ -254,8 +280,12 @@ vduse_domain_free_kernel_bounce_pages(struct vduse_iova_domain *domain)
 int vduse_domain_add_user_bounce_pages(struct vduse_iova_domain *domain,
 				       struct page **pages, int count)
 {
-	struct vduse_bounce_map *map;
-	int i, ret;
+	struct vduse_bounce_map *map, *head_map;
+	int i, j, ret;
+	int inner_pages = PAGE_SIZE / BOUNCE_MAP_SIZE;
+	int bounce_pfns = domain->bounce_size >> BOUNCE_MAP_SHIFT;
+	struct page *head_page = NULL;
+	bool need_copy;
 
 	/* Now we don't support partial mapping */
 	if (count != (domain->bounce_size >> PAGE_SHIFT))
@@ -267,16 +297,23 @@ int vduse_domain_add_user_bounce_pages(struct vduse_iova_domain *domain,
 		goto out;
 
 	for (i = 0; i < count; i++) {
-		map = &domain->bounce_maps[i];
-		if (map->bounce_page) {
+		need_copy = false;
+		head_map = &domain->bounce_maps[(i * inner_pages)];
+		head_page = head_map->bounce_page;
+		for (j = 0; j < inner_pages; j++) {
+			if ((i * inner_pages + j) >= bounce_pfns)
+				break;
+			map = &domain->bounce_maps[(i * inner_pages + j)];
 			/* Copy kernel page to user page if it's in use */
-			if (map->orig_phys != INVALID_PHYS_ADDR)
-				memcpy_to_page(pages[i], 0,
-					       page_address(map->bounce_page),
-					       PAGE_SIZE);
+			if ((head_page) && (map->orig_phys != INVALID_PHYS_ADDR))
+				need_copy = true;
+			map->user_bounce_page = pages[i];
 		}
-		map->user_bounce_page = pages[i];
 		get_page(pages[i]);
+		if ((head_page) && (need_copy))
+			memcpy_to_page(pages[i], 0,
+				       page_address(head_page),
+				       PAGE_SIZE);
 	}
 	domain->user_bounce_pages = true;
 	ret = 0;
@@ -288,8 +325,12 @@ int vduse_domain_add_user_bounce_pages(struct vduse_iova_domain *domain,
 
 void vduse_domain_remove_user_bounce_pages(struct vduse_iova_domain *domain)
 {
-	struct vduse_bounce_map *map;
-	unsigned long i, count;
+	struct vduse_bounce_map *map, *head_map;
+	unsigned long i, j, count;
+	int inner_pages = PAGE_SIZE / BOUNCE_MAP_SIZE;
+	int bounce_pfns = domain->bounce_size >> BOUNCE_MAP_SHIFT;
+	struct page *head_page = NULL;
+	bool need_copy;
 
 	write_lock(&domain->bounce_lock);
 	if (!domain->user_bounce_pages)
@@ -297,20 +338,27 @@ void vduse_domain_remove_user_bounce_pages(struct vduse_iova_domain *domain)
 
 	count = domain->bounce_size >> PAGE_SHIFT;
 	for (i = 0; i < count; i++) {
-		struct page *page = NULL;
-
-		map = &domain->bounce_maps[i];
-		if (WARN_ON(!map->user_bounce_page))
+		need_copy = false;
+		head_map = &domain->bounce_maps[(i * inner_pages)];
+		if (WARN_ON(!head_map->user_bounce_page))
 			continue;
+		head_page = head_map->user_bounce_page;
 
-		/* Copy user page to kernel page if it's in use */
-		if (map->orig_phys != INVALID_PHYS_ADDR) {
-			page = map->bounce_page;
-			memcpy_from_page(page_address(page),
-					 map->user_bounce_page, 0, PAGE_SIZE);
+		for (j = 0; j < inner_pages; j++) {
+			if ((i * inner_pages + j) >= bounce_pfns)
+				break;
+			map = &domain->bounce_maps[(i * inner_pages + j)];
+			if (WARN_ON(!map->user_bounce_page))
+				continue;
+			/* Copy user page to kernel page if it's in use */
+			if ((map->orig_phys != INVALID_PHYS_ADDR) && (head_map->bounce_page))
+				need_copy = true;
+			map->user_bounce_page = NULL;
 		}
-		put_page(map->user_bounce_page);
-		map->user_bounce_page = NULL;
+		if (need_copy)
+			memcpy_from_page(page_address(head_map->bounce_page),
+					 head_page, 0, PAGE_SIZE);
+		put_page(head_page);
 	}
 	domain->user_bounce_pages = false;
 out:
@@ -447,7 +495,7 @@ void vduse_domain_unmap_page(struct vduse_iova_domain *domain,
 
 void *vduse_domain_alloc_coherent(struct vduse_iova_domain *domain,
 				  size_t size, dma_addr_t *dma_addr,
-				  gfp_t flag, unsigned long attrs)
+				  gfp_t flag)
 {
 	struct iova_domain *iovad = &domain->consistent_iovad;
 	unsigned long limit = domain->iova_limit;
@@ -581,7 +629,7 @@ vduse_domain_create(unsigned long iova_limit, size_t bounce_size)
 	unsigned long pfn, bounce_pfns;
 	int ret;
 
-	bounce_pfns = PAGE_ALIGN(bounce_size) >> PAGE_SHIFT;
+	bounce_pfns = PAGE_ALIGN(bounce_size) >> BOUNCE_MAP_SHIFT;
 	if (iova_limit <= bounce_size)
 		return NULL;
 
@@ -613,7 +661,7 @@ vduse_domain_create(unsigned long iova_limit, size_t bounce_size)
 	rwlock_init(&domain->bounce_lock);
 	spin_lock_init(&domain->iotlb_lock);
 	init_iova_domain(&domain->stream_iovad,
-			PAGE_SIZE, IOVA_START_PFN);
+			BOUNCE_MAP_SIZE, IOVA_START_PFN);
 	ret = iova_domain_init_rcaches(&domain->stream_iovad);
 	if (ret)
 		goto err_iovad_stream;
diff --git a/drivers/vdpa/vdpa_user/iova_domain.h b/drivers/vdpa/vdpa_user/iova_domain.h
index 7f3f092..775cad523 100644
--- a/drivers/vdpa/vdpa_user/iova_domain.h
+++ b/drivers/vdpa/vdpa_user/iova_domain.h
@@ -19,6 +19,11 @@
 
 #define INVALID_PHYS_ADDR (~(phys_addr_t)0)
 
+#define BOUNCE_MAP_SHIFT	12
+#define BOUNCE_MAP_SIZE	(1 << BOUNCE_MAP_SHIFT)
+#define BOUNCE_MAP_MASK	(~(BOUNCE_MAP_SIZE - 1))
+#define BOUNCE_MAP_ALIGN(addr)	(((addr) + BOUNCE_MAP_SIZE - 1) & ~(BOUNCE_MAP_SIZE - 1))
+
 struct vduse_bounce_map {
 	struct page *bounce_page;
 	struct page *user_bounce_page;
@@ -64,7 +69,7 @@ void vduse_domain_unmap_page(struct vduse_iova_domain *domain,
 
 void *vduse_domain_alloc_coherent(struct vduse_iova_domain *domain,
 				  size_t size, dma_addr_t *dma_addr,
-				  gfp_t flag, unsigned long attrs);
+				  gfp_t flag);
 
 void vduse_domain_free_coherent(struct vduse_iova_domain *domain, size_t size,
 				void *vaddr, dma_addr_t dma_addr,
diff --git a/drivers/vdpa/vdpa_user/vduse_dev.c b/drivers/vdpa/vdpa_user/vduse_dev.c
index 04620bb..e7bced0 100644
--- a/drivers/vdpa/vdpa_user/vduse_dev.c
+++ b/drivers/vdpa/vdpa_user/vduse_dev.c
@@ -814,59 +814,53 @@ static const struct vdpa_config_ops vduse_vdpa_config_ops = {
 	.free			= vduse_vdpa_free,
 };
 
-static void vduse_dev_sync_single_for_device(struct device *dev,
+static void vduse_dev_sync_single_for_device(union virtio_map token,
 					     dma_addr_t dma_addr, size_t size,
 					     enum dma_data_direction dir)
 {
-	struct vduse_dev *vdev = dev_to_vduse(dev);
-	struct vduse_iova_domain *domain = vdev->domain;
+	struct vduse_iova_domain *domain = token.iova_domain;
 
 	vduse_domain_sync_single_for_device(domain, dma_addr, size, dir);
 }
 
-static void vduse_dev_sync_single_for_cpu(struct device *dev,
+static void vduse_dev_sync_single_for_cpu(union virtio_map token,
 					     dma_addr_t dma_addr, size_t size,
 					     enum dma_data_direction dir)
 {
-	struct vduse_dev *vdev = dev_to_vduse(dev);
-	struct vduse_iova_domain *domain = vdev->domain;
+	struct vduse_iova_domain *domain = token.iova_domain;
 
 	vduse_domain_sync_single_for_cpu(domain, dma_addr, size, dir);
 }
 
-static dma_addr_t vduse_dev_map_page(struct device *dev, struct page *page,
+static dma_addr_t vduse_dev_map_page(union virtio_map token, struct page *page,
 				     unsigned long offset, size_t size,
 				     enum dma_data_direction dir,
 				     unsigned long attrs)
 {
-	struct vduse_dev *vdev = dev_to_vduse(dev);
-	struct vduse_iova_domain *domain = vdev->domain;
+	struct vduse_iova_domain *domain = token.iova_domain;
 
 	return vduse_domain_map_page(domain, page, offset, size, dir, attrs);
 }
 
-static void vduse_dev_unmap_page(struct device *dev, dma_addr_t dma_addr,
-				size_t size, enum dma_data_direction dir,
-				unsigned long attrs)
+static void vduse_dev_unmap_page(union virtio_map token, dma_addr_t dma_addr,
+				 size_t size, enum dma_data_direction dir,
+				 unsigned long attrs)
 {
-	struct vduse_dev *vdev = dev_to_vduse(dev);
-	struct vduse_iova_domain *domain = vdev->domain;
+	struct vduse_iova_domain *domain = token.iova_domain;
 
 	return vduse_domain_unmap_page(domain, dma_addr, size, dir, attrs);
 }
 
-static void *vduse_dev_alloc_coherent(struct device *dev, size_t size,
-					dma_addr_t *dma_addr, gfp_t flag,
-					unsigned long attrs)
+static void *vduse_dev_alloc_coherent(union virtio_map token, size_t size,
+				      dma_addr_t *dma_addr, gfp_t flag)
 {
-	struct vduse_dev *vdev = dev_to_vduse(dev);
-	struct vduse_iova_domain *domain = vdev->domain;
+	struct vduse_iova_domain *domain = token.iova_domain;
 	unsigned long iova;
 	void *addr;
 
 	*dma_addr = DMA_MAPPING_ERROR;
 	addr = vduse_domain_alloc_coherent(domain, size,
-				(dma_addr_t *)&iova, flag, attrs);
+					   (dma_addr_t *)&iova, flag);
 	if (!addr)
 		return NULL;
 
@@ -875,31 +869,45 @@ static void *vduse_dev_alloc_coherent(struct device *dev, size_t size,
 	return addr;
 }
 
-static void vduse_dev_free_coherent(struct device *dev, size_t size,
-					void *vaddr, dma_addr_t dma_addr,
-					unsigned long attrs)
+static void vduse_dev_free_coherent(union virtio_map token, size_t size,
+				    void *vaddr, dma_addr_t dma_addr,
+				    unsigned long attrs)
 {
-	struct vduse_dev *vdev = dev_to_vduse(dev);
-	struct vduse_iova_domain *domain = vdev->domain;
+	struct vduse_iova_domain *domain = token.iova_domain;
 
 	vduse_domain_free_coherent(domain, size, vaddr, dma_addr, attrs);
 }
 
-static size_t vduse_dev_max_mapping_size(struct device *dev)
+static bool vduse_dev_need_sync(union virtio_map token, dma_addr_t dma_addr)
 {
-	struct vduse_dev *vdev = dev_to_vduse(dev);
-	struct vduse_iova_domain *domain = vdev->domain;
+	struct vduse_iova_domain *domain = token.iova_domain;
+
+	return dma_addr < domain->bounce_size;
+}
+
+static int vduse_dev_mapping_error(union virtio_map token, dma_addr_t dma_addr)
+{
+	if (unlikely(dma_addr == DMA_MAPPING_ERROR))
+		return -ENOMEM;
+	return 0;
+}
+
+static size_t vduse_dev_max_mapping_size(union virtio_map token)
+{
+	struct vduse_iova_domain *domain = token.iova_domain;
 
 	return domain->bounce_size;
 }
 
-static const struct dma_map_ops vduse_dev_dma_ops = {
+static const struct virtio_map_ops vduse_map_ops = {
 	.sync_single_for_device = vduse_dev_sync_single_for_device,
 	.sync_single_for_cpu = vduse_dev_sync_single_for_cpu,
 	.map_page = vduse_dev_map_page,
 	.unmap_page = vduse_dev_unmap_page,
 	.alloc = vduse_dev_alloc_coherent,
 	.free = vduse_dev_free_coherent,
+	.need_sync = vduse_dev_need_sync,
+	.mapping_error = vduse_dev_mapping_error,
 	.max_mapping_size = vduse_dev_max_mapping_size,
 };
 
@@ -2003,26 +2011,18 @@ static struct vduse_mgmt_dev *vduse_mgmt;
 static int vduse_dev_init_vdpa(struct vduse_dev *dev, const char *name)
 {
 	struct vduse_vdpa *vdev;
-	int ret;
 
 	if (dev->vdev)
 		return -EEXIST;
 
 	vdev = vdpa_alloc_device(struct vduse_vdpa, vdpa, dev->dev,
-				 &vduse_vdpa_config_ops, 1, 1, name, true);
+				 &vduse_vdpa_config_ops, &vduse_map_ops,
+				 1, 1, name, true);
 	if (IS_ERR(vdev))
 		return PTR_ERR(vdev);
 
 	dev->vdev = vdev;
 	vdev->dev = dev;
-	vdev->vdpa.dev.dma_mask = &vdev->vdpa.dev.coherent_dma_mask;
-	ret = dma_set_mask_and_coherent(&vdev->vdpa.dev, DMA_BIT_MASK(64));
-	if (ret) {
-		put_device(&vdev->vdpa.dev);
-		return ret;
-	}
-	set_dma_ops(&vdev->vdpa.dev, &vduse_dev_dma_ops);
-	vdev->vdpa.dma_dev = &vdev->vdpa.dev;
 	vdev->vdpa.mdev = &vduse_mgmt->mgmt_dev;
 
 	return 0;
@@ -2055,6 +2055,7 @@ static int vdpa_dev_add(struct vdpa_mgmt_dev *mdev, const char *name,
 		return -ENOMEM;
 	}
 
+	dev->vdev->vdpa.vmap.iova_domain = dev->domain;
 	ret = _vdpa_register_device(&dev->vdev->vdpa, dev->vq_num);
 	if (ret) {
 		put_device(&dev->vdev->vdpa.dev);
diff --git a/drivers/vdpa/virtio_pci/vp_vdpa.c b/drivers/vdpa/virtio_pci/vp_vdpa.c
index 8787407f..17a19a7 100644
--- a/drivers/vdpa/virtio_pci/vp_vdpa.c
+++ b/drivers/vdpa/virtio_pci/vp_vdpa.c
@@ -511,7 +511,8 @@ static int vp_vdpa_dev_add(struct vdpa_mgmt_dev *v_mdev, const char *name,
 	int ret, i;
 
 	vp_vdpa = vdpa_alloc_device(struct vp_vdpa, vdpa,
-				    dev, &vp_vdpa_ops, 1, 1, name, false);
+				    dev, &vp_vdpa_ops, NULL,
+				    1, 1, name, false);
 
 	if (IS_ERR(vp_vdpa)) {
 		dev_err(dev, "vp_vdpa: Failed to allocate vDPA structure\n");
@@ -520,7 +521,7 @@ static int vp_vdpa_dev_add(struct vdpa_mgmt_dev *v_mdev, const char *name,
 
 	vp_vdpa_mgtdev->vp_vdpa = vp_vdpa;
 
-	vp_vdpa->vdpa.dma_dev = &pdev->dev;
+	vp_vdpa->vdpa.vmap.dma_dev = &pdev->dev;
 	vp_vdpa->queues = vp_modern_get_num_queues(mdev);
 	vp_vdpa->mdev = mdev;
 
diff --git a/drivers/vhost/vdpa.c b/drivers/vhost/vdpa.c
index af1e1fd..05a481e 100644
--- a/drivers/vhost/vdpa.c
+++ b/drivers/vhost/vdpa.c
@@ -1318,7 +1318,8 @@ static int vhost_vdpa_alloc_domain(struct vhost_vdpa *v)
 {
 	struct vdpa_device *vdpa = v->vdpa;
 	const struct vdpa_config_ops *ops = vdpa->config;
-	struct device *dma_dev = vdpa_get_dma_dev(vdpa);
+	union virtio_map map = vdpa_get_map(vdpa);
+	struct device *dma_dev = map.dma_dev;
 	int ret;
 
 	/* Device want to do DMA by itself */
@@ -1353,7 +1354,8 @@ static int vhost_vdpa_alloc_domain(struct vhost_vdpa *v)
 static void vhost_vdpa_free_domain(struct vhost_vdpa *v)
 {
 	struct vdpa_device *vdpa = v->vdpa;
-	struct device *dma_dev = vdpa_get_dma_dev(vdpa);
+	union virtio_map map = vdpa_get_map(vdpa);
+	struct device *dma_dev = map.dma_dev;
 
 	if (v->domain) {
 		iommu_detach_device(v->domain, dma_dev);
diff --git a/drivers/vhost/vringh.c b/drivers/vhost/vringh.c
index 1778eff..925858c 100644
--- a/drivers/vhost/vringh.c
+++ b/drivers/vhost/vringh.c
@@ -1115,6 +1115,7 @@ static inline int copy_from_iotlb(const struct vringh *vrh, void *dst,
 		struct iov_iter iter;
 		u64 translated;
 		int ret;
+		size_t size;
 
 		ret = iotlb_translate(vrh, (u64)(uintptr_t)src,
 				      len - total_translated, &translated,
@@ -1132,9 +1133,9 @@ static inline int copy_from_iotlb(const struct vringh *vrh, void *dst,
 				      translated);
 		}
 
-		ret = copy_from_iter(dst, translated, &iter);
-		if (ret < 0)
-			return ret;
+		size = copy_from_iter(dst, translated, &iter);
+		if (size != translated)
+			return -EFAULT;
 
 		src += translated;
 		dst += translated;
diff --git a/drivers/virtio/virtio_balloon.c b/drivers/virtio/virtio_balloon.c
index 7f3fd72..1b93d8c 100644
--- a/drivers/virtio/virtio_balloon.c
+++ b/drivers/virtio/virtio_balloon.c
@@ -205,7 +205,7 @@ static int virtballoon_free_page_report(struct page_reporting_dev_info *pr_dev_i
 	unsigned int unused, err;
 
 	/* We should always be able to add these buffers to an empty queue. */
-	err = virtqueue_add_inbuf(vq, sg, nents, vb, GFP_NOWAIT | __GFP_NOWARN);
+	err = virtqueue_add_inbuf(vq, sg, nents, vb, GFP_NOWAIT);
 
 	/*
 	 * In the extremely unlikely case that something has occurred and we
diff --git a/drivers/virtio/virtio_ring.c b/drivers/virtio/virtio_ring.c
index c147145..7b62052 100644
--- a/drivers/virtio/virtio_ring.c
+++ b/drivers/virtio/virtio_ring.c
@@ -166,7 +166,7 @@ struct vring_virtqueue {
 	bool packed_ring;
 
 	/* Is DMA API used? */
-	bool use_dma_api;
+	bool use_map_api;
 
 	/* Can we use weak barriers? */
 	bool weak_barriers;
@@ -210,8 +210,7 @@ struct vring_virtqueue {
 	/* DMA, allocation, and size information */
 	bool we_own_ring;
 
-	/* Device used for doing DMA */
-	struct device *dma_dev;
+	union virtio_map map;
 
 #ifdef DEBUG
 	/* They're supposed to lock for us. */
@@ -268,7 +267,7 @@ static bool virtqueue_use_indirect(const struct vring_virtqueue *vq,
  * unconditionally on data path.
  */
 
-static bool vring_use_dma_api(const struct virtio_device *vdev)
+static bool vring_use_map_api(const struct virtio_device *vdev)
 {
 	if (!virtio_has_dma_quirk(vdev))
 		return true;
@@ -291,33 +290,39 @@ static bool vring_use_dma_api(const struct virtio_device *vdev)
 static bool vring_need_unmap_buffer(const struct vring_virtqueue *vring,
 				    const struct vring_desc_extra *extra)
 {
-	return vring->use_dma_api && (extra->addr != DMA_MAPPING_ERROR);
+	return vring->use_map_api && (extra->addr != DMA_MAPPING_ERROR);
 }
 
 size_t virtio_max_dma_size(const struct virtio_device *vdev)
 {
 	size_t max_segment_size = SIZE_MAX;
 
-	if (vring_use_dma_api(vdev))
-		max_segment_size = dma_max_mapping_size(vdev->dev.parent);
+	if (vring_use_map_api(vdev)) {
+		if (vdev->map) {
+			max_segment_size =
+				vdev->map->max_mapping_size(vdev->vmap);
+		} else
+			max_segment_size =
+				dma_max_mapping_size(vdev->dev.parent);
+	}
 
 	return max_segment_size;
 }
 EXPORT_SYMBOL_GPL(virtio_max_dma_size);
 
 static void *vring_alloc_queue(struct virtio_device *vdev, size_t size,
-			       dma_addr_t *dma_handle, gfp_t flag,
-			       struct device *dma_dev)
+			       dma_addr_t *map_handle, gfp_t flag,
+			       union virtio_map map)
 {
-	if (vring_use_dma_api(vdev)) {
-		return dma_alloc_coherent(dma_dev, size,
-					  dma_handle, flag);
+	if (vring_use_map_api(vdev)) {
+		return virtqueue_map_alloc_coherent(vdev, map, size,
+						    map_handle, flag);
 	} else {
 		void *queue = alloc_pages_exact(PAGE_ALIGN(size), flag);
 
 		if (queue) {
 			phys_addr_t phys_addr = virt_to_phys(queue);
-			*dma_handle = (dma_addr_t)phys_addr;
+			*map_handle = (dma_addr_t)phys_addr;
 
 			/*
 			 * Sanity check: make sure we dind't truncate
@@ -330,7 +335,7 @@ static void *vring_alloc_queue(struct virtio_device *vdev, size_t size,
 			 * warning and abort if we end up with an
 			 * unrepresentable address.
 			 */
-			if (WARN_ON_ONCE(*dma_handle != phys_addr)) {
+			if (WARN_ON_ONCE(*map_handle != phys_addr)) {
 				free_pages_exact(queue, PAGE_ALIGN(size));
 				return NULL;
 			}
@@ -340,11 +345,12 @@ static void *vring_alloc_queue(struct virtio_device *vdev, size_t size,
 }
 
 static void vring_free_queue(struct virtio_device *vdev, size_t size,
-			     void *queue, dma_addr_t dma_handle,
-			     struct device *dma_dev)
+			     void *queue, dma_addr_t map_handle,
+			     union virtio_map map)
 {
-	if (vring_use_dma_api(vdev))
-		dma_free_coherent(dma_dev, size, queue, dma_handle);
+	if (vring_use_map_api(vdev))
+		virtqueue_map_free_coherent(vdev, map, size,
+					    queue, map_handle);
 	else
 		free_pages_exact(queue, PAGE_ALIGN(size));
 }
@@ -356,7 +362,21 @@ static void vring_free_queue(struct virtio_device *vdev, size_t size,
  */
 static struct device *vring_dma_dev(const struct vring_virtqueue *vq)
 {
-	return vq->dma_dev;
+	return vq->map.dma_dev;
+}
+
+static int vring_mapping_error(const struct vring_virtqueue *vq,
+			       dma_addr_t addr)
+{
+	struct virtio_device *vdev = vq->vq.vdev;
+
+	if (!vq->use_map_api)
+		return 0;
+
+	if (vdev->map)
+		return vdev->map->mapping_error(vq->map, addr);
+	else
+		return dma_mapping_error(vring_dma_dev(vq), addr);
 }
 
 /* Map one sg entry. */
@@ -372,7 +392,7 @@ static int vring_map_one_sg(const struct vring_virtqueue *vq, struct scatterlist
 
 	*len = sg->length;
 
-	if (!vq->use_dma_api) {
+	if (!vq->use_map_api) {
 		/*
 		 * If DMA is not used, KMSAN doesn't know that the scatterlist
 		 * is initialized by the hardware. Explicitly check/unpoison it
@@ -388,11 +408,11 @@ static int vring_map_one_sg(const struct vring_virtqueue *vq, struct scatterlist
 	 * the way it expects (we don't guarantee that the scatterlist
 	 * will exist for the lifetime of the mapping).
 	 */
-	*addr = dma_map_page(vring_dma_dev(vq),
-			    sg_page(sg), sg->offset, sg->length,
-			    direction);
+	*addr = virtqueue_map_page_attrs(&vq->vq, sg_page(sg),
+					 sg->offset, sg->length,
+					 direction, 0);
 
-	if (dma_mapping_error(vring_dma_dev(vq), *addr))
+	if (vring_mapping_error(vq, *addr))
 		return -ENOMEM;
 
 	return 0;
@@ -402,20 +422,11 @@ static dma_addr_t vring_map_single(const struct vring_virtqueue *vq,
 				   void *cpu_addr, size_t size,
 				   enum dma_data_direction direction)
 {
-	if (!vq->use_dma_api)
+	if (!vq->use_map_api)
 		return (dma_addr_t)virt_to_phys(cpu_addr);
 
-	return dma_map_single(vring_dma_dev(vq),
-			      cpu_addr, size, direction);
-}
-
-static int vring_mapping_error(const struct vring_virtqueue *vq,
-			       dma_addr_t addr)
-{
-	if (!vq->use_dma_api)
-		return 0;
-
-	return dma_mapping_error(vring_dma_dev(vq), addr);
+	return virtqueue_map_single_attrs(&vq->vq, cpu_addr,
+					  size, direction, 0);
 }
 
 static void virtqueue_init(struct vring_virtqueue *vq, u32 num)
@@ -449,24 +460,17 @@ static unsigned int vring_unmap_one_split(const struct vring_virtqueue *vq,
 	flags = extra->flags;
 
 	if (flags & VRING_DESC_F_INDIRECT) {
-		if (!vq->use_dma_api)
+		if (!vq->use_map_api)
 			goto out;
+	} else if (!vring_need_unmap_buffer(vq, extra))
+		goto out;
 
-		dma_unmap_single(vring_dma_dev(vq),
-				 extra->addr,
-				 extra->len,
-				 (flags & VRING_DESC_F_WRITE) ?
-				 DMA_FROM_DEVICE : DMA_TO_DEVICE);
-	} else {
-		if (!vring_need_unmap_buffer(vq, extra))
-			goto out;
-
-		dma_unmap_page(vring_dma_dev(vq),
-			       extra->addr,
-			       extra->len,
-			       (flags & VRING_DESC_F_WRITE) ?
-			       DMA_FROM_DEVICE : DMA_TO_DEVICE);
-	}
+	virtqueue_unmap_page_attrs(&vq->vq,
+				   extra->addr,
+				   extra->len,
+				   (flags & VRING_DESC_F_WRITE) ?
+				   DMA_FROM_DEVICE : DMA_TO_DEVICE,
+				   0);
 
 out:
 	return extra->next;
@@ -790,7 +794,7 @@ static void detach_buf_split(struct vring_virtqueue *vq, unsigned int head,
 
 		extra = (struct vring_desc_extra *)&indir_desc[num];
 
-		if (vq->use_dma_api) {
+		if (vq->use_map_api) {
 			for (j = 0; j < num; j++)
 				vring_unmap_one_split(vq, &extra[j]);
 		}
@@ -1064,12 +1068,13 @@ static int vring_alloc_state_extra_split(struct vring_virtqueue_split *vring_spl
 }
 
 static void vring_free_split(struct vring_virtqueue_split *vring_split,
-			     struct virtio_device *vdev, struct device *dma_dev)
+			     struct virtio_device *vdev,
+			     union virtio_map map)
 {
 	vring_free_queue(vdev, vring_split->queue_size_in_bytes,
 			 vring_split->vring.desc,
 			 vring_split->queue_dma_addr,
-			 dma_dev);
+			 map);
 
 	kfree(vring_split->desc_state);
 	kfree(vring_split->desc_extra);
@@ -1080,7 +1085,7 @@ static int vring_alloc_queue_split(struct vring_virtqueue_split *vring_split,
 				   u32 num,
 				   unsigned int vring_align,
 				   bool may_reduce_num,
-				   struct device *dma_dev)
+				   union virtio_map map)
 {
 	void *queue = NULL;
 	dma_addr_t dma_addr;
@@ -1096,7 +1101,7 @@ static int vring_alloc_queue_split(struct vring_virtqueue_split *vring_split,
 		queue = vring_alloc_queue(vdev, vring_size(num, vring_align),
 					  &dma_addr,
 					  GFP_KERNEL | __GFP_NOWARN | __GFP_ZERO,
-					  dma_dev);
+					  map);
 		if (queue)
 			break;
 		if (!may_reduce_num)
@@ -1110,7 +1115,7 @@ static int vring_alloc_queue_split(struct vring_virtqueue_split *vring_split,
 		/* Try to get a single page. You are my only hope! */
 		queue = vring_alloc_queue(vdev, vring_size(num, vring_align),
 					  &dma_addr, GFP_KERNEL | __GFP_ZERO,
-					  dma_dev);
+					  map);
 	}
 	if (!queue)
 		return -ENOMEM;
@@ -1134,7 +1139,7 @@ static struct virtqueue *__vring_new_virtqueue_split(unsigned int index,
 					       bool (*notify)(struct virtqueue *),
 					       void (*callback)(struct virtqueue *),
 					       const char *name,
-					       struct device *dma_dev)
+					       union virtio_map map)
 {
 	struct vring_virtqueue *vq;
 	int err;
@@ -1157,8 +1162,8 @@ static struct virtqueue *__vring_new_virtqueue_split(unsigned int index,
 #else
 	vq->broken = false;
 #endif
-	vq->dma_dev = dma_dev;
-	vq->use_dma_api = vring_use_dma_api(vdev);
+	vq->map = map;
+	vq->use_map_api = vring_use_map_api(vdev);
 
 	vq->indirect = virtio_has_feature(vdev, VIRTIO_RING_F_INDIRECT_DESC) &&
 		!context;
@@ -1195,21 +1200,21 @@ static struct virtqueue *vring_create_virtqueue_split(
 	bool (*notify)(struct virtqueue *),
 	void (*callback)(struct virtqueue *),
 	const char *name,
-	struct device *dma_dev)
+	union virtio_map map)
 {
 	struct vring_virtqueue_split vring_split = {};
 	struct virtqueue *vq;
 	int err;
 
 	err = vring_alloc_queue_split(&vring_split, vdev, num, vring_align,
-				      may_reduce_num, dma_dev);
+				      may_reduce_num, map);
 	if (err)
 		return NULL;
 
 	vq = __vring_new_virtqueue_split(index, &vring_split, vdev, weak_barriers,
-				   context, notify, callback, name, dma_dev);
+				   context, notify, callback, name, map);
 	if (!vq) {
-		vring_free_split(&vring_split, vdev, dma_dev);
+		vring_free_split(&vring_split, vdev, map);
 		return NULL;
 	}
 
@@ -1228,7 +1233,7 @@ static int virtqueue_resize_split(struct virtqueue *_vq, u32 num)
 	err = vring_alloc_queue_split(&vring_split, vdev, num,
 				      vq->split.vring_align,
 				      vq->split.may_reduce_num,
-				      vring_dma_dev(vq));
+				      vq->map);
 	if (err)
 		goto err;
 
@@ -1246,7 +1251,7 @@ static int virtqueue_resize_split(struct virtqueue *_vq, u32 num)
 	return 0;
 
 err_state_extra:
-	vring_free_split(&vring_split, vdev, vring_dma_dev(vq));
+	vring_free_split(&vring_split, vdev, vq->map);
 err:
 	virtqueue_reinit_split(vq);
 	return -ENOMEM;
@@ -1274,22 +1279,16 @@ static void vring_unmap_extra_packed(const struct vring_virtqueue *vq,
 	flags = extra->flags;
 
 	if (flags & VRING_DESC_F_INDIRECT) {
-		if (!vq->use_dma_api)
+		if (!vq->use_map_api)
 			return;
+	} else if (!vring_need_unmap_buffer(vq, extra))
+		return;
 
-		dma_unmap_single(vring_dma_dev(vq),
-				 extra->addr, extra->len,
-				 (flags & VRING_DESC_F_WRITE) ?
-				 DMA_FROM_DEVICE : DMA_TO_DEVICE);
-	} else {
-		if (!vring_need_unmap_buffer(vq, extra))
-			return;
-
-		dma_unmap_page(vring_dma_dev(vq),
-			       extra->addr, extra->len,
-			       (flags & VRING_DESC_F_WRITE) ?
-			       DMA_FROM_DEVICE : DMA_TO_DEVICE);
-	}
+	virtqueue_unmap_page_attrs(&vq->vq,
+				   extra->addr, extra->len,
+				   (flags & VRING_DESC_F_WRITE) ?
+				   DMA_FROM_DEVICE : DMA_TO_DEVICE,
+				   0);
 }
 
 static struct vring_packed_desc *alloc_indirect_packed(unsigned int total_sg,
@@ -1366,7 +1365,7 @@ static int virtqueue_add_indirect_packed(struct vring_virtqueue *vq,
 			desc[i].addr = cpu_to_le64(addr);
 			desc[i].len = cpu_to_le32(len);
 
-			if (unlikely(vq->use_dma_api)) {
+			if (unlikely(vq->use_map_api)) {
 				extra[i].addr = premapped ? DMA_MAPPING_ERROR : addr;
 				extra[i].len = len;
 				extra[i].flags = n < out_sgs ?  0 : VRING_DESC_F_WRITE;
@@ -1388,7 +1387,7 @@ static int virtqueue_add_indirect_packed(struct vring_virtqueue *vq,
 				sizeof(struct vring_packed_desc));
 	vq->packed.vring.desc[head].id = cpu_to_le16(id);
 
-	if (vq->use_dma_api) {
+	if (vq->use_map_api) {
 		vq->packed.desc_extra[id].addr = addr;
 		vq->packed.desc_extra[id].len = total_sg *
 				sizeof(struct vring_packed_desc);
@@ -1530,7 +1529,7 @@ static inline int virtqueue_add_packed(struct virtqueue *_vq,
 			desc[i].len = cpu_to_le32(len);
 			desc[i].id = cpu_to_le16(id);
 
-			if (unlikely(vq->use_dma_api)) {
+			if (unlikely(vq->use_map_api)) {
 				vq->packed.desc_extra[curr].addr = premapped ?
 					DMA_MAPPING_ERROR : addr;
 				vq->packed.desc_extra[curr].len = len;
@@ -1665,7 +1664,7 @@ static void detach_buf_packed(struct vring_virtqueue *vq,
 	vq->free_head = id;
 	vq->vq.num_free += state->num;
 
-	if (unlikely(vq->use_dma_api)) {
+	if (unlikely(vq->use_map_api)) {
 		curr = id;
 		for (i = 0; i < state->num; i++) {
 			vring_unmap_extra_packed(vq,
@@ -1683,7 +1682,7 @@ static void detach_buf_packed(struct vring_virtqueue *vq,
 		if (!desc)
 			return;
 
-		if (vq->use_dma_api) {
+		if (vq->use_map_api) {
 			len = vq->packed.desc_extra[id].len;
 			num = len / sizeof(struct vring_packed_desc);
 
@@ -1962,25 +1961,25 @@ static struct vring_desc_extra *vring_alloc_desc_extra(unsigned int num)
 
 static void vring_free_packed(struct vring_virtqueue_packed *vring_packed,
 			      struct virtio_device *vdev,
-			      struct device *dma_dev)
+			      union virtio_map map)
 {
 	if (vring_packed->vring.desc)
 		vring_free_queue(vdev, vring_packed->ring_size_in_bytes,
 				 vring_packed->vring.desc,
 				 vring_packed->ring_dma_addr,
-				 dma_dev);
+				 map);
 
 	if (vring_packed->vring.driver)
 		vring_free_queue(vdev, vring_packed->event_size_in_bytes,
 				 vring_packed->vring.driver,
 				 vring_packed->driver_event_dma_addr,
-				 dma_dev);
+				 map);
 
 	if (vring_packed->vring.device)
 		vring_free_queue(vdev, vring_packed->event_size_in_bytes,
 				 vring_packed->vring.device,
 				 vring_packed->device_event_dma_addr,
-				 dma_dev);
+				 map);
 
 	kfree(vring_packed->desc_state);
 	kfree(vring_packed->desc_extra);
@@ -1988,7 +1987,7 @@ static void vring_free_packed(struct vring_virtqueue_packed *vring_packed,
 
 static int vring_alloc_queue_packed(struct vring_virtqueue_packed *vring_packed,
 				    struct virtio_device *vdev,
-				    u32 num, struct device *dma_dev)
+				    u32 num, union virtio_map map)
 {
 	struct vring_packed_desc *ring;
 	struct vring_packed_desc_event *driver, *device;
@@ -2000,7 +1999,7 @@ static int vring_alloc_queue_packed(struct vring_virtqueue_packed *vring_packed,
 	ring = vring_alloc_queue(vdev, ring_size_in_bytes,
 				 &ring_dma_addr,
 				 GFP_KERNEL | __GFP_NOWARN | __GFP_ZERO,
-				 dma_dev);
+				 map);
 	if (!ring)
 		goto err;
 
@@ -2013,7 +2012,7 @@ static int vring_alloc_queue_packed(struct vring_virtqueue_packed *vring_packed,
 	driver = vring_alloc_queue(vdev, event_size_in_bytes,
 				   &driver_event_dma_addr,
 				   GFP_KERNEL | __GFP_NOWARN | __GFP_ZERO,
-				   dma_dev);
+				   map);
 	if (!driver)
 		goto err;
 
@@ -2024,7 +2023,7 @@ static int vring_alloc_queue_packed(struct vring_virtqueue_packed *vring_packed,
 	device = vring_alloc_queue(vdev, event_size_in_bytes,
 				   &device_event_dma_addr,
 				   GFP_KERNEL | __GFP_NOWARN | __GFP_ZERO,
-				   dma_dev);
+				   map);
 	if (!device)
 		goto err;
 
@@ -2036,7 +2035,7 @@ static int vring_alloc_queue_packed(struct vring_virtqueue_packed *vring_packed,
 	return 0;
 
 err:
-	vring_free_packed(vring_packed, vdev, dma_dev);
+	vring_free_packed(vring_packed, vdev, map);
 	return -ENOMEM;
 }
 
@@ -2112,7 +2111,7 @@ static struct virtqueue *__vring_new_virtqueue_packed(unsigned int index,
 					       bool (*notify)(struct virtqueue *),
 					       void (*callback)(struct virtqueue *),
 					       const char *name,
-					       struct device *dma_dev)
+					       union virtio_map map)
 {
 	struct vring_virtqueue *vq;
 	int err;
@@ -2135,8 +2134,8 @@ static struct virtqueue *__vring_new_virtqueue_packed(unsigned int index,
 	vq->broken = false;
 #endif
 	vq->packed_ring = true;
-	vq->dma_dev = dma_dev;
-	vq->use_dma_api = vring_use_dma_api(vdev);
+	vq->map = map;
+	vq->use_map_api = vring_use_map_api(vdev);
 
 	vq->indirect = virtio_has_feature(vdev, VIRTIO_RING_F_INDIRECT_DESC) &&
 		!context;
@@ -2173,18 +2172,18 @@ static struct virtqueue *vring_create_virtqueue_packed(
 	bool (*notify)(struct virtqueue *),
 	void (*callback)(struct virtqueue *),
 	const char *name,
-	struct device *dma_dev)
+	union virtio_map map)
 {
 	struct vring_virtqueue_packed vring_packed = {};
 	struct virtqueue *vq;
 
-	if (vring_alloc_queue_packed(&vring_packed, vdev, num, dma_dev))
+	if (vring_alloc_queue_packed(&vring_packed, vdev, num, map))
 		return NULL;
 
 	vq = __vring_new_virtqueue_packed(index, &vring_packed, vdev, weak_barriers,
-					context, notify, callback, name, dma_dev);
+					context, notify, callback, name, map);
 	if (!vq) {
-		vring_free_packed(&vring_packed, vdev, dma_dev);
+		vring_free_packed(&vring_packed, vdev, map);
 		return NULL;
 	}
 
@@ -2200,7 +2199,7 @@ static int virtqueue_resize_packed(struct virtqueue *_vq, u32 num)
 	struct virtio_device *vdev = _vq->vdev;
 	int err;
 
-	if (vring_alloc_queue_packed(&vring_packed, vdev, num, vring_dma_dev(vq)))
+	if (vring_alloc_queue_packed(&vring_packed, vdev, num, vq->map))
 		goto err_ring;
 
 	err = vring_alloc_state_extra_packed(&vring_packed);
@@ -2217,7 +2216,7 @@ static int virtqueue_resize_packed(struct virtqueue *_vq, u32 num)
 	return 0;
 
 err_state_extra:
-	vring_free_packed(&vring_packed, vdev, vring_dma_dev(vq));
+	vring_free_packed(&vring_packed, vdev, vq->map);
 err_ring:
 	virtqueue_reinit_packed(vq);
 	return -ENOMEM;
@@ -2448,8 +2447,8 @@ struct device *virtqueue_dma_dev(struct virtqueue *_vq)
 {
 	struct vring_virtqueue *vq = to_vvq(_vq);
 
-	if (vq->use_dma_api)
-		return vring_dma_dev(vq);
+	if (vq->use_map_api && !_vq->vdev->map)
+		return vq->map.dma_dev;
 	else
 		return NULL;
 }
@@ -2734,19 +2733,20 @@ struct virtqueue *vring_create_virtqueue(
 	void (*callback)(struct virtqueue *),
 	const char *name)
 {
+	union virtio_map map = {.dma_dev = vdev->dev.parent};
 
 	if (virtio_has_feature(vdev, VIRTIO_F_RING_PACKED))
 		return vring_create_virtqueue_packed(index, num, vring_align,
 				vdev, weak_barriers, may_reduce_num,
-				context, notify, callback, name, vdev->dev.parent);
+				context, notify, callback, name, map);
 
 	return vring_create_virtqueue_split(index, num, vring_align,
 			vdev, weak_barriers, may_reduce_num,
-			context, notify, callback, name, vdev->dev.parent);
+			context, notify, callback, name, map);
 }
 EXPORT_SYMBOL_GPL(vring_create_virtqueue);
 
-struct virtqueue *vring_create_virtqueue_dma(
+struct virtqueue *vring_create_virtqueue_map(
 	unsigned int index,
 	unsigned int num,
 	unsigned int vring_align,
@@ -2757,19 +2757,19 @@ struct virtqueue *vring_create_virtqueue_dma(
 	bool (*notify)(struct virtqueue *),
 	void (*callback)(struct virtqueue *),
 	const char *name,
-	struct device *dma_dev)
+	union virtio_map map)
 {
 
 	if (virtio_has_feature(vdev, VIRTIO_F_RING_PACKED))
 		return vring_create_virtqueue_packed(index, num, vring_align,
 				vdev, weak_barriers, may_reduce_num,
-				context, notify, callback, name, dma_dev);
+				context, notify, callback, name, map);
 
 	return vring_create_virtqueue_split(index, num, vring_align,
 			vdev, weak_barriers, may_reduce_num,
-			context, notify, callback, name, dma_dev);
+			context, notify, callback, name, map);
 }
-EXPORT_SYMBOL_GPL(vring_create_virtqueue_dma);
+EXPORT_SYMBOL_GPL(vring_create_virtqueue_map);
 
 /**
  * virtqueue_resize - resize the vring of vq
@@ -2880,6 +2880,7 @@ struct virtqueue *vring_new_virtqueue(unsigned int index,
 				      const char *name)
 {
 	struct vring_virtqueue_split vring_split = {};
+	union virtio_map map = {.dma_dev = vdev->dev.parent};
 
 	if (virtio_has_feature(vdev, VIRTIO_F_RING_PACKED)) {
 		struct vring_virtqueue_packed vring_packed = {};
@@ -2889,13 +2890,13 @@ struct virtqueue *vring_new_virtqueue(unsigned int index,
 		return __vring_new_virtqueue_packed(index, &vring_packed,
 						    vdev, weak_barriers,
 						    context, notify, callback,
-						    name, vdev->dev.parent);
+						    name, map);
 	}
 
 	vring_init(&vring_split.vring, num, pages, vring_align);
 	return __vring_new_virtqueue_split(index, &vring_split, vdev, weak_barriers,
 				     context, notify, callback, name,
-				     vdev->dev.parent);
+				     map);
 }
 EXPORT_SYMBOL_GPL(vring_new_virtqueue);
 
@@ -2909,19 +2910,19 @@ static void vring_free(struct virtqueue *_vq)
 					 vq->packed.ring_size_in_bytes,
 					 vq->packed.vring.desc,
 					 vq->packed.ring_dma_addr,
-					 vring_dma_dev(vq));
+					 vq->map);
 
 			vring_free_queue(vq->vq.vdev,
 					 vq->packed.event_size_in_bytes,
 					 vq->packed.vring.driver,
 					 vq->packed.driver_event_dma_addr,
-					 vring_dma_dev(vq));
+					 vq->map);
 
 			vring_free_queue(vq->vq.vdev,
 					 vq->packed.event_size_in_bytes,
 					 vq->packed.vring.device,
 					 vq->packed.device_event_dma_addr,
-					 vring_dma_dev(vq));
+					 vq->map);
 
 			kfree(vq->packed.desc_state);
 			kfree(vq->packed.desc_extra);
@@ -2930,7 +2931,7 @@ static void vring_free(struct virtqueue *_vq)
 					 vq->split.queue_size_in_bytes,
 					 vq->split.vring.desc,
 					 vq->split.queue_dma_addr,
-					 vring_dma_dev(vq));
+					 vq->map);
 		}
 	}
 	if (!vq->packed_ring) {
@@ -3137,7 +3138,108 @@ const struct vring *virtqueue_get_vring(const struct virtqueue *vq)
 EXPORT_SYMBOL_GPL(virtqueue_get_vring);
 
 /**
- * virtqueue_dma_map_single_attrs - map DMA for _vq
+ * virtqueue_map_alloc_coherent - alloc coherent mapping
+ * @vdev: the virtio device we are talking to
+ * @map: metadata for performing mapping
+ * @size: the size of the buffer
+ * @map_handle: the pointer to the mapped address
+ * @gfp: allocation flag (GFP_XXX)
+ *
+ * return virtual address or NULL on error
+ */
+void *virtqueue_map_alloc_coherent(struct virtio_device *vdev,
+				   union virtio_map map,
+				   size_t size, dma_addr_t *map_handle,
+				   gfp_t gfp)
+{
+	if (vdev->map)
+		return vdev->map->alloc(map, size,
+					map_handle, gfp);
+	else
+		return dma_alloc_coherent(map.dma_dev, size,
+					  map_handle, gfp);
+}
+EXPORT_SYMBOL_GPL(virtqueue_map_alloc_coherent);
+
+/**
+ * virtqueue_map_free_coherent - free coherent mapping
+ * @vdev: the virtio device we are talking to
+ * @map: metadata for performing mapping
+ * @size: the size of the buffer
+ * @map_handle: the mapped address that needs to be freed
+ *
+ */
+void virtqueue_map_free_coherent(struct virtio_device *vdev,
+				 union virtio_map map, size_t size, void *vaddr,
+				 dma_addr_t map_handle)
+{
+	if (vdev->map)
+		vdev->map->free(map, size, vaddr,
+				map_handle, 0);
+	else
+		dma_free_coherent(map.dma_dev, size, vaddr, map_handle);
+}
+EXPORT_SYMBOL_GPL(virtqueue_map_free_coherent);
+
+/**
+ * virtqueue_map_page_attrs - map a page to the device
+ * @_vq: the virtqueue we are talking to
+ * @page: the page that will be mapped by the device
+ * @offset: the offset in the page for a buffer
+ * @size: the buffer size
+ * @dir: mapping direction
+ * @attrs: mapping attributes
+ *
+ * Returns mapped address. Caller should check that by virtqueue_mapping_error().
+ */
+dma_addr_t virtqueue_map_page_attrs(const struct virtqueue *_vq,
+				    struct page *page,
+				    unsigned long offset,
+				    size_t size,
+				    enum dma_data_direction dir,
+				    unsigned long attrs)
+{
+	const struct vring_virtqueue *vq = to_vvq(_vq);
+	struct virtio_device *vdev = _vq->vdev;
+
+	if (vdev->map)
+		return vdev->map->map_page(vq->map,
+					   page, offset, size,
+					   dir, attrs);
+
+	return dma_map_page_attrs(vring_dma_dev(vq),
+				  page, offset, size,
+				  dir, attrs);
+}
+EXPORT_SYMBOL_GPL(virtqueue_map_page_attrs);
+
+/**
+ * virtqueue_unmap_page_attrs - map a page to the device
+ * @_vq: the virtqueue we are talking to
+ * @map_handle: the mapped address
+ * @size: the buffer size
+ * @dir: mapping direction
+ * @attrs: unmapping attributes
+ */
+void virtqueue_unmap_page_attrs(const struct virtqueue *_vq,
+				dma_addr_t map_handle,
+				size_t size, enum dma_data_direction dir,
+				unsigned long attrs)
+{
+	const struct vring_virtqueue *vq = to_vvq(_vq);
+	struct virtio_device *vdev = _vq->vdev;
+
+	if (vdev->map)
+		vdev->map->unmap_page(vq->map,
+				      map_handle, size, dir, attrs);
+	else
+		dma_unmap_page_attrs(vring_dma_dev(vq), map_handle,
+				     size, dir, attrs);
+}
+EXPORT_SYMBOL_GPL(virtqueue_unmap_page_attrs);
+
+/**
+ * virtqueue_map_single_attrs - map DMA for _vq
  * @_vq: the struct virtqueue we're talking about.
  * @ptr: the pointer of the buffer to do dma
  * @size: the size of the buffer to do dma
@@ -3147,139 +3249,158 @@ EXPORT_SYMBOL_GPL(virtqueue_get_vring);
  * The caller calls this to do dma mapping in advance. The DMA address can be
  * passed to this _vq when it is in pre-mapped mode.
  *
- * return DMA address. Caller should check that by virtqueue_dma_mapping_error().
+ * return mapped address. Caller should check that by virtqueue_mapping_error().
  */
-dma_addr_t virtqueue_dma_map_single_attrs(struct virtqueue *_vq, void *ptr,
-					  size_t size,
-					  enum dma_data_direction dir,
-					  unsigned long attrs)
+dma_addr_t virtqueue_map_single_attrs(const struct virtqueue *_vq, void *ptr,
+				      size_t size,
+				      enum dma_data_direction dir,
+				      unsigned long attrs)
 {
-	struct vring_virtqueue *vq = to_vvq(_vq);
+	const struct vring_virtqueue *vq = to_vvq(_vq);
 
-	if (!vq->use_dma_api) {
+	if (!vq->use_map_api) {
 		kmsan_handle_dma(virt_to_phys(ptr), size, dir);
 		return (dma_addr_t)virt_to_phys(ptr);
 	}
 
-	return dma_map_single_attrs(vring_dma_dev(vq), ptr, size, dir, attrs);
+	/* DMA must never operate on areas that might be remapped. */
+	if (dev_WARN_ONCE(&_vq->vdev->dev, is_vmalloc_addr(ptr),
+			  "rejecting DMA map of vmalloc memory\n"))
+		return DMA_MAPPING_ERROR;
+
+	return virtqueue_map_page_attrs(&vq->vq, virt_to_page(ptr),
+					offset_in_page(ptr), size, dir, attrs);
 }
-EXPORT_SYMBOL_GPL(virtqueue_dma_map_single_attrs);
+EXPORT_SYMBOL_GPL(virtqueue_map_single_attrs);
 
 /**
- * virtqueue_dma_unmap_single_attrs - unmap DMA for _vq
+ * virtqueue_unmap_single_attrs - unmap map for _vq
  * @_vq: the struct virtqueue we're talking about.
  * @addr: the dma address to unmap
  * @size: the size of the buffer
  * @dir: DMA direction
  * @attrs: DMA Attrs
  *
- * Unmap the address that is mapped by the virtqueue_dma_map_* APIs.
+ * Unmap the address that is mapped by the virtqueue_map_* APIs.
  *
  */
-void virtqueue_dma_unmap_single_attrs(struct virtqueue *_vq, dma_addr_t addr,
-				      size_t size, enum dma_data_direction dir,
-				      unsigned long attrs)
+void virtqueue_unmap_single_attrs(const struct virtqueue *_vq,
+				  dma_addr_t addr,
+				  size_t size, enum dma_data_direction dir,
+				  unsigned long attrs)
 {
-	struct vring_virtqueue *vq = to_vvq(_vq);
+	const struct vring_virtqueue *vq = to_vvq(_vq);
 
-	if (!vq->use_dma_api)
+	if (!vq->use_map_api)
 		return;
 
-	dma_unmap_single_attrs(vring_dma_dev(vq), addr, size, dir, attrs);
+	virtqueue_unmap_page_attrs(_vq, addr, size, dir, attrs);
 }
-EXPORT_SYMBOL_GPL(virtqueue_dma_unmap_single_attrs);
+EXPORT_SYMBOL_GPL(virtqueue_unmap_single_attrs);
 
 /**
- * virtqueue_dma_mapping_error - check dma address
+ * virtqueue_mapping_error - check dma address
  * @_vq: the struct virtqueue we're talking about.
  * @addr: DMA address
  *
  * Returns 0 means dma valid. Other means invalid dma address.
  */
-int virtqueue_dma_mapping_error(struct virtqueue *_vq, dma_addr_t addr)
+int virtqueue_map_mapping_error(const struct virtqueue *_vq, dma_addr_t addr)
 {
-	struct vring_virtqueue *vq = to_vvq(_vq);
+	const struct vring_virtqueue *vq = to_vvq(_vq);
 
-	if (!vq->use_dma_api)
-		return 0;
-
-	return dma_mapping_error(vring_dma_dev(vq), addr);
+	return vring_mapping_error(vq, addr);
 }
-EXPORT_SYMBOL_GPL(virtqueue_dma_mapping_error);
+EXPORT_SYMBOL_GPL(virtqueue_map_mapping_error);
 
 /**
- * virtqueue_dma_need_sync - check a dma address needs sync
+ * virtqueue_map_need_sync - check a dma address needs sync
  * @_vq: the struct virtqueue we're talking about.
  * @addr: DMA address
  *
- * Check if the dma address mapped by the virtqueue_dma_map_* APIs needs to be
+ * Check if the dma address mapped by the virtqueue_map_* APIs needs to be
  * synchronized
  *
  * return bool
  */
-bool virtqueue_dma_need_sync(struct virtqueue *_vq, dma_addr_t addr)
+bool virtqueue_map_need_sync(const struct virtqueue *_vq, dma_addr_t addr)
 {
-	struct vring_virtqueue *vq = to_vvq(_vq);
+	const struct vring_virtqueue *vq = to_vvq(_vq);
+	struct virtio_device *vdev = _vq->vdev;
 
-	if (!vq->use_dma_api)
+	if (!vq->use_map_api)
 		return false;
 
-	return dma_need_sync(vring_dma_dev(vq), addr);
+	if (vdev->map)
+		return vdev->map->need_sync(vq->map, addr);
+	else
+		return dma_need_sync(vring_dma_dev(vq), addr);
 }
-EXPORT_SYMBOL_GPL(virtqueue_dma_need_sync);
+EXPORT_SYMBOL_GPL(virtqueue_map_need_sync);
 
 /**
- * virtqueue_dma_sync_single_range_for_cpu - dma sync for cpu
+ * virtqueue_map_sync_single_range_for_cpu - map sync for cpu
  * @_vq: the struct virtqueue we're talking about.
  * @addr: DMA address
  * @offset: DMA address offset
  * @size: buf size for sync
  * @dir: DMA direction
  *
- * Before calling this function, use virtqueue_dma_need_sync() to confirm that
+ * Before calling this function, use virtqueue_map_need_sync() to confirm that
  * the DMA address really needs to be synchronized
  *
  */
-void virtqueue_dma_sync_single_range_for_cpu(struct virtqueue *_vq,
+void virtqueue_map_sync_single_range_for_cpu(const struct virtqueue *_vq,
 					     dma_addr_t addr,
 					     unsigned long offset, size_t size,
 					     enum dma_data_direction dir)
 {
-	struct vring_virtqueue *vq = to_vvq(_vq);
-	struct device *dev = vring_dma_dev(vq);
+	const struct vring_virtqueue *vq = to_vvq(_vq);
+	struct virtio_device *vdev = _vq->vdev;
 
-	if (!vq->use_dma_api)
+	if (!vq->use_map_api)
 		return;
 
-	dma_sync_single_range_for_cpu(dev, addr, offset, size, dir);
+	if (vdev->map)
+		vdev->map->sync_single_for_cpu(vq->map,
+					       addr + offset, size, dir);
+	else
+		dma_sync_single_range_for_cpu(vring_dma_dev(vq),
+					      addr, offset, size, dir);
 }
-EXPORT_SYMBOL_GPL(virtqueue_dma_sync_single_range_for_cpu);
+EXPORT_SYMBOL_GPL(virtqueue_map_sync_single_range_for_cpu);
 
 /**
- * virtqueue_dma_sync_single_range_for_device - dma sync for device
+ * virtqueue_map_sync_single_range_for_device - map sync for device
  * @_vq: the struct virtqueue we're talking about.
  * @addr: DMA address
  * @offset: DMA address offset
  * @size: buf size for sync
  * @dir: DMA direction
  *
- * Before calling this function, use virtqueue_dma_need_sync() to confirm that
+ * Before calling this function, use virtqueue_map_need_sync() to confirm that
  * the DMA address really needs to be synchronized
  */
-void virtqueue_dma_sync_single_range_for_device(struct virtqueue *_vq,
+void virtqueue_map_sync_single_range_for_device(const struct virtqueue *_vq,
 						dma_addr_t addr,
 						unsigned long offset, size_t size,
 						enum dma_data_direction dir)
 {
-	struct vring_virtqueue *vq = to_vvq(_vq);
-	struct device *dev = vring_dma_dev(vq);
+	const struct vring_virtqueue *vq = to_vvq(_vq);
+	struct virtio_device *vdev = _vq->vdev;
 
-	if (!vq->use_dma_api)
+	if (!vq->use_map_api)
 		return;
 
-	dma_sync_single_range_for_device(dev, addr, offset, size, dir);
+	if (vdev->map)
+		vdev->map->sync_single_for_device(vq->map,
+						  addr + offset,
+						  size, dir);
+	else
+		dma_sync_single_range_for_device(vring_dma_dev(vq), addr,
+						 offset, size, dir);
 }
-EXPORT_SYMBOL_GPL(virtqueue_dma_sync_single_range_for_device);
+EXPORT_SYMBOL_GPL(virtqueue_map_sync_single_range_for_device);
 
 MODULE_DESCRIPTION("Virtio ring implementation");
 MODULE_LICENSE("GPL");
diff --git a/drivers/virtio/virtio_vdpa.c b/drivers/virtio/virtio_vdpa.c
index 657b07a..f9a2904 100644
--- a/drivers/virtio/virtio_vdpa.c
+++ b/drivers/virtio/virtio_vdpa.c
@@ -133,12 +133,12 @@ virtio_vdpa_setup_vq(struct virtio_device *vdev, unsigned int index,
 		     const char *name, bool ctx)
 {
 	struct vdpa_device *vdpa = vd_get_vdpa(vdev);
-	struct device *dma_dev;
 	const struct vdpa_config_ops *ops = vdpa->config;
 	bool (*notify)(struct virtqueue *vq) = virtio_vdpa_notify;
 	struct vdpa_callback cb;
 	struct virtqueue *vq;
 	u64 desc_addr, driver_addr, device_addr;
+	union virtio_map map = {0};
 	/* Assume split virtqueue, switch to packed if necessary */
 	struct vdpa_vq_state state = {0};
 	u32 align, max_num, min_num = 1;
@@ -176,23 +176,27 @@ virtio_vdpa_setup_vq(struct virtio_device *vdev, unsigned int index,
 	if (ops->get_vq_num_min)
 		min_num = ops->get_vq_num_min(vdpa);
 
-	may_reduce_num = (max_num == min_num) ? false : true;
+	may_reduce_num = (max_num != min_num);
 
 	/* Create the vring */
 	align = ops->get_vq_align(vdpa);
 
-	if (ops->get_vq_dma_dev)
-		dma_dev = ops->get_vq_dma_dev(vdpa, index);
+	if (ops->get_vq_map)
+		map = ops->get_vq_map(vdpa, index);
 	else
-		dma_dev = vdpa_get_dma_dev(vdpa);
-	vq = vring_create_virtqueue_dma(index, max_num, align, vdev,
+		map = vdpa_get_map(vdpa);
+
+	vq = vring_create_virtqueue_map(index, max_num, align, vdev,
 					true, may_reduce_num, ctx,
-					notify, callback, name, dma_dev);
+					notify, callback, name, map);
 	if (!vq) {
 		err = -ENOMEM;
 		goto error_new_virtqueue;
 	}
 
+	if (index == 0)
+		vdev->vmap = map;
+
 	vq->num_max = max_num;
 
 	/* Setup virtqueue callback */
@@ -462,9 +466,11 @@ static int virtio_vdpa_probe(struct vdpa_device *vdpa)
 	if (!vd_dev)
 		return -ENOMEM;
 
-	vd_dev->vdev.dev.parent = vdpa_get_dma_dev(vdpa);
+	vd_dev->vdev.dev.parent = vdpa->map ? &vdpa->dev :
+				  vdpa_get_map(vdpa).dma_dev;
 	vd_dev->vdev.dev.release = virtio_vdpa_release_dev;
 	vd_dev->vdev.config = &virtio_vdpa_config_ops;
+	vd_dev->vdev.map = vdpa->map;
 	vd_dev->vdpa = vdpa;
 
 	vd_dev->vdev.id.device = ops->get_device_id(vdpa);
diff --git a/include/linux/vdpa.h b/include/linux/vdpa.h
index 2e7a30f..4cf21d6 100644
--- a/include/linux/vdpa.h
+++ b/include/linux/vdpa.h
@@ -5,6 +5,7 @@
 #include <linux/kernel.h>
 #include <linux/device.h>
 #include <linux/interrupt.h>
+#include <linux/virtio.h>
 #include <linux/vhost_iotlb.h>
 #include <linux/virtio_net.h>
 #include <linux/virtio_blk.h>
@@ -70,11 +71,12 @@ struct vdpa_mgmt_dev;
 /**
  * struct vdpa_device - representation of a vDPA device
  * @dev: underlying device
- * @dma_dev: the actual device that is performing DMA
+ * @vmap: the metadata passed to upper layer to be used for mapping
  * @driver_override: driver name to force a match; do not set directly,
  *                   because core frees it; use driver_set_override() to
  *                   set or clear it.
  * @config: the configuration ops for this device.
+ * @map: the map ops for this device
  * @cf_lock: Protects get and set access to configuration layout.
  * @index: device index
  * @features_valid: were features initialized? for legacy guests
@@ -87,9 +89,10 @@ struct vdpa_mgmt_dev;
  */
 struct vdpa_device {
 	struct device dev;
-	struct device *dma_dev;
+	union virtio_map vmap;
 	const char *driver_override;
 	const struct vdpa_config_ops *config;
+	const struct virtio_map_ops *map;
 	struct rw_semaphore cf_lock; /* Protects get/set config */
 	unsigned int index;
 	bool features_valid;
@@ -352,11 +355,11 @@ struct vdpa_map_file {
  *				@vdev: vdpa device
  *				@asid: address space identifier
  *				Returns integer: success (0) or error (< 0)
- * @get_vq_dma_dev:		Get the dma device for a specific
+ * @get_vq_map:		Get the map metadata for a specific
  *				virtqueue (optional)
  *				@vdev: vdpa device
  *				@idx: virtqueue index
- *				Returns pointer to structure device or error (NULL)
+ *				Returns map token union error (NULL)
  * @bind_mm:			Bind the device to a specific address space
  *				so the vDPA framework can use VA when this
  *				callback is implemented. (optional)
@@ -436,7 +439,7 @@ struct vdpa_config_ops {
 	int (*reset_map)(struct vdpa_device *vdev, unsigned int asid);
 	int (*set_group_asid)(struct vdpa_device *vdev, unsigned int group,
 			      unsigned int asid);
-	struct device *(*get_vq_dma_dev)(struct vdpa_device *vdev, u16 idx);
+	union virtio_map (*get_vq_map)(struct vdpa_device *vdev, u16 idx);
 	int (*bind_mm)(struct vdpa_device *vdev, struct mm_struct *mm);
 	void (*unbind_mm)(struct vdpa_device *vdev);
 
@@ -446,6 +449,7 @@ struct vdpa_config_ops {
 
 struct vdpa_device *__vdpa_alloc_device(struct device *parent,
 					const struct vdpa_config_ops *config,
+					const struct virtio_map_ops *map,
 					unsigned int ngroups, unsigned int nas,
 					size_t size, const char *name,
 					bool use_va);
@@ -457,6 +461,7 @@ struct vdpa_device *__vdpa_alloc_device(struct device *parent,
  * @member: the name of struct vdpa_device within the @dev_struct
  * @parent: the parent device
  * @config: the bus operations that is supported by this device
+ * @map: the map operations that is supported by this device
  * @ngroups: the number of virtqueue groups supported by this device
  * @nas: the number of address spaces
  * @name: name of the vdpa device
@@ -464,10 +469,10 @@ struct vdpa_device *__vdpa_alloc_device(struct device *parent,
  *
  * Return allocated data structure or ERR_PTR upon error
  */
-#define vdpa_alloc_device(dev_struct, member, parent, config, ngroups, nas, \
-			  name, use_va) \
+#define vdpa_alloc_device(dev_struct, member, parent, config, map, \
+			  ngroups, nas, name, use_va)		   \
 			  container_of((__vdpa_alloc_device( \
-				       parent, config, ngroups, nas, \
+				       parent, config, map, ngroups, nas, \
 				       (sizeof(dev_struct) + \
 				       BUILD_BUG_ON_ZERO(offsetof( \
 				       dev_struct, member))), name, use_va)), \
@@ -520,9 +525,9 @@ static inline void vdpa_set_drvdata(struct vdpa_device *vdev, void *data)
 	dev_set_drvdata(&vdev->dev, data);
 }
 
-static inline struct device *vdpa_get_dma_dev(struct vdpa_device *vdev)
+static inline union virtio_map vdpa_get_map(struct vdpa_device *vdev)
 {
-	return vdev->dma_dev;
+	return vdev->vmap;
 }
 
 static inline int vdpa_reset(struct vdpa_device *vdev, u32 flags)
diff --git a/include/linux/virtio.h b/include/linux/virtio.h
index db31fc6..96c6612 100644
--- a/include/linux/virtio.h
+++ b/include/linux/virtio.h
@@ -41,6 +41,15 @@ struct virtqueue {
 	void *priv;
 };
 
+struct vduse_iova_domain;
+
+union virtio_map {
+	/* Device that performs DMA */
+	struct device *dma_dev;
+	/* VDUSE specific mapping data */
+	struct vduse_iova_domain *iova_domain;
+};
+
 int virtqueue_add_outbuf(struct virtqueue *vq,
 			 struct scatterlist sg[], unsigned int num,
 			 void *data,
@@ -161,9 +170,11 @@ struct virtio_device {
 	struct virtio_device_id id;
 	const struct virtio_config_ops *config;
 	const struct vringh_config_ops *vringh_config;
+	const struct virtio_map_ops *map;
 	struct list_head vqs;
 	VIRTIO_DECLARE_FEATURES(features);
 	void *priv;
+	union virtio_map vmap;
 #ifdef CONFIG_VIRTIO_DEBUG
 	struct dentry *debugfs_dir;
 	u64 debugfs_filter_features[VIRTIO_FEATURES_DWORDS];
@@ -262,18 +273,41 @@ void unregister_virtio_driver(struct virtio_driver *drv);
 	module_driver(__virtio_driver, register_virtio_driver, \
 			unregister_virtio_driver)
 
-dma_addr_t virtqueue_dma_map_single_attrs(struct virtqueue *_vq, void *ptr, size_t size,
+
+void *virtqueue_map_alloc_coherent(struct virtio_device *vdev,
+				   union virtio_map mapping_token,
+				   size_t size, dma_addr_t *dma_handle,
+				   gfp_t gfp);
+
+void virtqueue_map_free_coherent(struct virtio_device *vdev,
+				 union virtio_map mapping_token,
+				 size_t size, void *vaddr,
+				 dma_addr_t dma_handle);
+
+dma_addr_t virtqueue_map_page_attrs(const struct virtqueue *_vq,
+				    struct page *page,
+				    unsigned long offset,
+				    size_t size,
+				    enum dma_data_direction dir,
+				    unsigned long attrs);
+
+void virtqueue_unmap_page_attrs(const struct virtqueue *_vq,
+				dma_addr_t dma_handle,
+				size_t size, enum dma_data_direction dir,
+				unsigned long attrs);
+
+dma_addr_t virtqueue_map_single_attrs(const struct virtqueue *_vq, void *ptr, size_t size,
 					  enum dma_data_direction dir, unsigned long attrs);
-void virtqueue_dma_unmap_single_attrs(struct virtqueue *_vq, dma_addr_t addr,
+void virtqueue_unmap_single_attrs(const struct virtqueue *_vq, dma_addr_t addr,
 				      size_t size, enum dma_data_direction dir,
 				      unsigned long attrs);
-int virtqueue_dma_mapping_error(struct virtqueue *_vq, dma_addr_t addr);
+int virtqueue_map_mapping_error(const struct virtqueue *_vq, dma_addr_t addr);
 
-bool virtqueue_dma_need_sync(struct virtqueue *_vq, dma_addr_t addr);
-void virtqueue_dma_sync_single_range_for_cpu(struct virtqueue *_vq, dma_addr_t addr,
+bool virtqueue_map_need_sync(const struct virtqueue *_vq, dma_addr_t addr);
+void virtqueue_map_sync_single_range_for_cpu(const struct virtqueue *_vq, dma_addr_t addr,
 					     unsigned long offset, size_t size,
 					     enum dma_data_direction dir);
-void virtqueue_dma_sync_single_range_for_device(struct virtqueue *_vq, dma_addr_t addr,
+void virtqueue_map_sync_single_range_for_device(const struct virtqueue *_vq, dma_addr_t addr,
 						unsigned long offset, size_t size,
 						enum dma_data_direction dir);
 
diff --git a/include/linux/virtio_config.h b/include/linux/virtio_config.h
index 7427b79..16001e9 100644
--- a/include/linux/virtio_config.h
+++ b/include/linux/virtio_config.h
@@ -139,6 +139,78 @@ struct virtio_config_ops {
 	int (*enable_vq_after_reset)(struct virtqueue *vq);
 };
 
+/**
+ * struct virtio_map_ops - operations for mapping buffer for a virtio device
+ * Note: For transport that has its own mapping logic it must
+ * implements all of the operations
+ * @map_page: map a buffer to the device
+ *      map: metadata for performing mapping
+ *      page: the page that will be mapped by the device
+ *      offset: the offset in the page for a buffer
+ *      size: the buffer size
+ *      dir: mapping direction
+ *      attrs: mapping attributes
+ *      Returns: the mapped address
+ * @unmap_page: unmap a buffer from the device
+ *      map: device specific mapping map
+ *      map_handle: the mapped address
+ *      size: the buffer size
+ *      dir: mapping direction
+ *      attrs: unmapping attributes
+ * @sync_single_for_cpu: sync a single buffer from device to cpu
+ *      map: metadata for performing mapping
+ *      map_handle: the mapping address to sync
+ *      size: the size of the buffer
+ *      dir: synchronization direction
+ * @sync_single_for_device: sync a single buffer from cpu to device
+ *      map: metadata for performing mapping
+ *      map_handle: the mapping address to sync
+ *      size: the size of the buffer
+ *      dir: synchronization direction
+ * @alloc: alloc a coherent buffer mapping
+ *      map: metadata for performing mapping
+ *      size: the size of the buffer
+ *      map_handle: the mapping address to sync
+ *      gfp: allocation flag (GFP_XXX)
+ *      Returns: virtual address of the allocated buffer
+ * @free: free a coherent buffer mapping
+ *      map: metadata for performing mapping
+ *      size: the size of the buffer
+ *      vaddr: virtual address of the buffer
+ *      map_handle: the mapping address to sync
+ *      attrs: unmapping attributes
+ * @need_sync: if the buffer needs synchronization
+ *      map: metadata for performing mapping
+ *      map_handle: the mapped address
+ *      Returns: whether the buffer needs synchronization
+ * @mapping_error: if the mapping address is error
+ *      map: metadata for performing mapping
+ *      map_handle: the mapped address
+ * @max_mapping_size: get the maximum buffer size that can be mapped
+ *      map: metadata for performing mapping
+ *      Returns: the maximum buffer size that can be mapped
+ */
+struct virtio_map_ops {
+	dma_addr_t (*map_page)(union virtio_map map, struct page *page,
+			       unsigned long offset, size_t size,
+			       enum dma_data_direction dir, unsigned long attrs);
+	void (*unmap_page)(union virtio_map map, dma_addr_t map_handle,
+			   size_t size, enum dma_data_direction dir,
+			   unsigned long attrs);
+	void (*sync_single_for_cpu)(union virtio_map map, dma_addr_t map_handle,
+				    size_t size, enum dma_data_direction dir);
+	void (*sync_single_for_device)(union virtio_map map,
+				       dma_addr_t map_handle, size_t size,
+				       enum dma_data_direction dir);
+	void *(*alloc)(union virtio_map map, size_t size,
+		       dma_addr_t *map_handle, gfp_t gfp);
+	void (*free)(union virtio_map map, size_t size, void *vaddr,
+		     dma_addr_t map_handle, unsigned long attrs);
+	bool (*need_sync)(union virtio_map map, dma_addr_t map_handle);
+	int (*mapping_error)(union virtio_map map, dma_addr_t map_handle);
+	size_t (*max_mapping_size)(union virtio_map map);
+};
+
 /* If driver didn't advertise the feature, it will never appear. */
 void virtio_check_driver_offered_feature(const struct virtio_device *vdev,
 					 unsigned int fbit);
diff --git a/include/linux/virtio_ring.h b/include/linux/virtio_ring.h
index 9b33df7..c97a12c 100644
--- a/include/linux/virtio_ring.h
+++ b/include/linux/virtio_ring.h
@@ -3,6 +3,7 @@
 #define _LINUX_VIRTIO_RING_H
 
 #include <asm/barrier.h>
+#include <linux/virtio.h>
 #include <linux/irqreturn.h>
 #include <uapi/linux/virtio_ring.h>
 
@@ -79,9 +80,9 @@ struct virtqueue *vring_create_virtqueue(unsigned int index,
 
 /*
  * Creates a virtqueue and allocates the descriptor ring with per
- * virtqueue DMA device.
+ * virtqueue mapping operations.
  */
-struct virtqueue *vring_create_virtqueue_dma(unsigned int index,
+struct virtqueue *vring_create_virtqueue_map(unsigned int index,
 					     unsigned int num,
 					     unsigned int vring_align,
 					     struct virtio_device *vdev,
@@ -91,7 +92,7 @@ struct virtqueue *vring_create_virtqueue_dma(unsigned int index,
 					     bool (*notify)(struct virtqueue *vq),
 					     void (*callback)(struct virtqueue *vq),
 					     const char *name,
-					     struct device *dma_dev);
+					     union virtio_map map);
 
 /*
  * Creates a virtqueue with a standard layout but a caller-allocated