diff --git a/drivers/staging/android/ion/ion.c b/drivers/staging/android/ion/ion.c index d97117f3abb684..cf9fc781300432 100644 --- a/drivers/staging/android/ion/ion.c +++ b/drivers/staging/android/ion/ion.c @@ -383,7 +383,14 @@ static void ion_handle_get(struct ion_handle *handle) static int ion_handle_put(struct ion_handle *handle) { - return kref_put(&handle->ref, ion_handle_destroy); + struct ion_client *client = handle->client; + int ret; + + mutex_lock(&client->lock); + ret = kref_put(&handle->ref, ion_handle_destroy); + mutex_unlock(&client->lock); + + return ret; } static struct ion_handle *ion_handle_lookup(struct ion_client *client, @@ -403,14 +410,24 @@ static struct ion_handle *ion_handle_lookup(struct ion_client *client, return ERR_PTR(-EINVAL); } -static struct ion_handle *ion_uhandle_get(struct ion_client *client, int id) +static struct ion_handle *ion_handle_get_by_id(struct ion_client *client, + int id) { - return idr_find(&client->idr, id); + struct ion_handle *handle; + + mutex_lock(&client->lock); + handle = idr_find(&client->idr, id); + if (handle) + ion_handle_get(handle); + mutex_unlock(&client->lock); + + return handle ? handle : ERR_PTR(-EINVAL); } static bool ion_handle_validate(struct ion_client *client, struct ion_handle *handle) { - return (ion_uhandle_get(client, handle->id) == handle); + WARN_ON(!mutex_is_locked(&client->lock)); + return (idr_find(&client->idr, handle->id) == handle); } static int ion_handle_add(struct ion_client *client, struct ion_handle *handle) @@ -503,11 +520,11 @@ struct ion_handle *ion_alloc(struct ion_client *client, size_t len, mutex_lock(&client->lock); ret = ion_handle_add(client, handle); + mutex_unlock(&client->lock); if (ret) { ion_handle_put(handle); handle = ERR_PTR(ret); } - mutex_unlock(&client->lock); return handle; } @@ -527,8 +544,8 @@ void ion_free(struct ion_client *client, struct ion_handle *handle) mutex_unlock(&client->lock); return; } - ion_handle_put(handle); mutex_unlock(&client->lock); + ion_handle_put(handle); } EXPORT_SYMBOL(ion_free); @@ -1021,14 +1038,15 @@ struct dma_buf *ion_share_dma_buf(struct ion_client *client, mutex_lock(&client->lock); valid_handle = ion_handle_validate(client, handle); - mutex_unlock(&client->lock); if (!valid_handle) { WARN(1, "%s: invalid handle passed to share.\n", __func__); + mutex_unlock(&client->lock); return ERR_PTR(-EINVAL); } - buffer = handle->buffer; ion_buffer_get(buffer); + mutex_unlock(&client->lock); + dmabuf = dma_buf_export(buffer, &dma_buf_ops, buffer->size, O_RDWR); if (IS_ERR(dmabuf)) { ion_buffer_put(buffer); @@ -1081,18 +1099,24 @@ struct ion_handle *ion_import_dma_buf(struct ion_client *client, int fd) handle = ion_handle_lookup(client, buffer); if (!IS_ERR(handle)) { ion_handle_get(handle); + mutex_unlock(&client->lock); goto end; } + mutex_unlock(&client->lock); + handle = ion_handle_create(client, buffer); if (IS_ERR(handle)) goto end; + + mutex_lock(&client->lock); ret = ion_handle_add(client, handle); + mutex_unlock(&client->lock); if (ret) { ion_handle_put(handle); handle = ERR_PTR(ret); } + end: - mutex_unlock(&client->lock); dma_buf_put(dmabuf); return handle; } @@ -1156,12 +1180,11 @@ static long ion_ioctl(struct file *filp, unsigned int cmd, unsigned long arg) if (copy_from_user(&data, (void __user *)arg, sizeof(struct ion_handle_data))) return -EFAULT; - mutex_lock(&client->lock); - handle = ion_uhandle_get(client, data.handle); - mutex_unlock(&client->lock); - if (!handle) - return -EINVAL; + handle = ion_handle_get_by_id(client, data.handle); + if (IS_ERR(handle)) + return PTR_ERR(handle); ion_free(client, handle); + ion_handle_put(handle); break; } case ION_IOC_SHARE: @@ -1172,8 +1195,11 @@ static long ion_ioctl(struct file *filp, unsigned int cmd, unsigned long arg) if (copy_from_user(&data, (void __user *)arg, sizeof(data))) return -EFAULT; - handle = ion_uhandle_get(client, data.handle); + handle = ion_handle_get_by_id(client, data.handle); + if (IS_ERR(handle)) + return PTR_ERR(handle); data.fd = ion_share_dma_buf_fd(client, handle); + ion_handle_put(handle); if (copy_to_user((void __user *)arg, &data, sizeof(data))) return -EFAULT; if (data.fd < 0)