summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKeiichi Watanabe <keiichiw@chromium.org>2021-05-10 21:17:24 +0900
committerSergio Lopez <slp@sinrega.org>2021-06-23 09:52:42 +0200
commit9982541776a603d30556a06df55e8c0491072763 (patch)
tree4b89a4ffe35e629c8d5d7c2f1a1e26a723214901
parent1a03a2aca700421f258b094b1a845f6420d8cfa9 (diff)
downloadvmm_vhost-9982541776a603d30556a06df55e8c0491072763.tar.gz
vhost_user: Stop passing around RawFd
Use `File` or `dyn AsRawFd` instead of `RawFd` to handle ownership easily. Fixes #37. Signed-off-by: Keiichi Watanabe <keiichiw@chromium.org> Change-Id: I6c79d73d1a54163d4612b0ca4d30bf7bd53f9b0f
-rw-r--r--coverage_config_x86_64.json2
-rw-r--r--src/vhost_user/connection.rs153
-rw-r--r--src/vhost_user/dummy_slave.rs36
-rw-r--r--src/vhost_user/master.rs53
-rw-r--r--src/vhost_user/master_req_handler.rs78
-rw-r--r--src/vhost_user/mod.rs2
-rw-r--r--src/vhost_user/slave_fs_cache.rs33
-rw-r--r--src/vhost_user/slave_req_handler.rs209
8 files changed, 235 insertions, 331 deletions
diff --git a/coverage_config_x86_64.json b/coverage_config_x86_64.json
index 37ecef8..7db1dca 100644
--- a/coverage_config_x86_64.json
+++ b/coverage_config_x86_64.json
@@ -1 +1 @@
-{"coverage_score": 82.7, "exclude_path": "src/vhost_kern/", "crate_features": "vhost-user-master,vhost-user-slave"}
+{"coverage_score": 83.6, "exclude_path": "src/vhost_kern/", "crate_features": "vhost-user-master,vhost-user-slave"}
diff --git a/src/vhost_user/connection.rs b/src/vhost_user/connection.rs
index fef7dac..ead84c5 100644
--- a/src/vhost_user/connection.rs
+++ b/src/vhost_user/connection.rs
@@ -5,9 +5,10 @@
#![allow(dead_code)]
+use std::fs::File;
use std::io::ErrorKind;
use std::marker::PhantomData;
-use std::os::unix::io::{AsRawFd, RawFd};
+use std::os::unix::io::{AsRawFd, FromRawFd, RawFd};
use std::os::unix::net::{UnixListener, UnixStream};
use std::path::{Path, PathBuf};
use std::{mem, slice};
@@ -305,7 +306,7 @@ impl<R: Req> Endpoint<R> {
}
/// Reads bytes from the socket into the given scatter/gather vectors with optional attached
- /// file descriptors.
+ /// file.
///
/// The underlying communication channel is a Unix domain socket in STREAM mode. It's a little
/// tricky to pass file descriptors through such a communication channel. Let's assume that a
@@ -315,29 +316,37 @@ impl<R: Req> Endpoint<R> {
/// 2) message(packet) boundaries must be respected on the receive side.
/// In other words, recvmsg() operations must not cross the packet boundary, otherwise the
/// attached file descriptors will get lost.
+ /// Note that this function wraps received file descriptors as `File`.
///
/// # Return:
- /// * - (number of bytes received, [received fds]) on success
+ /// * - (number of bytes received, [received files]) on success
/// * - SocketRetry: temporary error caused by signals or short of resources.
/// * - SocketBroken: the underline socket is broken.
/// * - SocketError: other socket related errors.
- pub fn recv_into_iovec(&mut self, iovs: &mut [iovec]) -> Result<(usize, Option<Vec<RawFd>>)> {
+ pub fn recv_into_iovec(&mut self, iovs: &mut [iovec]) -> Result<(usize, Option<Vec<File>>)> {
let mut fd_array = vec![0; MAX_ATTACHED_FD_ENTRIES];
let (bytes, fds) = self.sock.recv_with_fds(iovs, &mut fd_array)?;
- let rfds = match fds {
+
+ let files = match fds {
0 => None,
n => {
- let mut fds = Vec::with_capacity(n);
- fds.extend_from_slice(&fd_array[0..n]);
- Some(fds)
+ let files = fd_array
+ .iter()
+ .take(n)
+ .map(|fd| {
+ // Safe because we have the ownership of `fd`.
+ unsafe { File::from_raw_fd(*fd) }
+ })
+ .collect();
+ Some(files)
}
};
- Ok((bytes, rfds))
+ Ok((bytes, files))
}
/// Reads all bytes from the socket into the given scatter/gather vectors with optional
- /// attached file descriptors. Will loop until all data has been transfered.
+ /// attached files. Will loop until all data has been transferred.
///
/// The underlying communication channel is a Unix domain socket in STREAM mode. It's a little
/// tricky to pass file descriptors through such a communication channel. Let's assume that a
@@ -347,6 +356,7 @@ impl<R: Req> Endpoint<R> {
/// 2) message(packet) boundaries must be respected on the receive side.
/// In other words, recvmsg() operations must not cross the packet boundary, otherwise the
/// attached file descriptors will get lost.
+ /// Note that this function wraps received file descriptors as `File`.
///
/// # Return:
/// * - (number of bytes received, [received fds]) on success
@@ -355,7 +365,7 @@ impl<R: Req> Endpoint<R> {
pub fn recv_into_iovec_all(
&mut self,
iovs: &mut [iovec],
- ) -> Result<(usize, Option<Vec<RawFd>>)> {
+ ) -> Result<(usize, Option<Vec<File>>)> {
let mut data_read = 0;
let mut data_total = 0;
let mut rfds = None;
@@ -396,46 +406,46 @@ impl<R: Req> Endpoint<R> {
}
/// Reads bytes from the socket into a new buffer with optional attached
- /// file descriptors. Received file descriptors are set close-on-exec.
+ /// files. Received file descriptors are set close-on-exec and converted to `File`.
///
/// # Return:
- /// * - (number of bytes received, buf, [received fds]) on success.
+ /// * - (number of bytes received, buf, [received files]) on success.
/// * - SocketRetry: temporary error caused by signals or short of resources.
/// * - SocketBroken: the underline socket is broken.
/// * - SocketError: other socket related errors.
pub fn recv_into_buf(
&mut self,
buf_size: usize,
- ) -> Result<(usize, Vec<u8>, Option<Vec<RawFd>>)> {
+ ) -> Result<(usize, Vec<u8>, Option<Vec<File>>)> {
let mut buf = vec![0u8; buf_size];
- let (bytes, rfds) = {
+ let (bytes, files) = {
let mut iovs = [iovec {
iov_base: buf.as_mut_ptr() as *mut c_void,
iov_len: buf_size,
}];
self.recv_into_iovec(&mut iovs)?
};
- Ok((bytes, buf, rfds))
+ Ok((bytes, buf, files))
}
- /// Receive a header-only message with optional attached file descriptors.
+ /// Receive a header-only message with optional attached files.
/// Note, only the first MAX_ATTACHED_FD_ENTRIES file descriptors will be
/// accepted and all other file descriptor will be discard silently.
///
/// # Return:
- /// * - (message header, [received fds]) on success.
+ /// * - (message header, [received files]) on success.
/// * - SocketRetry: temporary error caused by signals or short of resources.
/// * - SocketBroken: the underline socket is broken.
/// * - SocketError: other socket related errors.
/// * - PartialMessage: received a partial message.
/// * - InvalidMessage: received a invalid message.
- pub fn recv_header(&mut self) -> Result<(VhostUserMsgHeader<R>, Option<Vec<RawFd>>)> {
+ pub fn recv_header(&mut self) -> Result<(VhostUserMsgHeader<R>, Option<Vec<File>>)> {
let mut hdr = VhostUserMsgHeader::default();
let mut iovs = [iovec {
iov_base: (&mut hdr as *mut VhostUserMsgHeader<R>) as *mut c_void,
iov_len: mem::size_of::<VhostUserMsgHeader<R>>(),
}];
- let (bytes, rfds) = self.recv_into_iovec_all(&mut iovs[..])?;
+ let (bytes, files) = self.recv_into_iovec_all(&mut iovs[..])?;
if bytes != mem::size_of::<VhostUserMsgHeader<R>>() {
return Err(Error::PartialMessage);
@@ -443,7 +453,7 @@ impl<R: Req> Endpoint<R> {
return Err(Error::InvalidMessage);
}
- Ok((hdr, rfds))
+ Ok((hdr, files))
}
/// Receive a message with optional attached file descriptors.
@@ -451,7 +461,7 @@ impl<R: Req> Endpoint<R> {
/// accepted and all other file descriptor will be discard silently.
///
/// # Return:
- /// * - (message header, message body, [received fds]) on success.
+ /// * - (message header, message body, [received files]) on success.
/// * - SocketRetry: temporary error caused by signals or short of resources.
/// * - SocketBroken: the underline socket is broken.
/// * - SocketError: other socket related errors.
@@ -459,7 +469,7 @@ impl<R: Req> Endpoint<R> {
/// * - InvalidMessage: received a invalid message.
pub fn recv_body<T: Sized + Default + VhostUserMsgValidator>(
&mut self,
- ) -> Result<(VhostUserMsgHeader<R>, T, Option<Vec<RawFd>>)> {
+ ) -> Result<(VhostUserMsgHeader<R>, T, Option<Vec<File>>)> {
let mut hdr = VhostUserMsgHeader::default();
let mut body: T = Default::default();
let mut iovs = [
@@ -472,7 +482,7 @@ impl<R: Req> Endpoint<R> {
iov_len: mem::size_of::<T>(),
},
];
- let (bytes, rfds) = self.recv_into_iovec_all(&mut iovs[..])?;
+ let (bytes, files) = self.recv_into_iovec_all(&mut iovs[..])?;
let total = mem::size_of::<VhostUserMsgHeader<R>>() + mem::size_of::<T>();
if bytes != total {
@@ -481,7 +491,7 @@ impl<R: Req> Endpoint<R> {
return Err(Error::InvalidMessage);
}
- Ok((hdr, body, rfds))
+ Ok((hdr, body, files))
}
/// Receive a message with header and optional content. Callers need to
@@ -492,7 +502,7 @@ impl<R: Req> Endpoint<R> {
/// silently.
///
/// # Return:
- /// * - (message header, message size, [received fds]) on success.
+ /// * - (message header, message size, [received files]) on success.
/// * - SocketRetry: temporary error caused by signals or short of resources.
/// * - SocketBroken: the underline socket is broken.
/// * - SocketError: other socket related errors.
@@ -501,7 +511,7 @@ impl<R: Req> Endpoint<R> {
pub fn recv_body_into_buf(
&mut self,
buf: &mut [u8],
- ) -> Result<(VhostUserMsgHeader<R>, usize, Option<Vec<RawFd>>)> {
+ ) -> Result<(VhostUserMsgHeader<R>, usize, Option<Vec<File>>)> {
let mut hdr = VhostUserMsgHeader::default();
let mut iovs = [
iovec {
@@ -513,7 +523,7 @@ impl<R: Req> Endpoint<R> {
iov_len: buf.len(),
},
];
- let (bytes, rfds) = self.recv_into_iovec_all(&mut iovs[..])?;
+ let (bytes, files) = self.recv_into_iovec_all(&mut iovs[..])?;
if bytes < mem::size_of::<VhostUserMsgHeader<R>>() {
return Err(Error::PartialMessage);
@@ -521,7 +531,7 @@ impl<R: Req> Endpoint<R> {
return Err(Error::InvalidMessage);
}
- Ok((hdr, bytes - mem::size_of::<VhostUserMsgHeader<R>>(), rfds))
+ Ok((hdr, bytes - mem::size_of::<VhostUserMsgHeader<R>>(), files))
}
/// Receive a message with optional payload and attached file descriptors.
@@ -529,7 +539,7 @@ impl<R: Req> Endpoint<R> {
/// accepted and all other file descriptor will be discard silently.
///
/// # Return:
- /// * - (message header, message body, size of payload, [received fds]) on success.
+ /// * - (message header, message body, size of payload, [received files]) on success.
/// * - SocketRetry: temporary error caused by signals or short of resources.
/// * - SocketBroken: the underline socket is broken.
/// * - SocketError: other socket related errors.
@@ -539,7 +549,7 @@ impl<R: Req> Endpoint<R> {
pub fn recv_payload_into_buf<T: Sized + Default + VhostUserMsgValidator>(
&mut self,
buf: &mut [u8],
- ) -> Result<(VhostUserMsgHeader<R>, T, usize, Option<Vec<RawFd>>)> {
+ ) -> Result<(VhostUserMsgHeader<R>, T, usize, Option<Vec<File>>)> {
let mut hdr = VhostUserMsgHeader::default();
let mut body: T = Default::default();
let mut iovs = [
@@ -556,7 +566,7 @@ impl<R: Req> Endpoint<R> {
iov_len: buf.len(),
},
];
- let (bytes, rfds) = self.recv_into_iovec_all(&mut iovs[..])?;
+ let (bytes, files) = self.recv_into_iovec_all(&mut iovs[..])?;
let total = mem::size_of::<VhostUserMsgHeader<R>>() + mem::size_of::<T>();
if bytes < total {
@@ -565,17 +575,7 @@ impl<R: Req> Endpoint<R> {
return Err(Error::InvalidMessage);
}
- Ok((hdr, body, bytes - total, rfds))
- }
-
- /// Close all raw file descriptors.
- pub fn close_rfds(rfds: Option<Vec<RawFd>>) {
- if let Some(fds) = rfds {
- for fd in fds {
- // safe because the rawfds are valid and we don't care about the result.
- let _ = unsafe { libc::close(fd) };
- }
- }
+ Ok((hdr, body, bytes - total, files))
}
}
@@ -608,9 +608,7 @@ fn get_sub_iovs_offset(iov_lens: &[usize], skip_size: usize) -> (usize, usize) {
#[cfg(test)]
mod tests {
use super::*;
- use std::fs::File;
use std::io::{Read, Seek, SeekFrom, Write};
- use std::os::unix::io::FromRawFd;
use vmm_sys_util::rand::rand_alphanumerics;
use vmm_sys_util::tempfile::TempFile;
@@ -685,14 +683,14 @@ mod tests {
.unwrap();
assert_eq!(len, 4);
- let (bytes, buf2, rfds) = slave.recv_into_buf(4).unwrap();
+ let (bytes, buf2, files) = slave.recv_into_buf(4).unwrap();
assert_eq!(bytes, 4);
assert_eq!(&buf1[..], &buf2[..]);
- assert!(rfds.is_some());
- let fds = rfds.unwrap();
+ assert!(files.is_some());
+ let files = files.unwrap();
{
- assert_eq!(fds.len(), 1);
- let mut file = unsafe { File::from_raw_fd(fds[0]) };
+ assert_eq!(files.len(), 1);
+ let mut file = &files[0];
let mut content = String::new();
file.seek(SeekFrom::Start(0)).unwrap();
file.read_to_string(&mut content).unwrap();
@@ -710,23 +708,23 @@ mod tests {
.unwrap();
assert_eq!(len, 4);
- let (bytes, buf2, rfds) = slave.recv_into_buf(0x2).unwrap();
+ let (bytes, buf2, files) = slave.recv_into_buf(0x2).unwrap();
assert_eq!(bytes, 2);
assert_eq!(&buf1[..2], &buf2[..]);
- assert!(rfds.is_some());
- let fds = rfds.unwrap();
+ assert!(files.is_some());
+ let files = files.unwrap();
{
- assert_eq!(fds.len(), 3);
- let mut file = unsafe { File::from_raw_fd(fds[1]) };
+ assert_eq!(files.len(), 3);
+ let mut file = &files[1];
let mut content = String::new();
file.seek(SeekFrom::Start(0)).unwrap();
file.read_to_string(&mut content).unwrap();
assert_eq!(content, "test");
}
- let (bytes, buf2, rfds) = slave.recv_into_buf(0x2).unwrap();
+ let (bytes, buf2, files) = slave.recv_into_buf(0x2).unwrap();
assert_eq!(bytes, 2);
assert_eq!(&buf1[2..], &buf2[..]);
- assert!(rfds.is_none());
+ assert!(files.is_none());
// Following communication pattern should not work:
// Sending side: data(header, body) with fds
@@ -742,10 +740,10 @@ mod tests {
let (bytes, buf4) = slave.recv_data(2).unwrap();
assert_eq!(bytes, 2);
assert_eq!(&buf1[..2], &buf4[..]);
- let (bytes, buf2, rfds) = slave.recv_into_buf(0x2).unwrap();
+ let (bytes, buf2, files) = slave.recv_into_buf(0x2).unwrap();
assert_eq!(bytes, 2);
assert_eq!(&buf1[2..], &buf2[..]);
- assert!(rfds.is_none());
+ assert!(files.is_none());
// Following communication pattern should work:
// Sending side: data, data with fds
@@ -760,28 +758,28 @@ mod tests {
.unwrap();
assert_eq!(len, 4);
- let (bytes, buf2, rfds) = slave.recv_into_buf(0x4).unwrap();
+ let (bytes, buf2, files) = slave.recv_into_buf(0x4).unwrap();
assert_eq!(bytes, 4);
assert_eq!(&buf1[..], &buf2[..]);
- assert!(rfds.is_none());
+ assert!(files.is_none());
- let (bytes, buf2, rfds) = slave.recv_into_buf(0x2).unwrap();
+ let (bytes, buf2, files) = slave.recv_into_buf(0x2).unwrap();
assert_eq!(bytes, 2);
assert_eq!(&buf1[..2], &buf2[..]);
- assert!(rfds.is_some());
- let fds = rfds.unwrap();
+ assert!(files.is_some());
+ let files = files.unwrap();
{
- assert_eq!(fds.len(), 3);
- let mut file = unsafe { File::from_raw_fd(fds[1]) };
+ assert_eq!(files.len(), 3);
+ let mut file = &files[1];
let mut content = String::new();
file.seek(SeekFrom::Start(0)).unwrap();
file.read_to_string(&mut content).unwrap();
assert_eq!(content, "test");
}
- let (bytes, buf2, rfds) = slave.recv_into_buf(0x2).unwrap();
+ let (bytes, buf2, files) = slave.recv_into_buf(0x2).unwrap();
assert_eq!(bytes, 2);
assert_eq!(&buf1[2..], &buf2[..]);
- assert!(rfds.is_none());
+ assert!(files.is_none());
// Following communication pattern should not work:
// Sending side: data1, data2 with fds
@@ -799,9 +797,9 @@ mod tests {
let (bytes, _) = slave.recv_data(5).unwrap();
assert_eq!(bytes, 5);
- let (bytes, _, rfds) = slave.recv_into_buf(0x4).unwrap();
+ let (bytes, _, files) = slave.recv_into_buf(0x4).unwrap();
assert_eq!(bytes, 3);
- assert!(rfds.is_none());
+ assert!(files.is_none());
// If the target fd array is too small, extra file descriptors will get lost.
let len = master
@@ -812,12 +810,9 @@ mod tests {
.unwrap();
assert_eq!(len, 4);
- let (bytes, _, rfds) = slave.recv_into_buf(0x4).unwrap();
+ let (bytes, _, files) = slave.recv_into_buf(0x4).unwrap();
assert_eq!(bytes, 4);
- assert!(rfds.is_some());
-
- Endpoint::<MasterReq>::close_rfds(rfds);
- Endpoint::<MasterReq>::close_rfds(None);
+ assert!(files.is_some());
}
#[test]
@@ -842,15 +837,15 @@ mod tests {
mem::size_of::<u64>(),
)
};
- let (hdr2, bytes, rfds) = slave.recv_body_into_buf(slice).unwrap();
+ let (hdr2, bytes, files) = slave.recv_body_into_buf(slice).unwrap();
assert_eq!(hdr1, hdr2);
assert_eq!(bytes, 8);
assert_eq!(features1, features2);
- assert!(rfds.is_none());
+ assert!(files.is_none());
master.send_header(&hdr1, None).unwrap();
- let (hdr2, rfds) = slave.recv_header().unwrap();
+ let (hdr2, files) = slave.recv_header().unwrap();
assert_eq!(hdr1, hdr2);
- assert!(rfds.is_none());
+ assert!(files.is_none());
}
}
diff --git a/src/vhost_user/dummy_slave.rs b/src/vhost_user/dummy_slave.rs
index dc9eed5..cc9a9fb 100644
--- a/src/vhost_user/dummy_slave.rs
+++ b/src/vhost_user/dummy_slave.rs
@@ -2,7 +2,6 @@
// SPDX-License-Identifier: Apache-2.0
use std::fs::File;
-use std::os::unix::io::{AsRawFd, RawFd};
use super::message::*;
use super::*;
@@ -21,9 +20,9 @@ pub struct DummySlaveReqHandler {
pub queue_num: usize,
pub vring_num: [u32; MAX_QUEUE_NUM],
pub vring_base: [u32; MAX_QUEUE_NUM],
- pub call_fd: [Option<RawFd>; MAX_QUEUE_NUM],
- pub kick_fd: [Option<RawFd>; MAX_QUEUE_NUM],
- pub err_fd: [Option<RawFd>; MAX_QUEUE_NUM],
+ pub call_fd: [Option<File>; MAX_QUEUE_NUM],
+ pub kick_fd: [Option<File>; MAX_QUEUE_NUM],
+ pub err_fd: [Option<File>; MAX_QUEUE_NUM],
pub vring_started: [bool; MAX_QUEUE_NUM],
pub vring_enabled: [bool; MAX_QUEUE_NUM],
pub inflight_file: Option<File>,
@@ -85,7 +84,7 @@ impl VhostUserSlaveReqHandlerMut for DummySlaveReqHandler {
Ok(())
}
- fn set_mem_table(&mut self, _ctx: &[VhostUserMemoryRegion], _fds: &[RawFd]) -> Result<()> {
+ fn set_mem_table(&mut self, _ctx: &[VhostUserMemoryRegion], _files: Vec<File>) -> Result<()> {
Ok(())
}
@@ -136,14 +135,10 @@ impl VhostUserSlaveReqHandlerMut for DummySlaveReqHandler {
))
}
- fn set_vring_kick(&mut self, index: u8, fd: Option<RawFd>) -> Result<()> {
+ fn set_vring_kick(&mut self, index: u8, fd: Option<File>) -> Result<()> {
if index as usize >= self.queue_num || index as usize > self.queue_num {
return Err(Error::InvalidParam);
}
- if self.kick_fd[index as usize].is_some() {
- // Close file descriptor set by previous operations.
- let _ = unsafe { libc::close(self.kick_fd[index as usize].unwrap()) };
- }
self.kick_fd[index as usize] = fd;
// Quotation from vhost-user spec:
@@ -157,26 +152,18 @@ impl VhostUserSlaveReqHandlerMut for DummySlaveReqHandler {
Ok(())
}
- fn set_vring_call(&mut self, index: u8, fd: Option<RawFd>) -> Result<()> {
+ fn set_vring_call(&mut self, index: u8, fd: Option<File>) -> Result<()> {
if index as usize >= self.queue_num || index as usize > self.queue_num {
return Err(Error::InvalidParam);
}
- if self.call_fd[index as usize].is_some() {
- // Close file descriptor set by previous operations.
- let _ = unsafe { libc::close(self.call_fd[index as usize].unwrap()) };
- }
self.call_fd[index as usize] = fd;
Ok(())
}
- fn set_vring_err(&mut self, index: u8, fd: Option<RawFd>) -> Result<()> {
+ fn set_vring_err(&mut self, index: u8, fd: Option<File>) -> Result<()> {
if index as usize >= self.queue_num || index as usize > self.queue_num {
return Err(Error::InvalidParam);
}
- if self.err_fd[index as usize].is_some() {
- // Close file descriptor set by previous operations.
- let _ = unsafe { libc::close(self.err_fd[index as usize].unwrap()) };
- }
self.err_fd[index as usize] = fd;
Ok(())
}
@@ -250,10 +237,9 @@ impl VhostUserSlaveReqHandlerMut for DummySlaveReqHandler {
fn get_inflight_fd(
&mut self,
inflight: &VhostUserInflight,
- ) -> Result<(VhostUserInflight, RawFd)> {
+ ) -> Result<(VhostUserInflight, File)> {
let file = tempfile::tempfile().unwrap();
- let fd = file.as_raw_fd();
- self.inflight_file = Some(file);
+ self.inflight_file = Some(file.try_clone().unwrap());
Ok((
VhostUserInflight {
mmap_size: 0x1000,
@@ -261,7 +247,7 @@ impl VhostUserSlaveReqHandlerMut for DummySlaveReqHandler {
num_queues: inflight.num_queues,
queue_size: inflight.queue_size,
},
- fd,
+ file,
))
}
@@ -273,7 +259,7 @@ impl VhostUserSlaveReqHandlerMut for DummySlaveReqHandler {
Ok(MAX_MEM_SLOTS as u64)
}
- fn add_mem_region(&mut self, _region: &VhostUserSingleMemoryRegion, _fd: RawFd) -> Result<()> {
+ fn add_mem_region(&mut self, _region: &VhostUserSingleMemoryRegion, _fd: File) -> Result<()> {
Ok(())
}
diff --git a/src/vhost_user/master.rs b/src/vhost_user/master.rs
index 0e8b22c..7933f4d 100644
--- a/src/vhost_user/master.rs
+++ b/src/vhost_user/master.rs
@@ -5,7 +5,7 @@
use std::fs::File;
use std::mem;
-use std::os::unix::io::{AsRawFd, FromRawFd, RawFd};
+use std::os::unix::io::{AsRawFd, RawFd};
use std::os::unix::net::UnixStream;
use std::path::Path;
use std::sync::{Arc, Mutex, MutexGuard};
@@ -14,6 +14,7 @@ use vmm_sys_util::eventfd::EventFd;
use super::connection::Endpoint;
use super::message::*;
+use super::slave_req_handler::take_single_file;
use super::{Error as VhostUserError, Result as VhostUserResult};
use crate::backend::{VhostBackend, VhostUserMemoryRegionInfo, VringConfigData};
use crate::{Error, Result};
@@ -50,7 +51,7 @@ pub trait VhostUserMaster: VhostBackend {
fn set_config(&mut self, offset: u32, flags: VhostUserConfigFlags, buf: &[u8]) -> Result<()>;
/// Setup slave communication channel.
- fn set_slave_request_fd(&mut self, fd: RawFd) -> Result<()>;
+ fn set_slave_request_fd(&mut self, fd: &dyn AsRawFd) -> Result<()>;
/// Retrieve shared buffer for inflight I/O tracking.
fn get_inflight_fd(
@@ -412,7 +413,6 @@ impl VhostUserMaster for Master {
let (body_reply, buf_reply, rfds) =
node.recv_reply_with_payload::<VhostUserConfig>(&hdr)?;
if rfds.is_some() {
- Endpoint::<MasterReq>::close_rfds(rfds);
return error_code(VhostUserError::InvalidMessage);
} else if body_reply.size == 0 {
return error_code(VhostUserError::SlaveInternalError);
@@ -445,13 +445,12 @@ impl VhostUserMaster for Master {
node.wait_for_ack(&hdr).map_err(|e| e.into())
}
- fn set_slave_request_fd(&mut self, fd: RawFd) -> Result<()> {
+ fn set_slave_request_fd(&mut self, fd: &dyn AsRawFd) -> Result<()> {
let mut node = self.node();
if node.acked_protocol_features & VhostUserProtocolFeatures::SLAVE_REQ.bits() == 0 {
return error_code(VhostUserError::InvalidOperation);
}
-
- let fds = [fd];
+ let fds = [fd.as_raw_fd()];
let hdr = node.send_request_header(MasterReq::SET_SLAVE_REQ_FD, Some(&fds))?;
node.wait_for_ack(&hdr).map_err(|e| e.into())
}
@@ -466,20 +465,12 @@ impl VhostUserMaster for Master {
}
let hdr = node.send_request_with_body(MasterReq::GET_INFLIGHT_FD, inflight, None)?;
- let (inflight, fds) = node.recv_reply_with_fds::<VhostUserInflight>(&hdr)?;
+ let (inflight, files) = node.recv_reply_with_files::<VhostUserInflight>(&hdr)?;
- if let Some(fds) = &fds {
- if fds.len() == 1 && fds[0] >= 0 {
- // Safe because we know the fd is valid.
- let file = unsafe { File::from_raw_fd(fds[0]) };
- return Ok((inflight, file));
- }
+ match take_single_file(files) {
+ Some(file) => Ok((inflight, file)),
+ None => error_code(VhostUserError::IncorrectFds),
}
-
- // Make sure to close the fds before returning the error.
- Endpoint::<MasterReq>::close_rfds(fds);
-
- error_code(VhostUserError::IncorrectFds)
}
fn set_inflight_fd(&mut self, inflight: &VhostUserInflight, fd: RawFd) -> Result<()> {
@@ -685,33 +676,31 @@ impl MasterInternal {
let (reply, body, rfds) = self.main_sock.recv_body::<T>()?;
if !reply.is_reply_for(&hdr) || rfds.is_some() || !body.is_valid() {
- Endpoint::<MasterReq>::close_rfds(rfds);
return Err(VhostUserError::InvalidMessage);
}
Ok(body)
}
- fn recv_reply_with_fds<T: Sized + Default + VhostUserMsgValidator>(
+ fn recv_reply_with_files<T: Sized + Default + VhostUserMsgValidator>(
&mut self,
hdr: &VhostUserMsgHeader<MasterReq>,
- ) -> VhostUserResult<(T, Option<Vec<RawFd>>)> {
+ ) -> VhostUserResult<(T, Option<Vec<File>>)> {
if mem::size_of::<T>() > MAX_MSG_SIZE || hdr.is_reply() {
return Err(VhostUserError::InvalidParam);
}
self.check_state()?;
- let (reply, body, rfds) = self.main_sock.recv_body::<T>()?;
- if !reply.is_reply_for(&hdr) || rfds.is_none() || !body.is_valid() {
- Endpoint::<MasterReq>::close_rfds(rfds);
+ let (reply, body, files) = self.main_sock.recv_body::<T>()?;
+ if !reply.is_reply_for(&hdr) || files.is_none() || !body.is_valid() {
return Err(VhostUserError::InvalidMessage);
}
- Ok((body, rfds))
+ Ok((body, files))
}
fn recv_reply_with_payload<T: Sized + Default + VhostUserMsgValidator>(
&mut self,
hdr: &VhostUserMsgHeader<MasterReq>,
- ) -> VhostUserResult<(T, Vec<u8>, Option<Vec<RawFd>>)> {
+ ) -> VhostUserResult<(T, Vec<u8>, Option<Vec<File>>)> {
if mem::size_of::<T>() > MAX_MSG_SIZE
|| hdr.get_size() as usize <= mem::size_of::<T>()
|| hdr.get_size() as usize > MAX_MSG_SIZE
@@ -722,18 +711,17 @@ impl MasterInternal {
self.check_state()?;
let mut buf: Vec<u8> = vec![0; hdr.get_size() as usize - mem::size_of::<T>()];
- let (reply, body, bytes, rfds) = self.main_sock.recv_payload_into_buf::<T>(&mut buf)?;
+ let (reply, body, bytes, files) = self.main_sock.recv_payload_into_buf::<T>(&mut buf)?;
if !reply.is_reply_for(hdr)
|| reply.get_size() as usize != mem::size_of::<T>() + bytes
- || rfds.is_some()
+ || files.is_some()
|| !body.is_valid()
+ || bytes != buf.len()
{
- Endpoint::<MasterReq>::close_rfds(rfds);
- return Err(VhostUserError::InvalidMessage);
- } else if bytes != buf.len() {
return Err(VhostUserError::InvalidMessage);
}
- Ok((body, buf, rfds))
+
+ Ok((body, buf, files))
}
fn wait_for_ack(&mut self, hdr: &VhostUserMsgHeader<MasterReq>) -> VhostUserResult<()> {
@@ -746,7 +734,6 @@ impl MasterInternal {
let (reply, body, rfds) = self.main_sock.recv_body::<VhostUserU64>()?;
if !reply.is_reply_for(&hdr) || rfds.is_some() || !body.is_valid() {
- Endpoint::<MasterReq>::close_rfds(rfds);
return Err(VhostUserError::InvalidMessage);
}
if body.value != 0 {
diff --git a/src/vhost_user/master_req_handler.rs b/src/vhost_user/master_req_handler.rs
index 8cba188..0ecda4e 100644
--- a/src/vhost_user/master_req_handler.rs
+++ b/src/vhost_user/master_req_handler.rs
@@ -1,6 +1,7 @@
// Copyright (C) 2019-2021 Alibaba Cloud. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
+use std::fs::File;
use std::mem;
use std::os::unix::io::{AsRawFd, RawFd};
use std::os::unix::net::UnixStream;
@@ -33,9 +34,7 @@ pub trait VhostUserMasterReqHandler {
}
/// Handle virtio-fs map file requests.
- fn fs_slave_map(&self, _fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> {
- // Safe because we have just received the rawfd from kernel.
- unsafe { libc::close(fd) };
+ fn fs_slave_map(&self, _fs: &VhostUserFSSlaveMsg, _fd: &dyn AsRawFd) -> HandlerResult<u64> {
Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
}
@@ -50,14 +49,12 @@ pub trait VhostUserMasterReqHandler {
}
/// Handle virtio-fs file IO requests.
- fn fs_slave_io(&self, _fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> {
- // Safe because we have just received the rawfd from kernel.
- unsafe { libc::close(fd) };
+ fn fs_slave_io(&self, _fs: &VhostUserFSSlaveMsg, _fd: &dyn AsRawFd) -> HandlerResult<u64> {
Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
}
// fn handle_iotlb_msg(&mut self, iotlb: VhostUserIotlb);
- // fn handle_vring_host_notifier(&mut self, area: VhostUserVringArea, fd: RawFd);
+ // fn handle_vring_host_notifier(&mut self, area: VhostUserVringArea, fd: &dyn AsRawFd);
}
/// A helper trait mirroring [VhostUserMasterReqHandler] but without interior mutability.
@@ -70,9 +67,7 @@ pub trait VhostUserMasterReqHandlerMut {
}
/// Handle virtio-fs map file requests.
- fn fs_slave_map(&mut self, _fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> {
- // Safe because we have just received the rawfd from kernel.
- unsafe { libc::close(fd) };
+ fn fs_slave_map(&mut self, _fs: &VhostUserFSSlaveMsg, _fd: &dyn AsRawFd) -> HandlerResult<u64> {
Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
}
@@ -87,9 +82,7 @@ pub trait VhostUserMasterReqHandlerMut {
}
/// Handle virtio-fs file IO requests.
- fn fs_slave_io(&mut self, _fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> {
- // Safe because we have just received the rawfd from kernel.
- unsafe { libc::close(fd) };
+ fn fs_slave_io(&mut self, _fs: &VhostUserFSSlaveMsg, _fd: &dyn AsRawFd) -> HandlerResult<u64> {
Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
}
@@ -102,7 +95,7 @@ impl<S: VhostUserMasterReqHandlerMut> VhostUserMasterReqHandler for Mutex<S> {
self.lock().unwrap().handle_config_change()
}
- fn fs_slave_map(&self, fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> {
+ fn fs_slave_map(&self, fs: &VhostUserFSSlaveMsg, fd: &dyn AsRawFd) -> HandlerResult<u64> {
self.lock().unwrap().fs_slave_map(fs, fd)
}
@@ -114,7 +107,7 @@ impl<S: VhostUserMasterReqHandlerMut> VhostUserMasterReqHandler for Mutex<S> {
self.lock().unwrap().fs_slave_sync(fs)
}
- fn fs_slave_io(&self, fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> {
+ fn fs_slave_io(&self, fs: &VhostUserFSSlaveMsg, fd: &dyn AsRawFd) -> HandlerResult<u64> {
self.lock().unwrap().fs_slave_io(fs, fd)
}
}
@@ -206,8 +199,8 @@ impl<S: VhostUserMasterReqHandler> MasterReqHandler<S> {
// . recv optional message body and payload according size field in
// message header
// . validate message body and optional payload
- let (hdr, rfds) = self.sub_sock.recv_header()?;
- let rfds = self.check_attached_rfds(&hdr, rfds)?;
+ let (hdr, files) = self.sub_sock.recv_header()?;
+ self.check_attached_files(&hdr, &files)?;
let (size, buf) = match hdr.get_size() {
0 => (0, vec![0u8; 0]),
len => {
@@ -231,9 +224,9 @@ impl<S: VhostUserMasterReqHandler> MasterReqHandler<S> {
}
SlaveReq::FS_MAP => {
let msg = self.extract_msg_body::<VhostUserFSSlaveMsg>(&hdr, size, &buf)?;
- // check_attached_rfds() has validated rfds
+ // check_attached_files() has validated files
self.backend
- .fs_slave_map(&msg, rfds.unwrap()[0])
+ .fs_slave_map(&msg, &files.unwrap()[0])
.map_err(Error::ReqHandlerError)
}
SlaveReq::FS_UNMAP => {
@@ -250,9 +243,9 @@ impl<S: VhostUserMasterReqHandler> MasterReqHandler<S> {
}
SlaveReq::FS_IO => {
let msg = self.extract_msg_body::<VhostUserFSSlaveMsg>(&hdr, size, &buf)?;
- // check_attached_rfds() has validated rfds
+ // check_attached_files() has validated files
self.backend
- .fs_slave_io(&msg, rfds.unwrap()[0])
+ .fs_slave_io(&msg, &files.unwrap()[0])
.map_err(Error::ReqHandlerError)
}
_ => Err(Error::InvalidMessage),
@@ -286,34 +279,21 @@ impl<S: VhostUserMasterReqHandler> MasterReqHandler<S> {
Ok(())
}
- fn check_attached_rfds(
+ fn check_attached_files(
&self,
hdr: &VhostUserMsgHeader<SlaveReq>,
- rfds: Option<Vec<RawFd>>,
- ) -> Result<Option<Vec<RawFd>>> {
+ files: &Option<Vec<File>>,
+ ) -> Result<()> {
match hdr.get_code() {
SlaveReq::FS_MAP | SlaveReq::FS_IO => {
- // Expect an fd set with a single fd.
- match rfds {
- None => Err(Error::InvalidMessage),
- Some(fds) => {
- if fds.len() != 1 {
- Endpoint::<SlaveReq>::close_rfds(Some(fds));
- Err(Error::InvalidMessage)
- } else {
- Ok(Some(fds))
- }
- }
- }
- }
- _ => {
- if rfds.is_some() {
- Endpoint::<SlaveReq>::close_rfds(rfds);
- Err(Error::InvalidMessage)
- } else {
- Ok(rfds)
+ // Expect a single file is passed.
+ match files {
+ Some(files) if files.len() == 1 => Ok(()),
+ _ => Err(Error::InvalidMessage),
}
}
+ _ if files.is_some() => Err(Error::InvalidMessage),
+ _ => Ok(()),
}
}
@@ -390,9 +370,11 @@ mod tests {
impl VhostUserMasterReqHandlerMut for MockMasterReqHandler {
/// Handle virtio-fs map file requests from the slave.
- fn fs_slave_map(&mut self, _fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> {
- // Safe because we have just received the rawfd from kernel.
- unsafe { libc::close(fd) };
+ fn fs_slave_map(
+ &mut self,
+ _fs: &VhostUserFSSlaveMsg,
+ _fd: &dyn AsRawFd,
+ ) -> HandlerResult<u64> {
Ok(0)
}
@@ -437,7 +419,7 @@ mod tests {
});
fs_cache
- .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd)
+ .fs_slave_map(&VhostUserFSSlaveMsg::default(), &fd)
.unwrap();
// When REPLY_ACK has not been negotiated, the master has no way to detect failure from
// slave side.
@@ -468,7 +450,7 @@ mod tests {
fs_cache.set_reply_ack_flag(true);
fs_cache
- .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd)
+ .fs_slave_map(&VhostUserFSSlaveMsg::default(), &fd)
.unwrap();
fs_cache
.fs_slave_unmap(&VhostUserFSSlaveMsg::default())
diff --git a/src/vhost_user/mod.rs b/src/vhost_user/mod.rs
index 52d97f7..3467c25 100644
--- a/src/vhost_user/mod.rs
+++ b/src/vhost_user/mod.rs
@@ -401,7 +401,7 @@ mod tests {
assert_eq!(offset, 0x100);
assert_eq!(reply_payload[0], 0xa5);
- master.set_slave_request_fd(eventfd.as_raw_fd()).unwrap();
+ master.set_slave_request_fd(&eventfd).unwrap();
master.set_vring_enable(0, true).unwrap();
// unimplemented yet
diff --git a/src/vhost_user/slave_fs_cache.rs b/src/vhost_user/slave_fs_cache.rs
index a9c4ed2..ee5fd9b 100644
--- a/src/vhost_user/slave_fs_cache.rs
+++ b/src/vhost_user/slave_fs_cache.rs
@@ -3,7 +3,7 @@
use std::io;
use std::mem;
-use std::os::unix::io::RawFd;
+use std::os::unix::io::{AsRawFd, RawFd};
use std::os::unix::net::UnixStream;
use std::sync::{Arc, Mutex, MutexGuard};
@@ -55,7 +55,6 @@ impl SlaveFsCacheReqInternal {
let (reply, body, rfds) = self.sock.recv_body::<VhostUserU64>()?;
if !reply.is_reply_for(&hdr) || rfds.is_some() || !body.is_valid() {
- Endpoint::<SlaveReq>::close_rfds(rfds);
return Err(Error::InvalidMessage);
}
if body.value != 0 {
@@ -129,8 +128,8 @@ impl SlaveFsCacheReq {
impl VhostUserMasterReqHandler for SlaveFsCacheReq {
/// Forward vhost-user-fs map file requests to the slave.
- fn fs_slave_map(&self, fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> {
- self.send_message(SlaveReq::FS_MAP, fs, Some(&[fd]))
+ fn fs_slave_map(&self, fs: &VhostUserFSSlaveMsg, fd: &dyn AsRawFd) -> HandlerResult<u64> {
+ self.send_message(SlaveReq::FS_MAP, fs, Some(&[fd.as_raw_fd()]))
}
/// Forward vhost-user-fs unmap file requests to the master.
@@ -158,31 +157,21 @@ mod tests {
#[test]
fn test_slave_fs_cache_send_failure() {
let (p1, p2) = UnixStream::pair().unwrap();
- let fd = p2.as_raw_fd();
let fs_cache = SlaveFsCacheReq::from_stream(p1);
fs_cache.set_failed(libc::ECONNRESET);
fs_cache
- .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd)
+ .fs_slave_map(&VhostUserFSSlaveMsg::default(), &p2)
.unwrap_err();
fs_cache
.fs_slave_unmap(&VhostUserFSSlaveMsg::default())
.unwrap_err();
fs_cache.node().error = None;
-
- drop(p2);
- fs_cache
- .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd)
- .unwrap_err();
- fs_cache
- .fs_slave_unmap(&VhostUserFSSlaveMsg::default())
- .unwrap_err();
}
#[test]
fn test_slave_fs_cache_recv_negative() {
let (p1, p2) = UnixStream::pair().unwrap();
- let fd = p2.as_raw_fd();
let fs_cache = SlaveFsCacheReq::from_stream(p1);
let mut master = Endpoint::<SlaveReq>::from_stream(p2);
@@ -194,33 +183,35 @@ mod tests {
);
let body = VhostUserU64::new(0);
- master.send_message(&hdr, &body, Some(&[fd])).unwrap();
+ master
+ .send_message(&hdr, &body, Some(&[master.as_raw_fd()]))
+ .unwrap();
fs_cache
- .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd)
+ .fs_slave_map(&VhostUserFSSlaveMsg::default(), &master)
.unwrap();
fs_cache.set_reply_ack_flag(true);
fs_cache
- .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd)
+ .fs_slave_map(&VhostUserFSSlaveMsg::default(), &master)
.unwrap_err();
hdr.set_code(SlaveReq::FS_UNMAP);
master.send_message(&hdr, &body, None).unwrap();
fs_cache
- .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd)
+ .fs_slave_map(&VhostUserFSSlaveMsg::default(), &master)
.unwrap_err();
hdr.set_code(SlaveReq::FS_MAP);
let body = VhostUserU64::new(1);
master.send_message(&hdr, &body, None).unwrap();
fs_cache
- .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd)
+ .fs_slave_map(&VhostUserFSSlaveMsg::default(), &master)
.unwrap_err();
let body = VhostUserU64::new(0);
master.send_message(&hdr, &body, None).unwrap();
fs_cache
- .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd)
+ .fs_slave_map(&VhostUserFSSlaveMsg::default(), &master)
.unwrap();
}
}
diff --git a/src/vhost_user/slave_req_handler.rs b/src/vhost_user/slave_req_handler.rs
index 710b5f5..7c3de7d 100644
--- a/src/vhost_user/slave_req_handler.rs
+++ b/src/vhost_user/slave_req_handler.rs
@@ -3,7 +3,7 @@
use std::fs::File;
use std::mem;
-use std::os::unix::io::{AsRawFd, FromRawFd, RawFd};
+use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd};
use std::os::unix::net::UnixStream;
use std::slice;
use std::sync::{Arc, Mutex};
@@ -39,7 +39,7 @@ pub trait VhostUserSlaveReqHandler {
fn reset_owner(&self) -> Result<()>;
fn get_features(&self) -> Result<u64>;
fn set_features(&self, features: u64) -> Result<()>;
- fn set_mem_table(&self, ctx: &[VhostUserMemoryRegion], fds: &[RawFd]) -> Result<()>;
+ fn set_mem_table(&self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()>;
fn set_vring_num(&self, index: u32, num: u32) -> Result<()>;
fn set_vring_addr(
&self,
@@ -52,9 +52,9 @@ pub trait VhostUserSlaveReqHandler {
) -> Result<()>;
fn set_vring_base(&self, index: u32, base: u32) -> Result<()>;
fn get_vring_base(&self, index: u32) -> Result<VhostUserVringState>;
- fn set_vring_kick(&self, index: u8, fd: Option<RawFd>) -> Result<()>;
- fn set_vring_call(&self, index: u8, fd: Option<RawFd>) -> Result<()>;
- fn set_vring_err(&self, index: u8, fd: Option<RawFd>) -> Result<()>;
+ fn set_vring_kick(&self, index: u8, fd: Option<File>) -> Result<()>;
+ fn set_vring_call(&self, index: u8, fd: Option<File>) -> Result<()>;
+ fn set_vring_err(&self, index: u8, fd: Option<File>) -> Result<()>;
fn get_protocol_features(&self) -> Result<VhostUserProtocolFeatures>;
fn set_protocol_features(&self, features: u64) -> Result<()>;
@@ -63,10 +63,10 @@ pub trait VhostUserSlaveReqHandler {
fn get_config(&self, offset: u32, size: u32, flags: VhostUserConfigFlags) -> Result<Vec<u8>>;
fn set_config(&self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>;
fn set_slave_req_fd(&self, _vu_req: SlaveFsCacheReq) {}
- fn get_inflight_fd(&self, inflight: &VhostUserInflight) -> Result<(VhostUserInflight, RawFd)>;
+ fn get_inflight_fd(&self, inflight: &VhostUserInflight) -> Result<(VhostUserInflight, File)>;
fn set_inflight_fd(&self, inflight: &VhostUserInflight, file: File) -> Result<()>;
fn get_max_mem_slots(&self) -> Result<u64>;
- fn add_mem_region(&self, region: &VhostUserSingleMemoryRegion, fd: RawFd) -> Result<()>;
+ fn add_mem_region(&self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()>;
fn remove_mem_region(&self, region: &VhostUserSingleMemoryRegion) -> Result<()>;
}
@@ -79,7 +79,7 @@ pub trait VhostUserSlaveReqHandlerMut {
fn reset_owner(&mut self) -> Result<()>;
fn get_features(&mut self) -> Result<u64>;
fn set_features(&mut self, features: u64) -> Result<()>;
- fn set_mem_table(&mut self, ctx: &[VhostUserMemoryRegion], fds: &[RawFd]) -> Result<()>;
+ fn set_mem_table(&mut self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()>;
fn set_vring_num(&mut self, index: u32, num: u32) -> Result<()>;
fn set_vring_addr(
&mut self,
@@ -92,9 +92,9 @@ pub trait VhostUserSlaveReqHandlerMut {
) -> Result<()>;
fn set_vring_base(&mut self, index: u32, base: u32) -> Result<()>;
fn get_vring_base(&mut self, index: u32) -> Result<VhostUserVringState>;
- fn set_vring_kick(&mut self, index: u8, fd: Option<RawFd>) -> Result<()>;
- fn set_vring_call(&mut self, index: u8, fd: Option<RawFd>) -> Result<()>;
- fn set_vring_err(&mut self, index: u8, fd: Option<RawFd>) -> Result<()>;
+ fn set_vring_kick(&mut self, index: u8, fd: Option<File>) -> Result<()>;
+ fn set_vring_call(&mut self, index: u8, fd: Option<File>) -> Result<()>;
+ fn set_vring_err(&mut self, index: u8, fd: Option<File>) -> Result<()>;
fn get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures>;
fn set_protocol_features(&mut self, features: u64) -> Result<()>;
@@ -111,10 +111,10 @@ pub trait VhostUserSlaveReqHandlerMut {
fn get_inflight_fd(
&mut self,
inflight: &VhostUserInflight,
- ) -> Result<(VhostUserInflight, RawFd)>;
+ ) -> Result<(VhostUserInflight, File)>;
fn set_inflight_fd(&mut self, inflight: &VhostUserInflight, file: File) -> Result<()>;
fn get_max_mem_slots(&mut self) -> Result<u64>;
- fn add_mem_region(&mut self, region: &VhostUserSingleMemoryRegion, fd: RawFd) -> Result<()>;
+ fn add_mem_region(&mut self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()>;
fn remove_mem_region(&mut self, region: &VhostUserSingleMemoryRegion) -> Result<()>;
}
@@ -135,8 +135,8 @@ impl<T: VhostUserSlaveReqHandlerMut> VhostUserSlaveReqHandler for Mutex<T> {
self.lock().unwrap().set_features(features)
}
- fn set_mem_table(&self, ctx: &[VhostUserMemoryRegion], fds: &[RawFd]) -> Result<()> {
- self.lock().unwrap().set_mem_table(ctx, fds)
+ fn set_mem_table(&self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()> {
+ self.lock().unwrap().set_mem_table(ctx, files)
}
fn set_vring_num(&self, index: u32, num: u32) -> Result<()> {
@@ -165,15 +165,15 @@ impl<T: VhostUserSlaveReqHandlerMut> VhostUserSlaveReqHandler for Mutex<T> {
self.lock().unwrap().get_vring_base(index)
}
- fn set_vring_kick(&self, index: u8, fd: Option<RawFd>) -> Result<()> {
+ fn set_vring_kick(&self, index: u8, fd: Option<File>) -> Result<()> {
self.lock().unwrap().set_vring_kick(index, fd)
}
- fn set_vring_call(&self, index: u8, fd: Option<RawFd>) -> Result<()> {
+ fn set_vring_call(&self, index: u8, fd: Option<File>) -> Result<()> {
self.lock().unwrap().set_vring_call(index, fd)
}
- fn set_vring_err(&self, index: u8, fd: Option<RawFd>) -> Result<()> {
+ fn set_vring_err(&self, index: u8, fd: Option<File>) -> Result<()> {
self.lock().unwrap().set_vring_err(index, fd)
}
@@ -205,7 +205,7 @@ impl<T: VhostUserSlaveReqHandlerMut> VhostUserSlaveReqHandler for Mutex<T> {
self.lock().unwrap().set_slave_req_fd(vu_req)
}
- fn get_inflight_fd(&self, inflight: &VhostUserInflight) -> Result<(VhostUserInflight, RawFd)> {
+ fn get_inflight_fd(&self, inflight: &VhostUserInflight) -> Result<(VhostUserInflight, File)> {
self.lock().unwrap().get_inflight_fd(inflight)
}
@@ -217,7 +217,7 @@ impl<T: VhostUserSlaveReqHandlerMut> VhostUserSlaveReqHandler for Mutex<T> {
self.lock().unwrap().get_max_mem_slots()
}
- fn add_mem_region(&self, region: &VhostUserSingleMemoryRegion, fd: RawFd) -> Result<()> {
+ fn add_mem_region(&self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()> {
self.lock().unwrap().add_mem_region(region, fd)
}
@@ -307,8 +307,9 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
// . recv optional message body and payload according size field in
// message header
// . validate message body and optional payload
- let (hdr, rfds) = self.main_sock.recv_header()?;
- let rfds = self.check_attached_rfds(&hdr, rfds)?;
+ let (hdr, files) = self.main_sock.recv_header()?;
+ self.check_attached_files(&hdr, &files)?;
+
let (size, buf) = match hdr.get_size() {
0 => (0, vec![0u8; 0]),
len => {
@@ -347,7 +348,7 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
self.send_ack_message(&hdr, res)?;
}
MasterReq::SET_MEM_TABLE => {
- let res = self.set_mem_table(&hdr, size, &buf, rfds);
+ let res = self.set_mem_table(&hdr, size, &buf, files);
self.send_ack_message(&hdr, res)?;
}
MasterReq::SET_VRING_NUM => {
@@ -383,20 +384,20 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
}
MasterReq::SET_VRING_CALL => {
self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?;
- let (index, rfds) = self.handle_vring_fd_request(&buf, rfds)?;
- let res = self.backend.set_vring_call(index, rfds);
+ let (index, file) = self.handle_vring_fd_request(&buf, files)?;
+ let res = self.backend.set_vring_call(index, file);
self.send_ack_message(&hdr, res)?;
}
MasterReq::SET_VRING_KICK => {
self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?;
- let (index, rfds) = self.handle_vring_fd_request(&buf, rfds)?;
- let res = self.backend.set_vring_kick(index, rfds);
+ let (index, file) = self.handle_vring_fd_request(&buf, files)?;
+ let res = self.backend.set_vring_kick(index, file);
self.send_ack_message(&hdr, res)?;
}
MasterReq::SET_VRING_ERR => {
self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?;
- let (index, rfds) = self.handle_vring_fd_request(&buf, rfds)?;
- let res = self.backend.set_vring_err(index, rfds);
+ let (index, file) = self.handle_vring_fd_request(&buf, files)?;
+ let res = self.backend.set_vring_err(index, file);
self.send_ack_message(&hdr, res)?;
}
MasterReq::GET_PROTOCOL_FEATURES => {
@@ -459,7 +460,7 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
return Err(Error::InvalidOperation);
}
self.check_request_size(&hdr, size, hdr.get_size() as usize)?;
- let res = self.set_slave_req_fd(rfds);
+ let res = self.set_slave_req_fd(files);
self.send_ack_message(&hdr, res)?;
}
MasterReq::GET_INFLIGHT_FD => {
@@ -470,10 +471,10 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
}
let msg = self.extract_request_body::<VhostUserInflight>(&hdr, size, &buf)?;
- let (inflight, fd) = self.backend.get_inflight_fd(&msg)?;
+ let (inflight, file) = self.backend.get_inflight_fd(&msg)?;
let reply_hdr = self.new_reply_header::<VhostUserInflight>(&hdr, 0)?;
self.main_sock
- .send_message(&reply_hdr, &inflight, Some(&[fd]))?;
+ .send_message(&reply_hdr, &inflight, Some(&[file.as_raw_fd()]))?;
}
MasterReq::SET_INFLIGHT_FD => {
if self.acked_protocol_features & VhostUserProtocolFeatures::INFLIGHT_SHMFD.bits()
@@ -481,18 +482,7 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
{
return Err(Error::InvalidOperation);
}
- let file = if let Some(fds) = rfds {
- if fds.len() != 1 || fds[0] < 0 {
- Endpoint::<MasterReq>::close_rfds(Some(fds));
- return Err(Error::IncorrectFds);
- }
-
- // Safe because we know the fd is valid.
- unsafe { File::from_raw_fd(fds[0]) }
- } else {
- return Err(Error::IncorrectFds);
- };
-
+ let file = take_single_file(files).ok_or(Error::IncorrectFds)?;
let msg = self.extract_request_body::<VhostUserInflight>(&hdr, size, &buf)?;
let res = self.backend.set_inflight_fd(&msg, file);
self.send_ack_message(&hdr, res)?;
@@ -516,18 +506,13 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
{
return Err(Error::InvalidOperation);
}
- let fd = if let Some(fds) = &rfds {
- if fds.len() != 1 {
- return Err(Error::InvalidParam);
- }
- fds[0]
- } else {
+ let mut files = files.ok_or(Error::InvalidParam)?;
+ if files.len() != 1 {
return Err(Error::InvalidParam);
- };
-
+ }
let msg =
self.extract_request_body::<VhostUserSingleMemoryRegion>(&hdr, size, &buf)?;
- let res = self.backend.add_mem_region(&msg, fd);
+ let res = self.backend.add_mem_region(&msg, files.swap_remove(0));
self.send_ack_message(&hdr, res)?;
}
MasterReq::REM_MEM_REG => {
@@ -555,37 +540,28 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
hdr: &VhostUserMsgHeader<MasterReq>,
size: usize,
buf: &[u8],
- rfds: Option<Vec<RawFd>>,
+ files: Option<Vec<File>>,
) -> Result<()> {
self.check_request_size(&hdr, size, hdr.get_size() as usize)?;
// check message size is consistent
let hdrsize = mem::size_of::<VhostUserMemory>();
if size < hdrsize {
- Endpoint::<MasterReq>::close_rfds(rfds);
return Err(Error::InvalidMessage);
}
let msg = unsafe { &*(buf.as_ptr() as *const VhostUserMemory) };
if !msg.is_valid() {
- Endpoint::<MasterReq>::close_rfds(rfds);
return Err(Error::InvalidMessage);
}
if size != hdrsize + msg.num_regions as usize * mem::size_of::<VhostUserMemoryRegion>() {
- Endpoint::<MasterReq>::close_rfds(rfds);
return Err(Error::InvalidMessage);
}
// validate number of fds matching number of memory regions
- let fds = match rfds {
- None => return Err(Error::InvalidMessage),
- Some(fds) => {
- if fds.len() != msg.num_regions as usize {
- Endpoint::<MasterReq>::close_rfds(Some(fds));
- return Err(Error::InvalidMessage);
- }
- fds
- }
- };
+ let files = files.ok_or(Error::InvalidMessage)?;
+ if files.len() != msg.num_regions as usize {
+ return Err(Error::InvalidMessage);
+ }
// Validate memory regions
let regions = unsafe {
@@ -596,12 +572,11 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
};
for region in regions.iter() {
if !region.is_valid() {
- Endpoint::<MasterReq>::close_rfds(Some(fds));
return Err(Error::InvalidMessage);
}
}
- self.backend.set_mem_table(&regions, &fds)
+ self.backend.set_mem_table(&regions, files)
}
fn get_config(&mut self, hdr: &VhostUserMsgHeader<MasterReq>, buf: &[u8]) -> Result<()> {
@@ -662,26 +637,19 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
self.backend.set_config(msg.offset, buf, flags)
}
- fn set_slave_req_fd(&mut self, rfds: Option<Vec<RawFd>>) -> Result<()> {
- if let Some(fds) = rfds {
- if fds.len() == 1 {
- let sock = unsafe { UnixStream::from_raw_fd(fds[0]) };
- let vu_req = SlaveFsCacheReq::from_stream(sock);
- self.backend.set_slave_req_fd(vu_req);
- Ok(())
- } else {
- Err(Error::InvalidMessage)
- }
- } else {
- Err(Error::InvalidMessage)
- }
+ fn set_slave_req_fd(&mut self, files: Option<Vec<File>>) -> Result<()> {
+ let file = take_single_file(files).ok_or(Error::InvalidMessage)?;
+ let sock = unsafe { UnixStream::from_raw_fd(file.into_raw_fd()) };
+ let vu_req = SlaveFsCacheReq::from_stream(sock);
+ self.backend.set_slave_req_fd(vu_req);
+ Ok(())
}
fn handle_vring_fd_request(
&mut self,
buf: &[u8],
- rfds: Option<Vec<RawFd>>,
- ) -> Result<(u8, Option<RawFd>)> {
+ files: Option<Vec<File>>,
+ ) -> Result<(u8, Option<File>)> {
if buf.len() > MAX_MSG_SIZE || buf.len() < mem::size_of::<VhostUserU64>() {
return Err(Error::InvalidMessage);
}
@@ -691,28 +659,19 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
}
// Bits (0-7) of the payload contain the vring index. Bit 8 is the
- // invalid FD flag. This flag is set when there is no file descriptor
+ // invalid FD flag. This bit is set when there is no file descriptor
// in the ancillary data. This signals that polling will be used
// instead of waiting for the call.
- let nofd = (msg.value & 0x100u64) == 0x100u64;
-
- let mut rfd = None;
- match rfds {
- Some(fds) => {
- if !nofd && fds.len() == 1 {
- rfd = Some(fds[0]);
- } else if (nofd && !fds.is_empty()) || (!nofd && fds.len() != 1) {
- Endpoint::<MasterReq>::close_rfds(Some(fds));
- return Err(Error::InvalidMessage);
- }
- }
- None => {
- if !nofd {
- return Err(Error::InvalidMessage);
- }
- }
+ // If Bit 8 is unset, the data must contain a file descriptor.
+ let has_fd = (msg.value & 0x100u64) == 0;
+
+ let file = take_single_file(files);
+
+ if has_fd && file.is_none() || !has_fd && file.is_some() {
+ return Err(Error::InvalidMessage);
}
- Ok((msg.value as u8, rfd))
+
+ Ok((msg.value as u8, file))
}
fn check_state(&self) -> Result<()> {
@@ -738,29 +697,23 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
Ok(())
}
- fn check_attached_rfds(
+ fn check_attached_files(
&self,
hdr: &VhostUserMsgHeader<MasterReq>,
- rfds: Option<Vec<RawFd>>,
- ) -> Result<Option<Vec<RawFd>>> {
+ files: &Option<Vec<File>>,
+ ) -> Result<()> {
match hdr.get_code() {
- MasterReq::SET_MEM_TABLE => Ok(rfds),
- MasterReq::SET_VRING_CALL => Ok(rfds),
- MasterReq::SET_VRING_KICK => Ok(rfds),
- MasterReq::SET_VRING_ERR => Ok(rfds),
- MasterReq::SET_LOG_BASE => Ok(rfds),
- MasterReq::SET_LOG_FD => Ok(rfds),
- MasterReq::SET_SLAVE_REQ_FD => Ok(rfds),
- MasterReq::SET_INFLIGHT_FD => Ok(rfds),
- MasterReq::ADD_MEM_REG => Ok(rfds),
- _ => {
- if rfds.is_some() {
- Endpoint::<MasterReq>::close_rfds(rfds);
- Err(Error::InvalidMessage)
- } else {
- Ok(rfds)
- }
- }
+ MasterReq::SET_MEM_TABLE
+ | MasterReq::SET_VRING_CALL
+ | MasterReq::SET_VRING_KICK
+ | MasterReq::SET_VRING_ERR
+ | MasterReq::SET_LOG_BASE
+ | MasterReq::SET_LOG_FD
+ | MasterReq::SET_SLAVE_REQ_FD
+ | MasterReq::SET_INFLIGHT_FD
+ | MasterReq::ADD_MEM_REG => Ok(()),
+ _ if files.is_some() => Err(Error::InvalidMessage),
+ _ => Ok(()),
}
}
@@ -850,6 +803,16 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
}
}
+/// Utility function to take the first element from option of a vector of files.
+/// Returns `None` if the vector contains no file or more than one file.
+pub(crate) fn take_single_file(files: Option<Vec<File>>) -> Option<File> {
+ let mut files = files?;
+ if files.len() != 1 {
+ return None;
+ }
+ Some(files.swap_remove(0))
+}
+
impl<S: VhostUserSlaveReqHandler> AsRawFd for SlaveReqHandler<S> {
fn as_raw_fd(&self) -> RawFd {
self.main_sock.as_raw_fd()