diff options
Diffstat (limited to 'drivers/vhost/vhost.c')
-rw-r--r-- | drivers/vhost/vhost.c | 43 |
1 files changed, 22 insertions, 21 deletions
diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c index 159c77a5746..38244f59cdd 100644 --- a/drivers/vhost/vhost.c +++ b/drivers/vhost/vhost.c @@ -15,6 +15,7 @@ #include <linux/vhost.h> #include <linux/virtio_net.h> #include <linux/mm.h> +#include <linux/mmu_context.h> #include <linux/miscdevice.h> #include <linux/mutex.h> #include <linux/rcupdate.h> @@ -29,8 +30,6 @@ #include <linux/if_packet.h> #include <linux/if_arp.h> -#include <net/sock.h> - #include "vhost.h" enum { @@ -157,7 +156,6 @@ static void vhost_vq_reset(struct vhost_dev *dev, vq->avail_idx = 0; vq->last_used_idx = 0; vq->used_flags = 0; - vq->used_flags = 0; vq->log_used = false; vq->log_addr = -1ull; vq->vhost_hlen = 0; @@ -178,6 +176,8 @@ static int vhost_worker(void *data) struct vhost_work *work = NULL; unsigned uninitialized_var(seq); + use_mm(dev->mm); + for (;;) { /* mb paired w/ kthread_stop */ set_current_state(TASK_INTERRUPTIBLE); @@ -192,7 +192,7 @@ static int vhost_worker(void *data) if (kthread_should_stop()) { spin_unlock_irq(&dev->work_lock); __set_current_state(TASK_RUNNING); - return 0; + break; } if (!list_empty(&dev->work_list)) { work = list_first_entry(&dev->work_list, @@ -210,6 +210,8 @@ static int vhost_worker(void *data) schedule(); } + unuse_mm(dev->mm); + return 0; } /* Helper to allocate iovec buffers for all vqs. */ @@ -402,15 +404,14 @@ void vhost_dev_cleanup(struct vhost_dev *dev) kfree(rcu_dereference_protected(dev->memory, lockdep_is_held(&dev->mutex))); RCU_INIT_POINTER(dev->memory, NULL); - if (dev->mm) - mmput(dev->mm); - dev->mm = NULL; - WARN_ON(!list_empty(&dev->work_list)); if (dev->worker) { kthread_stop(dev->worker); dev->worker = NULL; } + if (dev->mm) + mmput(dev->mm); + dev->mm = NULL; } static int log_access_ok(void __user *log_base, u64 addr, unsigned long sz) @@ -881,15 +882,15 @@ static int set_bit_to_user(int nr, void __user *addr) static int log_write(void __user *log_base, u64 write_address, u64 write_length) { + u64 write_page = write_address / VHOST_PAGE_SIZE; int r; if (!write_length) return 0; write_length += write_address % VHOST_PAGE_SIZE; - write_address /= VHOST_PAGE_SIZE; for (;;) { u64 base = (u64)(unsigned long)log_base; - u64 log = base + write_address / 8; - int bit = write_address % 8; + u64 log = base + write_page / 8; + int bit = write_page % 8; if ((u64)(unsigned long)log != log) return -EFAULT; r = set_bit_to_user(bit, (void __user *)(unsigned long)log); @@ -898,7 +899,7 @@ static int log_write(void __user *log_base, if (write_length <= VHOST_PAGE_SIZE) break; write_length -= VHOST_PAGE_SIZE; - write_address += 1; + write_page += 1; } return r; } @@ -1093,7 +1094,7 @@ int vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq, /* Check it isn't doing very strange things with descriptor numbers. */ last_avail_idx = vq->last_avail_idx; - if (unlikely(get_user(vq->avail_idx, &vq->avail->idx))) { + if (unlikely(__get_user(vq->avail_idx, &vq->avail->idx))) { vq_err(vq, "Failed to access avail idx at %p\n", &vq->avail->idx); return -EFAULT; @@ -1114,8 +1115,8 @@ int vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq, /* Grab the next descriptor number they're advertising, and increment * the index we've seen. */ - if (unlikely(get_user(head, - &vq->avail->ring[last_avail_idx % vq->num]))) { + if (unlikely(__get_user(head, + &vq->avail->ring[last_avail_idx % vq->num]))) { vq_err(vq, "Failed to read head: idx %d address %p\n", last_avail_idx, &vq->avail->ring[last_avail_idx % vq->num]); @@ -1214,17 +1215,17 @@ int vhost_add_used(struct vhost_virtqueue *vq, unsigned int head, int len) /* The virtqueue contains a ring of used buffers. Get a pointer to the * next entry in that used ring. */ used = &vq->used->ring[vq->last_used_idx % vq->num]; - if (put_user(head, &used->id)) { + if (__put_user(head, &used->id)) { vq_err(vq, "Failed to write used id"); return -EFAULT; } - if (put_user(len, &used->len)) { + if (__put_user(len, &used->len)) { vq_err(vq, "Failed to write used len"); return -EFAULT; } /* Make sure buffer is written before we update index. */ smp_wmb(); - if (put_user(vq->last_used_idx + 1, &vq->used->idx)) { + if (__put_user(vq->last_used_idx + 1, &vq->used->idx)) { vq_err(vq, "Failed to increment used idx"); return -EFAULT; } @@ -1256,7 +1257,7 @@ static int __vhost_add_used_n(struct vhost_virtqueue *vq, start = vq->last_used_idx % vq->num; used = vq->used->ring + start; - if (copy_to_user(used, heads, count * sizeof *used)) { + if (__copy_to_user(used, heads, count * sizeof *used)) { vq_err(vq, "Failed to write used"); return -EFAULT; } @@ -1317,7 +1318,7 @@ void vhost_signal(struct vhost_dev *dev, struct vhost_virtqueue *vq) * interrupts. */ smp_mb(); - if (get_user(flags, &vq->avail->flags)) { + if (__get_user(flags, &vq->avail->flags)) { vq_err(vq, "Failed to get flags"); return; } @@ -1368,7 +1369,7 @@ bool vhost_enable_notify(struct vhost_virtqueue *vq) /* They could have slipped one in as we were doing that: make * sure it's written, then check again. */ smp_mb(); - r = get_user(avail_idx, &vq->avail->idx); + r = __get_user(avail_idx, &vq->avail->idx); if (r) { vq_err(vq, "Failed to check avail idx at %p: %d\n", &vq->avail->idx, r); |