summaryrefslogtreecommitdiffstats
path: root/drivers/vhost/vhost.c
diff options
context:
space:
mode:
Diffstat (limited to 'drivers/vhost/vhost.c')
-rw-r--r--drivers/vhost/vhost.c62
1 files changed, 34 insertions, 28 deletions
diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c
index 94701ff3a23..ade0568c07a 100644
--- a/drivers/vhost/vhost.c
+++ b/drivers/vhost/vhost.c
@@ -15,6 +15,7 @@
#include <linux/vhost.h>
#include <linux/virtio_net.h>
#include <linux/mm.h>
+#include <linux/mmu_context.h>
#include <linux/miscdevice.h>
#include <linux/mutex.h>
#include <linux/rcupdate.h>
@@ -29,8 +30,6 @@
#include <linux/if_packet.h>
#include <linux/if_arp.h>
-#include <net/sock.h>
-
#include "vhost.h"
enum {
@@ -98,22 +97,26 @@ void vhost_poll_stop(struct vhost_poll *poll)
remove_wait_queue(poll->wqh, &poll->wait);
}
+static bool vhost_work_seq_done(struct vhost_dev *dev, struct vhost_work *work,
+ unsigned seq)
+{
+ int left;
+ spin_lock_irq(&dev->work_lock);
+ left = seq - work->done_seq;
+ spin_unlock_irq(&dev->work_lock);
+ return left <= 0;
+}
+
static void vhost_work_flush(struct vhost_dev *dev, struct vhost_work *work)
{
unsigned seq;
- int left;
int flushing;
spin_lock_irq(&dev->work_lock);
seq = work->queue_seq;
work->flushing++;
spin_unlock_irq(&dev->work_lock);
- wait_event(work->done, ({
- spin_lock_irq(&dev->work_lock);
- left = seq - work->done_seq <= 0;
- spin_unlock_irq(&dev->work_lock);
- left;
- }));
+ wait_event(work->done, vhost_work_seq_done(dev, work, seq));
spin_lock_irq(&dev->work_lock);
flushing = --work->flushing;
spin_unlock_irq(&dev->work_lock);
@@ -157,7 +160,6 @@ static void vhost_vq_reset(struct vhost_dev *dev,
vq->avail_idx = 0;
vq->last_used_idx = 0;
vq->used_flags = 0;
- vq->used_flags = 0;
vq->log_used = false;
vq->log_addr = -1ull;
vq->vhost_hlen = 0;
@@ -178,6 +180,8 @@ static int vhost_worker(void *data)
struct vhost_work *work = NULL;
unsigned uninitialized_var(seq);
+ use_mm(dev->mm);
+
for (;;) {
/* mb paired w/ kthread_stop */
set_current_state(TASK_INTERRUPTIBLE);
@@ -192,7 +196,7 @@ static int vhost_worker(void *data)
if (kthread_should_stop()) {
spin_unlock_irq(&dev->work_lock);
__set_current_state(TASK_RUNNING);
- return 0;
+ break;
}
if (!list_empty(&dev->work_list)) {
work = list_first_entry(&dev->work_list,
@@ -210,6 +214,8 @@ static int vhost_worker(void *data)
schedule();
}
+ unuse_mm(dev->mm);
+ return 0;
}
/* Helper to allocate iovec buffers for all vqs. */
@@ -402,15 +408,14 @@ void vhost_dev_cleanup(struct vhost_dev *dev)
kfree(rcu_dereference_protected(dev->memory,
lockdep_is_held(&dev->mutex)));
RCU_INIT_POINTER(dev->memory, NULL);
- if (dev->mm)
- mmput(dev->mm);
- dev->mm = NULL;
-
WARN_ON(!list_empty(&dev->work_list));
if (dev->worker) {
kthread_stop(dev->worker);
dev->worker = NULL;
}
+ if (dev->mm)
+ mmput(dev->mm);
+ dev->mm = NULL;
}
static int log_access_ok(void __user *log_base, u64 addr, unsigned long sz)
@@ -881,14 +886,15 @@ static int set_bit_to_user(int nr, void __user *addr)
static int log_write(void __user *log_base,
u64 write_address, u64 write_length)
{
+ u64 write_page = write_address / VHOST_PAGE_SIZE;
int r;
if (!write_length)
return 0;
- write_address /= VHOST_PAGE_SIZE;
+ write_length += write_address % VHOST_PAGE_SIZE;
for (;;) {
u64 base = (u64)(unsigned long)log_base;
- u64 log = base + write_address / 8;
- int bit = write_address % 8;
+ u64 log = base + write_page / 8;
+ int bit = write_page % 8;
if ((u64)(unsigned long)log != log)
return -EFAULT;
r = set_bit_to_user(bit, (void __user *)(unsigned long)log);
@@ -897,7 +903,7 @@ static int log_write(void __user *log_base,
if (write_length <= VHOST_PAGE_SIZE)
break;
write_length -= VHOST_PAGE_SIZE;
- write_address += VHOST_PAGE_SIZE;
+ write_page += 1;
}
return r;
}
@@ -1092,7 +1098,7 @@ int vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq,
/* Check it isn't doing very strange things with descriptor numbers. */
last_avail_idx = vq->last_avail_idx;
- if (unlikely(get_user(vq->avail_idx, &vq->avail->idx))) {
+ if (unlikely(__get_user(vq->avail_idx, &vq->avail->idx))) {
vq_err(vq, "Failed to access avail idx at %p\n",
&vq->avail->idx);
return -EFAULT;
@@ -1113,8 +1119,8 @@ int vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq,
/* Grab the next descriptor number they're advertising, and increment
* the index we've seen. */
- if (unlikely(get_user(head,
- &vq->avail->ring[last_avail_idx % vq->num]))) {
+ if (unlikely(__get_user(head,
+ &vq->avail->ring[last_avail_idx % vq->num]))) {
vq_err(vq, "Failed to read head: idx %d address %p\n",
last_avail_idx,
&vq->avail->ring[last_avail_idx % vq->num]);
@@ -1213,17 +1219,17 @@ int vhost_add_used(struct vhost_virtqueue *vq, unsigned int head, int len)
/* The virtqueue contains a ring of used buffers. Get a pointer to the
* next entry in that used ring. */
used = &vq->used->ring[vq->last_used_idx % vq->num];
- if (put_user(head, &used->id)) {
+ if (__put_user(head, &used->id)) {
vq_err(vq, "Failed to write used id");
return -EFAULT;
}
- if (put_user(len, &used->len)) {
+ if (__put_user(len, &used->len)) {
vq_err(vq, "Failed to write used len");
return -EFAULT;
}
/* Make sure buffer is written before we update index. */
smp_wmb();
- if (put_user(vq->last_used_idx + 1, &vq->used->idx)) {
+ if (__put_user(vq->last_used_idx + 1, &vq->used->idx)) {
vq_err(vq, "Failed to increment used idx");
return -EFAULT;
}
@@ -1255,7 +1261,7 @@ static int __vhost_add_used_n(struct vhost_virtqueue *vq,
start = vq->last_used_idx % vq->num;
used = vq->used->ring + start;
- if (copy_to_user(used, heads, count * sizeof *used)) {
+ if (__copy_to_user(used, heads, count * sizeof *used)) {
vq_err(vq, "Failed to write used");
return -EFAULT;
}
@@ -1316,7 +1322,7 @@ void vhost_signal(struct vhost_dev *dev, struct vhost_virtqueue *vq)
* interrupts. */
smp_mb();
- if (get_user(flags, &vq->avail->flags)) {
+ if (__get_user(flags, &vq->avail->flags)) {
vq_err(vq, "Failed to get flags");
return;
}
@@ -1367,7 +1373,7 @@ bool vhost_enable_notify(struct vhost_virtqueue *vq)
/* They could have slipped one in as we were doing that: make
* sure it's written, then check again. */
smp_mb();
- r = get_user(avail_idx, &vq->avail->idx);
+ r = __get_user(avail_idx, &vq->avail->idx);
if (r) {
vq_err(vq, "Failed to check avail idx at %p: %d\n",
&vq->avail->idx, r);