diff options
author | Keiichi Watanabe <keiichiw@chromium.org> | 2021-05-10 21:17:24 +0900 |
---|---|---|
committer | Sergio Lopez <slp@sinrega.org> | 2021-06-23 09:52:42 +0200 |
commit | 9982541776a603d30556a06df55e8c0491072763 (patch) | |
tree | 4b89a4ffe35e629c8d5d7c2f1a1e26a723214901 | |
parent | 1a03a2aca700421f258b094b1a845f6420d8cfa9 (diff) | |
download | vmm_vhost-9982541776a603d30556a06df55e8c0491072763.tar.gz |
vhost_user: Stop passing around RawFd
Use `File` or `dyn AsRawFd` instead of `RawFd` to handle ownership
easily.
Fixes #37.
Signed-off-by: Keiichi Watanabe <keiichiw@chromium.org>
Change-Id: I6c79d73d1a54163d4612b0ca4d30bf7bd53f9b0f
-rw-r--r-- | coverage_config_x86_64.json | 2 | ||||
-rw-r--r-- | src/vhost_user/connection.rs | 153 | ||||
-rw-r--r-- | src/vhost_user/dummy_slave.rs | 36 | ||||
-rw-r--r-- | src/vhost_user/master.rs | 53 | ||||
-rw-r--r-- | src/vhost_user/master_req_handler.rs | 78 | ||||
-rw-r--r-- | src/vhost_user/mod.rs | 2 | ||||
-rw-r--r-- | src/vhost_user/slave_fs_cache.rs | 33 | ||||
-rw-r--r-- | src/vhost_user/slave_req_handler.rs | 209 |
8 files changed, 235 insertions, 331 deletions
diff --git a/coverage_config_x86_64.json b/coverage_config_x86_64.json index 37ecef8..7db1dca 100644 --- a/coverage_config_x86_64.json +++ b/coverage_config_x86_64.json @@ -1 +1 @@ -{"coverage_score": 82.7, "exclude_path": "src/vhost_kern/", "crate_features": "vhost-user-master,vhost-user-slave"} +{"coverage_score": 83.6, "exclude_path": "src/vhost_kern/", "crate_features": "vhost-user-master,vhost-user-slave"} diff --git a/src/vhost_user/connection.rs b/src/vhost_user/connection.rs index fef7dac..ead84c5 100644 --- a/src/vhost_user/connection.rs +++ b/src/vhost_user/connection.rs @@ -5,9 +5,10 @@ #![allow(dead_code)] +use std::fs::File; use std::io::ErrorKind; use std::marker::PhantomData; -use std::os::unix::io::{AsRawFd, RawFd}; +use std::os::unix::io::{AsRawFd, FromRawFd, RawFd}; use std::os::unix::net::{UnixListener, UnixStream}; use std::path::{Path, PathBuf}; use std::{mem, slice}; @@ -305,7 +306,7 @@ impl<R: Req> Endpoint<R> { } /// Reads bytes from the socket into the given scatter/gather vectors with optional attached - /// file descriptors. + /// file. /// /// The underlying communication channel is a Unix domain socket in STREAM mode. It's a little /// tricky to pass file descriptors through such a communication channel. Let's assume that a @@ -315,29 +316,37 @@ impl<R: Req> Endpoint<R> { /// 2) message(packet) boundaries must be respected on the receive side. /// In other words, recvmsg() operations must not cross the packet boundary, otherwise the /// attached file descriptors will get lost. + /// Note that this function wraps received file descriptors as `File`. /// /// # Return: - /// * - (number of bytes received, [received fds]) on success + /// * - (number of bytes received, [received files]) on success /// * - SocketRetry: temporary error caused by signals or short of resources. /// * - SocketBroken: the underline socket is broken. /// * - SocketError: other socket related errors. - pub fn recv_into_iovec(&mut self, iovs: &mut [iovec]) -> Result<(usize, Option<Vec<RawFd>>)> { + pub fn recv_into_iovec(&mut self, iovs: &mut [iovec]) -> Result<(usize, Option<Vec<File>>)> { let mut fd_array = vec![0; MAX_ATTACHED_FD_ENTRIES]; let (bytes, fds) = self.sock.recv_with_fds(iovs, &mut fd_array)?; - let rfds = match fds { + + let files = match fds { 0 => None, n => { - let mut fds = Vec::with_capacity(n); - fds.extend_from_slice(&fd_array[0..n]); - Some(fds) + let files = fd_array + .iter() + .take(n) + .map(|fd| { + // Safe because we have the ownership of `fd`. + unsafe { File::from_raw_fd(*fd) } + }) + .collect(); + Some(files) } }; - Ok((bytes, rfds)) + Ok((bytes, files)) } /// Reads all bytes from the socket into the given scatter/gather vectors with optional - /// attached file descriptors. Will loop until all data has been transfered. + /// attached files. Will loop until all data has been transferred. /// /// The underlying communication channel is a Unix domain socket in STREAM mode. It's a little /// tricky to pass file descriptors through such a communication channel. Let's assume that a @@ -347,6 +356,7 @@ impl<R: Req> Endpoint<R> { /// 2) message(packet) boundaries must be respected on the receive side. /// In other words, recvmsg() operations must not cross the packet boundary, otherwise the /// attached file descriptors will get lost. + /// Note that this function wraps received file descriptors as `File`. /// /// # Return: /// * - (number of bytes received, [received fds]) on success @@ -355,7 +365,7 @@ impl<R: Req> Endpoint<R> { pub fn recv_into_iovec_all( &mut self, iovs: &mut [iovec], - ) -> Result<(usize, Option<Vec<RawFd>>)> { + ) -> Result<(usize, Option<Vec<File>>)> { let mut data_read = 0; let mut data_total = 0; let mut rfds = None; @@ -396,46 +406,46 @@ impl<R: Req> Endpoint<R> { } /// Reads bytes from the socket into a new buffer with optional attached - /// file descriptors. Received file descriptors are set close-on-exec. + /// files. Received file descriptors are set close-on-exec and converted to `File`. /// /// # Return: - /// * - (number of bytes received, buf, [received fds]) on success. + /// * - (number of bytes received, buf, [received files]) on success. /// * - SocketRetry: temporary error caused by signals or short of resources. /// * - SocketBroken: the underline socket is broken. /// * - SocketError: other socket related errors. pub fn recv_into_buf( &mut self, buf_size: usize, - ) -> Result<(usize, Vec<u8>, Option<Vec<RawFd>>)> { + ) -> Result<(usize, Vec<u8>, Option<Vec<File>>)> { let mut buf = vec![0u8; buf_size]; - let (bytes, rfds) = { + let (bytes, files) = { let mut iovs = [iovec { iov_base: buf.as_mut_ptr() as *mut c_void, iov_len: buf_size, }]; self.recv_into_iovec(&mut iovs)? }; - Ok((bytes, buf, rfds)) + Ok((bytes, buf, files)) } - /// Receive a header-only message with optional attached file descriptors. + /// Receive a header-only message with optional attached files. /// Note, only the first MAX_ATTACHED_FD_ENTRIES file descriptors will be /// accepted and all other file descriptor will be discard silently. /// /// # Return: - /// * - (message header, [received fds]) on success. + /// * - (message header, [received files]) on success. /// * - SocketRetry: temporary error caused by signals or short of resources. /// * - SocketBroken: the underline socket is broken. /// * - SocketError: other socket related errors. /// * - PartialMessage: received a partial message. /// * - InvalidMessage: received a invalid message. - pub fn recv_header(&mut self) -> Result<(VhostUserMsgHeader<R>, Option<Vec<RawFd>>)> { + pub fn recv_header(&mut self) -> Result<(VhostUserMsgHeader<R>, Option<Vec<File>>)> { let mut hdr = VhostUserMsgHeader::default(); let mut iovs = [iovec { iov_base: (&mut hdr as *mut VhostUserMsgHeader<R>) as *mut c_void, iov_len: mem::size_of::<VhostUserMsgHeader<R>>(), }]; - let (bytes, rfds) = self.recv_into_iovec_all(&mut iovs[..])?; + let (bytes, files) = self.recv_into_iovec_all(&mut iovs[..])?; if bytes != mem::size_of::<VhostUserMsgHeader<R>>() { return Err(Error::PartialMessage); @@ -443,7 +453,7 @@ impl<R: Req> Endpoint<R> { return Err(Error::InvalidMessage); } - Ok((hdr, rfds)) + Ok((hdr, files)) } /// Receive a message with optional attached file descriptors. @@ -451,7 +461,7 @@ impl<R: Req> Endpoint<R> { /// accepted and all other file descriptor will be discard silently. /// /// # Return: - /// * - (message header, message body, [received fds]) on success. + /// * - (message header, message body, [received files]) on success. /// * - SocketRetry: temporary error caused by signals or short of resources. /// * - SocketBroken: the underline socket is broken. /// * - SocketError: other socket related errors. @@ -459,7 +469,7 @@ impl<R: Req> Endpoint<R> { /// * - InvalidMessage: received a invalid message. pub fn recv_body<T: Sized + Default + VhostUserMsgValidator>( &mut self, - ) -> Result<(VhostUserMsgHeader<R>, T, Option<Vec<RawFd>>)> { + ) -> Result<(VhostUserMsgHeader<R>, T, Option<Vec<File>>)> { let mut hdr = VhostUserMsgHeader::default(); let mut body: T = Default::default(); let mut iovs = [ @@ -472,7 +482,7 @@ impl<R: Req> Endpoint<R> { iov_len: mem::size_of::<T>(), }, ]; - let (bytes, rfds) = self.recv_into_iovec_all(&mut iovs[..])?; + let (bytes, files) = self.recv_into_iovec_all(&mut iovs[..])?; let total = mem::size_of::<VhostUserMsgHeader<R>>() + mem::size_of::<T>(); if bytes != total { @@ -481,7 +491,7 @@ impl<R: Req> Endpoint<R> { return Err(Error::InvalidMessage); } - Ok((hdr, body, rfds)) + Ok((hdr, body, files)) } /// Receive a message with header and optional content. Callers need to @@ -492,7 +502,7 @@ impl<R: Req> Endpoint<R> { /// silently. /// /// # Return: - /// * - (message header, message size, [received fds]) on success. + /// * - (message header, message size, [received files]) on success. /// * - SocketRetry: temporary error caused by signals or short of resources. /// * - SocketBroken: the underline socket is broken. /// * - SocketError: other socket related errors. @@ -501,7 +511,7 @@ impl<R: Req> Endpoint<R> { pub fn recv_body_into_buf( &mut self, buf: &mut [u8], - ) -> Result<(VhostUserMsgHeader<R>, usize, Option<Vec<RawFd>>)> { + ) -> Result<(VhostUserMsgHeader<R>, usize, Option<Vec<File>>)> { let mut hdr = VhostUserMsgHeader::default(); let mut iovs = [ iovec { @@ -513,7 +523,7 @@ impl<R: Req> Endpoint<R> { iov_len: buf.len(), }, ]; - let (bytes, rfds) = self.recv_into_iovec_all(&mut iovs[..])?; + let (bytes, files) = self.recv_into_iovec_all(&mut iovs[..])?; if bytes < mem::size_of::<VhostUserMsgHeader<R>>() { return Err(Error::PartialMessage); @@ -521,7 +531,7 @@ impl<R: Req> Endpoint<R> { return Err(Error::InvalidMessage); } - Ok((hdr, bytes - mem::size_of::<VhostUserMsgHeader<R>>(), rfds)) + Ok((hdr, bytes - mem::size_of::<VhostUserMsgHeader<R>>(), files)) } /// Receive a message with optional payload and attached file descriptors. @@ -529,7 +539,7 @@ impl<R: Req> Endpoint<R> { /// accepted and all other file descriptor will be discard silently. /// /// # Return: - /// * - (message header, message body, size of payload, [received fds]) on success. + /// * - (message header, message body, size of payload, [received files]) on success. /// * - SocketRetry: temporary error caused by signals or short of resources. /// * - SocketBroken: the underline socket is broken. /// * - SocketError: other socket related errors. @@ -539,7 +549,7 @@ impl<R: Req> Endpoint<R> { pub fn recv_payload_into_buf<T: Sized + Default + VhostUserMsgValidator>( &mut self, buf: &mut [u8], - ) -> Result<(VhostUserMsgHeader<R>, T, usize, Option<Vec<RawFd>>)> { + ) -> Result<(VhostUserMsgHeader<R>, T, usize, Option<Vec<File>>)> { let mut hdr = VhostUserMsgHeader::default(); let mut body: T = Default::default(); let mut iovs = [ @@ -556,7 +566,7 @@ impl<R: Req> Endpoint<R> { iov_len: buf.len(), }, ]; - let (bytes, rfds) = self.recv_into_iovec_all(&mut iovs[..])?; + let (bytes, files) = self.recv_into_iovec_all(&mut iovs[..])?; let total = mem::size_of::<VhostUserMsgHeader<R>>() + mem::size_of::<T>(); if bytes < total { @@ -565,17 +575,7 @@ impl<R: Req> Endpoint<R> { return Err(Error::InvalidMessage); } - Ok((hdr, body, bytes - total, rfds)) - } - - /// Close all raw file descriptors. - pub fn close_rfds(rfds: Option<Vec<RawFd>>) { - if let Some(fds) = rfds { - for fd in fds { - // safe because the rawfds are valid and we don't care about the result. - let _ = unsafe { libc::close(fd) }; - } - } + Ok((hdr, body, bytes - total, files)) } } @@ -608,9 +608,7 @@ fn get_sub_iovs_offset(iov_lens: &[usize], skip_size: usize) -> (usize, usize) { #[cfg(test)] mod tests { use super::*; - use std::fs::File; use std::io::{Read, Seek, SeekFrom, Write}; - use std::os::unix::io::FromRawFd; use vmm_sys_util::rand::rand_alphanumerics; use vmm_sys_util::tempfile::TempFile; @@ -685,14 +683,14 @@ mod tests { .unwrap(); assert_eq!(len, 4); - let (bytes, buf2, rfds) = slave.recv_into_buf(4).unwrap(); + let (bytes, buf2, files) = slave.recv_into_buf(4).unwrap(); assert_eq!(bytes, 4); assert_eq!(&buf1[..], &buf2[..]); - assert!(rfds.is_some()); - let fds = rfds.unwrap(); + assert!(files.is_some()); + let files = files.unwrap(); { - assert_eq!(fds.len(), 1); - let mut file = unsafe { File::from_raw_fd(fds[0]) }; + assert_eq!(files.len(), 1); + let mut file = &files[0]; let mut content = String::new(); file.seek(SeekFrom::Start(0)).unwrap(); file.read_to_string(&mut content).unwrap(); @@ -710,23 +708,23 @@ mod tests { .unwrap(); assert_eq!(len, 4); - let (bytes, buf2, rfds) = slave.recv_into_buf(0x2).unwrap(); + let (bytes, buf2, files) = slave.recv_into_buf(0x2).unwrap(); assert_eq!(bytes, 2); assert_eq!(&buf1[..2], &buf2[..]); - assert!(rfds.is_some()); - let fds = rfds.unwrap(); + assert!(files.is_some()); + let files = files.unwrap(); { - assert_eq!(fds.len(), 3); - let mut file = unsafe { File::from_raw_fd(fds[1]) }; + assert_eq!(files.len(), 3); + let mut file = &files[1]; let mut content = String::new(); file.seek(SeekFrom::Start(0)).unwrap(); file.read_to_string(&mut content).unwrap(); assert_eq!(content, "test"); } - let (bytes, buf2, rfds) = slave.recv_into_buf(0x2).unwrap(); + let (bytes, buf2, files) = slave.recv_into_buf(0x2).unwrap(); assert_eq!(bytes, 2); assert_eq!(&buf1[2..], &buf2[..]); - assert!(rfds.is_none()); + assert!(files.is_none()); // Following communication pattern should not work: // Sending side: data(header, body) with fds @@ -742,10 +740,10 @@ mod tests { let (bytes, buf4) = slave.recv_data(2).unwrap(); assert_eq!(bytes, 2); assert_eq!(&buf1[..2], &buf4[..]); - let (bytes, buf2, rfds) = slave.recv_into_buf(0x2).unwrap(); + let (bytes, buf2, files) = slave.recv_into_buf(0x2).unwrap(); assert_eq!(bytes, 2); assert_eq!(&buf1[2..], &buf2[..]); - assert!(rfds.is_none()); + assert!(files.is_none()); // Following communication pattern should work: // Sending side: data, data with fds @@ -760,28 +758,28 @@ mod tests { .unwrap(); assert_eq!(len, 4); - let (bytes, buf2, rfds) = slave.recv_into_buf(0x4).unwrap(); + let (bytes, buf2, files) = slave.recv_into_buf(0x4).unwrap(); assert_eq!(bytes, 4); assert_eq!(&buf1[..], &buf2[..]); - assert!(rfds.is_none()); + assert!(files.is_none()); - let (bytes, buf2, rfds) = slave.recv_into_buf(0x2).unwrap(); + let (bytes, buf2, files) = slave.recv_into_buf(0x2).unwrap(); assert_eq!(bytes, 2); assert_eq!(&buf1[..2], &buf2[..]); - assert!(rfds.is_some()); - let fds = rfds.unwrap(); + assert!(files.is_some()); + let files = files.unwrap(); { - assert_eq!(fds.len(), 3); - let mut file = unsafe { File::from_raw_fd(fds[1]) }; + assert_eq!(files.len(), 3); + let mut file = &files[1]; let mut content = String::new(); file.seek(SeekFrom::Start(0)).unwrap(); file.read_to_string(&mut content).unwrap(); assert_eq!(content, "test"); } - let (bytes, buf2, rfds) = slave.recv_into_buf(0x2).unwrap(); + let (bytes, buf2, files) = slave.recv_into_buf(0x2).unwrap(); assert_eq!(bytes, 2); assert_eq!(&buf1[2..], &buf2[..]); - assert!(rfds.is_none()); + assert!(files.is_none()); // Following communication pattern should not work: // Sending side: data1, data2 with fds @@ -799,9 +797,9 @@ mod tests { let (bytes, _) = slave.recv_data(5).unwrap(); assert_eq!(bytes, 5); - let (bytes, _, rfds) = slave.recv_into_buf(0x4).unwrap(); + let (bytes, _, files) = slave.recv_into_buf(0x4).unwrap(); assert_eq!(bytes, 3); - assert!(rfds.is_none()); + assert!(files.is_none()); // If the target fd array is too small, extra file descriptors will get lost. let len = master @@ -812,12 +810,9 @@ mod tests { .unwrap(); assert_eq!(len, 4); - let (bytes, _, rfds) = slave.recv_into_buf(0x4).unwrap(); + let (bytes, _, files) = slave.recv_into_buf(0x4).unwrap(); assert_eq!(bytes, 4); - assert!(rfds.is_some()); - - Endpoint::<MasterReq>::close_rfds(rfds); - Endpoint::<MasterReq>::close_rfds(None); + assert!(files.is_some()); } #[test] @@ -842,15 +837,15 @@ mod tests { mem::size_of::<u64>(), ) }; - let (hdr2, bytes, rfds) = slave.recv_body_into_buf(slice).unwrap(); + let (hdr2, bytes, files) = slave.recv_body_into_buf(slice).unwrap(); assert_eq!(hdr1, hdr2); assert_eq!(bytes, 8); assert_eq!(features1, features2); - assert!(rfds.is_none()); + assert!(files.is_none()); master.send_header(&hdr1, None).unwrap(); - let (hdr2, rfds) = slave.recv_header().unwrap(); + let (hdr2, files) = slave.recv_header().unwrap(); assert_eq!(hdr1, hdr2); - assert!(rfds.is_none()); + assert!(files.is_none()); } } diff --git a/src/vhost_user/dummy_slave.rs b/src/vhost_user/dummy_slave.rs index dc9eed5..cc9a9fb 100644 --- a/src/vhost_user/dummy_slave.rs +++ b/src/vhost_user/dummy_slave.rs @@ -2,7 +2,6 @@ // SPDX-License-Identifier: Apache-2.0 use std::fs::File; -use std::os::unix::io::{AsRawFd, RawFd}; use super::message::*; use super::*; @@ -21,9 +20,9 @@ pub struct DummySlaveReqHandler { pub queue_num: usize, pub vring_num: [u32; MAX_QUEUE_NUM], pub vring_base: [u32; MAX_QUEUE_NUM], - pub call_fd: [Option<RawFd>; MAX_QUEUE_NUM], - pub kick_fd: [Option<RawFd>; MAX_QUEUE_NUM], - pub err_fd: [Option<RawFd>; MAX_QUEUE_NUM], + pub call_fd: [Option<File>; MAX_QUEUE_NUM], + pub kick_fd: [Option<File>; MAX_QUEUE_NUM], + pub err_fd: [Option<File>; MAX_QUEUE_NUM], pub vring_started: [bool; MAX_QUEUE_NUM], pub vring_enabled: [bool; MAX_QUEUE_NUM], pub inflight_file: Option<File>, @@ -85,7 +84,7 @@ impl VhostUserSlaveReqHandlerMut for DummySlaveReqHandler { Ok(()) } - fn set_mem_table(&mut self, _ctx: &[VhostUserMemoryRegion], _fds: &[RawFd]) -> Result<()> { + fn set_mem_table(&mut self, _ctx: &[VhostUserMemoryRegion], _files: Vec<File>) -> Result<()> { Ok(()) } @@ -136,14 +135,10 @@ impl VhostUserSlaveReqHandlerMut for DummySlaveReqHandler { )) } - fn set_vring_kick(&mut self, index: u8, fd: Option<RawFd>) -> Result<()> { + fn set_vring_kick(&mut self, index: u8, fd: Option<File>) -> Result<()> { if index as usize >= self.queue_num || index as usize > self.queue_num { return Err(Error::InvalidParam); } - if self.kick_fd[index as usize].is_some() { - // Close file descriptor set by previous operations. - let _ = unsafe { libc::close(self.kick_fd[index as usize].unwrap()) }; - } self.kick_fd[index as usize] = fd; // Quotation from vhost-user spec: @@ -157,26 +152,18 @@ impl VhostUserSlaveReqHandlerMut for DummySlaveReqHandler { Ok(()) } - fn set_vring_call(&mut self, index: u8, fd: Option<RawFd>) -> Result<()> { + fn set_vring_call(&mut self, index: u8, fd: Option<File>) -> Result<()> { if index as usize >= self.queue_num || index as usize > self.queue_num { return Err(Error::InvalidParam); } - if self.call_fd[index as usize].is_some() { - // Close file descriptor set by previous operations. - let _ = unsafe { libc::close(self.call_fd[index as usize].unwrap()) }; - } self.call_fd[index as usize] = fd; Ok(()) } - fn set_vring_err(&mut self, index: u8, fd: Option<RawFd>) -> Result<()> { + fn set_vring_err(&mut self, index: u8, fd: Option<File>) -> Result<()> { if index as usize >= self.queue_num || index as usize > self.queue_num { return Err(Error::InvalidParam); } - if self.err_fd[index as usize].is_some() { - // Close file descriptor set by previous operations. - let _ = unsafe { libc::close(self.err_fd[index as usize].unwrap()) }; - } self.err_fd[index as usize] = fd; Ok(()) } @@ -250,10 +237,9 @@ impl VhostUserSlaveReqHandlerMut for DummySlaveReqHandler { fn get_inflight_fd( &mut self, inflight: &VhostUserInflight, - ) -> Result<(VhostUserInflight, RawFd)> { + ) -> Result<(VhostUserInflight, File)> { let file = tempfile::tempfile().unwrap(); - let fd = file.as_raw_fd(); - self.inflight_file = Some(file); + self.inflight_file = Some(file.try_clone().unwrap()); Ok(( VhostUserInflight { mmap_size: 0x1000, @@ -261,7 +247,7 @@ impl VhostUserSlaveReqHandlerMut for DummySlaveReqHandler { num_queues: inflight.num_queues, queue_size: inflight.queue_size, }, - fd, + file, )) } @@ -273,7 +259,7 @@ impl VhostUserSlaveReqHandlerMut for DummySlaveReqHandler { Ok(MAX_MEM_SLOTS as u64) } - fn add_mem_region(&mut self, _region: &VhostUserSingleMemoryRegion, _fd: RawFd) -> Result<()> { + fn add_mem_region(&mut self, _region: &VhostUserSingleMemoryRegion, _fd: File) -> Result<()> { Ok(()) } diff --git a/src/vhost_user/master.rs b/src/vhost_user/master.rs index 0e8b22c..7933f4d 100644 --- a/src/vhost_user/master.rs +++ b/src/vhost_user/master.rs @@ -5,7 +5,7 @@ use std::fs::File; use std::mem; -use std::os::unix::io::{AsRawFd, FromRawFd, RawFd}; +use std::os::unix::io::{AsRawFd, RawFd}; use std::os::unix::net::UnixStream; use std::path::Path; use std::sync::{Arc, Mutex, MutexGuard}; @@ -14,6 +14,7 @@ use vmm_sys_util::eventfd::EventFd; use super::connection::Endpoint; use super::message::*; +use super::slave_req_handler::take_single_file; use super::{Error as VhostUserError, Result as VhostUserResult}; use crate::backend::{VhostBackend, VhostUserMemoryRegionInfo, VringConfigData}; use crate::{Error, Result}; @@ -50,7 +51,7 @@ pub trait VhostUserMaster: VhostBackend { fn set_config(&mut self, offset: u32, flags: VhostUserConfigFlags, buf: &[u8]) -> Result<()>; /// Setup slave communication channel. - fn set_slave_request_fd(&mut self, fd: RawFd) -> Result<()>; + fn set_slave_request_fd(&mut self, fd: &dyn AsRawFd) -> Result<()>; /// Retrieve shared buffer for inflight I/O tracking. fn get_inflight_fd( @@ -412,7 +413,6 @@ impl VhostUserMaster for Master { let (body_reply, buf_reply, rfds) = node.recv_reply_with_payload::<VhostUserConfig>(&hdr)?; if rfds.is_some() { - Endpoint::<MasterReq>::close_rfds(rfds); return error_code(VhostUserError::InvalidMessage); } else if body_reply.size == 0 { return error_code(VhostUserError::SlaveInternalError); @@ -445,13 +445,12 @@ impl VhostUserMaster for Master { node.wait_for_ack(&hdr).map_err(|e| e.into()) } - fn set_slave_request_fd(&mut self, fd: RawFd) -> Result<()> { + fn set_slave_request_fd(&mut self, fd: &dyn AsRawFd) -> Result<()> { let mut node = self.node(); if node.acked_protocol_features & VhostUserProtocolFeatures::SLAVE_REQ.bits() == 0 { return error_code(VhostUserError::InvalidOperation); } - - let fds = [fd]; + let fds = [fd.as_raw_fd()]; let hdr = node.send_request_header(MasterReq::SET_SLAVE_REQ_FD, Some(&fds))?; node.wait_for_ack(&hdr).map_err(|e| e.into()) } @@ -466,20 +465,12 @@ impl VhostUserMaster for Master { } let hdr = node.send_request_with_body(MasterReq::GET_INFLIGHT_FD, inflight, None)?; - let (inflight, fds) = node.recv_reply_with_fds::<VhostUserInflight>(&hdr)?; + let (inflight, files) = node.recv_reply_with_files::<VhostUserInflight>(&hdr)?; - if let Some(fds) = &fds { - if fds.len() == 1 && fds[0] >= 0 { - // Safe because we know the fd is valid. - let file = unsafe { File::from_raw_fd(fds[0]) }; - return Ok((inflight, file)); - } + match take_single_file(files) { + Some(file) => Ok((inflight, file)), + None => error_code(VhostUserError::IncorrectFds), } - - // Make sure to close the fds before returning the error. - Endpoint::<MasterReq>::close_rfds(fds); - - error_code(VhostUserError::IncorrectFds) } fn set_inflight_fd(&mut self, inflight: &VhostUserInflight, fd: RawFd) -> Result<()> { @@ -685,33 +676,31 @@ impl MasterInternal { let (reply, body, rfds) = self.main_sock.recv_body::<T>()?; if !reply.is_reply_for(&hdr) || rfds.is_some() || !body.is_valid() { - Endpoint::<MasterReq>::close_rfds(rfds); return Err(VhostUserError::InvalidMessage); } Ok(body) } - fn recv_reply_with_fds<T: Sized + Default + VhostUserMsgValidator>( + fn recv_reply_with_files<T: Sized + Default + VhostUserMsgValidator>( &mut self, hdr: &VhostUserMsgHeader<MasterReq>, - ) -> VhostUserResult<(T, Option<Vec<RawFd>>)> { + ) -> VhostUserResult<(T, Option<Vec<File>>)> { if mem::size_of::<T>() > MAX_MSG_SIZE || hdr.is_reply() { return Err(VhostUserError::InvalidParam); } self.check_state()?; - let (reply, body, rfds) = self.main_sock.recv_body::<T>()?; - if !reply.is_reply_for(&hdr) || rfds.is_none() || !body.is_valid() { - Endpoint::<MasterReq>::close_rfds(rfds); + let (reply, body, files) = self.main_sock.recv_body::<T>()?; + if !reply.is_reply_for(&hdr) || files.is_none() || !body.is_valid() { return Err(VhostUserError::InvalidMessage); } - Ok((body, rfds)) + Ok((body, files)) } fn recv_reply_with_payload<T: Sized + Default + VhostUserMsgValidator>( &mut self, hdr: &VhostUserMsgHeader<MasterReq>, - ) -> VhostUserResult<(T, Vec<u8>, Option<Vec<RawFd>>)> { + ) -> VhostUserResult<(T, Vec<u8>, Option<Vec<File>>)> { if mem::size_of::<T>() > MAX_MSG_SIZE || hdr.get_size() as usize <= mem::size_of::<T>() || hdr.get_size() as usize > MAX_MSG_SIZE @@ -722,18 +711,17 @@ impl MasterInternal { self.check_state()?; let mut buf: Vec<u8> = vec![0; hdr.get_size() as usize - mem::size_of::<T>()]; - let (reply, body, bytes, rfds) = self.main_sock.recv_payload_into_buf::<T>(&mut buf)?; + let (reply, body, bytes, files) = self.main_sock.recv_payload_into_buf::<T>(&mut buf)?; if !reply.is_reply_for(hdr) || reply.get_size() as usize != mem::size_of::<T>() + bytes - || rfds.is_some() + || files.is_some() || !body.is_valid() + || bytes != buf.len() { - Endpoint::<MasterReq>::close_rfds(rfds); - return Err(VhostUserError::InvalidMessage); - } else if bytes != buf.len() { return Err(VhostUserError::InvalidMessage); } - Ok((body, buf, rfds)) + + Ok((body, buf, files)) } fn wait_for_ack(&mut self, hdr: &VhostUserMsgHeader<MasterReq>) -> VhostUserResult<()> { @@ -746,7 +734,6 @@ impl MasterInternal { let (reply, body, rfds) = self.main_sock.recv_body::<VhostUserU64>()?; if !reply.is_reply_for(&hdr) || rfds.is_some() || !body.is_valid() { - Endpoint::<MasterReq>::close_rfds(rfds); return Err(VhostUserError::InvalidMessage); } if body.value != 0 { diff --git a/src/vhost_user/master_req_handler.rs b/src/vhost_user/master_req_handler.rs index 8cba188..0ecda4e 100644 --- a/src/vhost_user/master_req_handler.rs +++ b/src/vhost_user/master_req_handler.rs @@ -1,6 +1,7 @@ // Copyright (C) 2019-2021 Alibaba Cloud. All rights reserved. // SPDX-License-Identifier: Apache-2.0 +use std::fs::File; use std::mem; use std::os::unix::io::{AsRawFd, RawFd}; use std::os::unix::net::UnixStream; @@ -33,9 +34,7 @@ pub trait VhostUserMasterReqHandler { } /// Handle virtio-fs map file requests. - fn fs_slave_map(&self, _fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> { - // Safe because we have just received the rawfd from kernel. - unsafe { libc::close(fd) }; + fn fs_slave_map(&self, _fs: &VhostUserFSSlaveMsg, _fd: &dyn AsRawFd) -> HandlerResult<u64> { Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) } @@ -50,14 +49,12 @@ pub trait VhostUserMasterReqHandler { } /// Handle virtio-fs file IO requests. - fn fs_slave_io(&self, _fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> { - // Safe because we have just received the rawfd from kernel. - unsafe { libc::close(fd) }; + fn fs_slave_io(&self, _fs: &VhostUserFSSlaveMsg, _fd: &dyn AsRawFd) -> HandlerResult<u64> { Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) } // fn handle_iotlb_msg(&mut self, iotlb: VhostUserIotlb); - // fn handle_vring_host_notifier(&mut self, area: VhostUserVringArea, fd: RawFd); + // fn handle_vring_host_notifier(&mut self, area: VhostUserVringArea, fd: &dyn AsRawFd); } /// A helper trait mirroring [VhostUserMasterReqHandler] but without interior mutability. @@ -70,9 +67,7 @@ pub trait VhostUserMasterReqHandlerMut { } /// Handle virtio-fs map file requests. - fn fs_slave_map(&mut self, _fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> { - // Safe because we have just received the rawfd from kernel. - unsafe { libc::close(fd) }; + fn fs_slave_map(&mut self, _fs: &VhostUserFSSlaveMsg, _fd: &dyn AsRawFd) -> HandlerResult<u64> { Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) } @@ -87,9 +82,7 @@ pub trait VhostUserMasterReqHandlerMut { } /// Handle virtio-fs file IO requests. - fn fs_slave_io(&mut self, _fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> { - // Safe because we have just received the rawfd from kernel. - unsafe { libc::close(fd) }; + fn fs_slave_io(&mut self, _fs: &VhostUserFSSlaveMsg, _fd: &dyn AsRawFd) -> HandlerResult<u64> { Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) } @@ -102,7 +95,7 @@ impl<S: VhostUserMasterReqHandlerMut> VhostUserMasterReqHandler for Mutex<S> { self.lock().unwrap().handle_config_change() } - fn fs_slave_map(&self, fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> { + fn fs_slave_map(&self, fs: &VhostUserFSSlaveMsg, fd: &dyn AsRawFd) -> HandlerResult<u64> { self.lock().unwrap().fs_slave_map(fs, fd) } @@ -114,7 +107,7 @@ impl<S: VhostUserMasterReqHandlerMut> VhostUserMasterReqHandler for Mutex<S> { self.lock().unwrap().fs_slave_sync(fs) } - fn fs_slave_io(&self, fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> { + fn fs_slave_io(&self, fs: &VhostUserFSSlaveMsg, fd: &dyn AsRawFd) -> HandlerResult<u64> { self.lock().unwrap().fs_slave_io(fs, fd) } } @@ -206,8 +199,8 @@ impl<S: VhostUserMasterReqHandler> MasterReqHandler<S> { // . recv optional message body and payload according size field in // message header // . validate message body and optional payload - let (hdr, rfds) = self.sub_sock.recv_header()?; - let rfds = self.check_attached_rfds(&hdr, rfds)?; + let (hdr, files) = self.sub_sock.recv_header()?; + self.check_attached_files(&hdr, &files)?; let (size, buf) = match hdr.get_size() { 0 => (0, vec![0u8; 0]), len => { @@ -231,9 +224,9 @@ impl<S: VhostUserMasterReqHandler> MasterReqHandler<S> { } SlaveReq::FS_MAP => { let msg = self.extract_msg_body::<VhostUserFSSlaveMsg>(&hdr, size, &buf)?; - // check_attached_rfds() has validated rfds + // check_attached_files() has validated files self.backend - .fs_slave_map(&msg, rfds.unwrap()[0]) + .fs_slave_map(&msg, &files.unwrap()[0]) .map_err(Error::ReqHandlerError) } SlaveReq::FS_UNMAP => { @@ -250,9 +243,9 @@ impl<S: VhostUserMasterReqHandler> MasterReqHandler<S> { } SlaveReq::FS_IO => { let msg = self.extract_msg_body::<VhostUserFSSlaveMsg>(&hdr, size, &buf)?; - // check_attached_rfds() has validated rfds + // check_attached_files() has validated files self.backend - .fs_slave_io(&msg, rfds.unwrap()[0]) + .fs_slave_io(&msg, &files.unwrap()[0]) .map_err(Error::ReqHandlerError) } _ => Err(Error::InvalidMessage), @@ -286,34 +279,21 @@ impl<S: VhostUserMasterReqHandler> MasterReqHandler<S> { Ok(()) } - fn check_attached_rfds( + fn check_attached_files( &self, hdr: &VhostUserMsgHeader<SlaveReq>, - rfds: Option<Vec<RawFd>>, - ) -> Result<Option<Vec<RawFd>>> { + files: &Option<Vec<File>>, + ) -> Result<()> { match hdr.get_code() { SlaveReq::FS_MAP | SlaveReq::FS_IO => { - // Expect an fd set with a single fd. - match rfds { - None => Err(Error::InvalidMessage), - Some(fds) => { - if fds.len() != 1 { - Endpoint::<SlaveReq>::close_rfds(Some(fds)); - Err(Error::InvalidMessage) - } else { - Ok(Some(fds)) - } - } - } - } - _ => { - if rfds.is_some() { - Endpoint::<SlaveReq>::close_rfds(rfds); - Err(Error::InvalidMessage) - } else { - Ok(rfds) + // Expect a single file is passed. + match files { + Some(files) if files.len() == 1 => Ok(()), + _ => Err(Error::InvalidMessage), } } + _ if files.is_some() => Err(Error::InvalidMessage), + _ => Ok(()), } } @@ -390,9 +370,11 @@ mod tests { impl VhostUserMasterReqHandlerMut for MockMasterReqHandler { /// Handle virtio-fs map file requests from the slave. - fn fs_slave_map(&mut self, _fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> { - // Safe because we have just received the rawfd from kernel. - unsafe { libc::close(fd) }; + fn fs_slave_map( + &mut self, + _fs: &VhostUserFSSlaveMsg, + _fd: &dyn AsRawFd, + ) -> HandlerResult<u64> { Ok(0) } @@ -437,7 +419,7 @@ mod tests { }); fs_cache - .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd) + .fs_slave_map(&VhostUserFSSlaveMsg::default(), &fd) .unwrap(); // When REPLY_ACK has not been negotiated, the master has no way to detect failure from // slave side. @@ -468,7 +450,7 @@ mod tests { fs_cache.set_reply_ack_flag(true); fs_cache - .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd) + .fs_slave_map(&VhostUserFSSlaveMsg::default(), &fd) .unwrap(); fs_cache .fs_slave_unmap(&VhostUserFSSlaveMsg::default()) diff --git a/src/vhost_user/mod.rs b/src/vhost_user/mod.rs index 52d97f7..3467c25 100644 --- a/src/vhost_user/mod.rs +++ b/src/vhost_user/mod.rs @@ -401,7 +401,7 @@ mod tests { assert_eq!(offset, 0x100); assert_eq!(reply_payload[0], 0xa5); - master.set_slave_request_fd(eventfd.as_raw_fd()).unwrap(); + master.set_slave_request_fd(&eventfd).unwrap(); master.set_vring_enable(0, true).unwrap(); // unimplemented yet diff --git a/src/vhost_user/slave_fs_cache.rs b/src/vhost_user/slave_fs_cache.rs index a9c4ed2..ee5fd9b 100644 --- a/src/vhost_user/slave_fs_cache.rs +++ b/src/vhost_user/slave_fs_cache.rs @@ -3,7 +3,7 @@ use std::io; use std::mem; -use std::os::unix::io::RawFd; +use std::os::unix::io::{AsRawFd, RawFd}; use std::os::unix::net::UnixStream; use std::sync::{Arc, Mutex, MutexGuard}; @@ -55,7 +55,6 @@ impl SlaveFsCacheReqInternal { let (reply, body, rfds) = self.sock.recv_body::<VhostUserU64>()?; if !reply.is_reply_for(&hdr) || rfds.is_some() || !body.is_valid() { - Endpoint::<SlaveReq>::close_rfds(rfds); return Err(Error::InvalidMessage); } if body.value != 0 { @@ -129,8 +128,8 @@ impl SlaveFsCacheReq { impl VhostUserMasterReqHandler for SlaveFsCacheReq { /// Forward vhost-user-fs map file requests to the slave. - fn fs_slave_map(&self, fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> { - self.send_message(SlaveReq::FS_MAP, fs, Some(&[fd])) + fn fs_slave_map(&self, fs: &VhostUserFSSlaveMsg, fd: &dyn AsRawFd) -> HandlerResult<u64> { + self.send_message(SlaveReq::FS_MAP, fs, Some(&[fd.as_raw_fd()])) } /// Forward vhost-user-fs unmap file requests to the master. @@ -158,31 +157,21 @@ mod tests { #[test] fn test_slave_fs_cache_send_failure() { let (p1, p2) = UnixStream::pair().unwrap(); - let fd = p2.as_raw_fd(); let fs_cache = SlaveFsCacheReq::from_stream(p1); fs_cache.set_failed(libc::ECONNRESET); fs_cache - .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd) + .fs_slave_map(&VhostUserFSSlaveMsg::default(), &p2) .unwrap_err(); fs_cache .fs_slave_unmap(&VhostUserFSSlaveMsg::default()) .unwrap_err(); fs_cache.node().error = None; - - drop(p2); - fs_cache - .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd) - .unwrap_err(); - fs_cache - .fs_slave_unmap(&VhostUserFSSlaveMsg::default()) - .unwrap_err(); } #[test] fn test_slave_fs_cache_recv_negative() { let (p1, p2) = UnixStream::pair().unwrap(); - let fd = p2.as_raw_fd(); let fs_cache = SlaveFsCacheReq::from_stream(p1); let mut master = Endpoint::<SlaveReq>::from_stream(p2); @@ -194,33 +183,35 @@ mod tests { ); let body = VhostUserU64::new(0); - master.send_message(&hdr, &body, Some(&[fd])).unwrap(); + master + .send_message(&hdr, &body, Some(&[master.as_raw_fd()])) + .unwrap(); fs_cache - .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd) + .fs_slave_map(&VhostUserFSSlaveMsg::default(), &master) .unwrap(); fs_cache.set_reply_ack_flag(true); fs_cache - .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd) + .fs_slave_map(&VhostUserFSSlaveMsg::default(), &master) .unwrap_err(); hdr.set_code(SlaveReq::FS_UNMAP); master.send_message(&hdr, &body, None).unwrap(); fs_cache - .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd) + .fs_slave_map(&VhostUserFSSlaveMsg::default(), &master) .unwrap_err(); hdr.set_code(SlaveReq::FS_MAP); let body = VhostUserU64::new(1); master.send_message(&hdr, &body, None).unwrap(); fs_cache - .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd) + .fs_slave_map(&VhostUserFSSlaveMsg::default(), &master) .unwrap_err(); let body = VhostUserU64::new(0); master.send_message(&hdr, &body, None).unwrap(); fs_cache - .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd) + .fs_slave_map(&VhostUserFSSlaveMsg::default(), &master) .unwrap(); } } diff --git a/src/vhost_user/slave_req_handler.rs b/src/vhost_user/slave_req_handler.rs index 710b5f5..7c3de7d 100644 --- a/src/vhost_user/slave_req_handler.rs +++ b/src/vhost_user/slave_req_handler.rs @@ -3,7 +3,7 @@ use std::fs::File; use std::mem; -use std::os::unix::io::{AsRawFd, FromRawFd, RawFd}; +use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd}; use std::os::unix::net::UnixStream; use std::slice; use std::sync::{Arc, Mutex}; @@ -39,7 +39,7 @@ pub trait VhostUserSlaveReqHandler { fn reset_owner(&self) -> Result<()>; fn get_features(&self) -> Result<u64>; fn set_features(&self, features: u64) -> Result<()>; - fn set_mem_table(&self, ctx: &[VhostUserMemoryRegion], fds: &[RawFd]) -> Result<()>; + fn set_mem_table(&self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()>; fn set_vring_num(&self, index: u32, num: u32) -> Result<()>; fn set_vring_addr( &self, @@ -52,9 +52,9 @@ pub trait VhostUserSlaveReqHandler { ) -> Result<()>; fn set_vring_base(&self, index: u32, base: u32) -> Result<()>; fn get_vring_base(&self, index: u32) -> Result<VhostUserVringState>; - fn set_vring_kick(&self, index: u8, fd: Option<RawFd>) -> Result<()>; - fn set_vring_call(&self, index: u8, fd: Option<RawFd>) -> Result<()>; - fn set_vring_err(&self, index: u8, fd: Option<RawFd>) -> Result<()>; + fn set_vring_kick(&self, index: u8, fd: Option<File>) -> Result<()>; + fn set_vring_call(&self, index: u8, fd: Option<File>) -> Result<()>; + fn set_vring_err(&self, index: u8, fd: Option<File>) -> Result<()>; fn get_protocol_features(&self) -> Result<VhostUserProtocolFeatures>; fn set_protocol_features(&self, features: u64) -> Result<()>; @@ -63,10 +63,10 @@ pub trait VhostUserSlaveReqHandler { fn get_config(&self, offset: u32, size: u32, flags: VhostUserConfigFlags) -> Result<Vec<u8>>; fn set_config(&self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>; fn set_slave_req_fd(&self, _vu_req: SlaveFsCacheReq) {} - fn get_inflight_fd(&self, inflight: &VhostUserInflight) -> Result<(VhostUserInflight, RawFd)>; + fn get_inflight_fd(&self, inflight: &VhostUserInflight) -> Result<(VhostUserInflight, File)>; fn set_inflight_fd(&self, inflight: &VhostUserInflight, file: File) -> Result<()>; fn get_max_mem_slots(&self) -> Result<u64>; - fn add_mem_region(&self, region: &VhostUserSingleMemoryRegion, fd: RawFd) -> Result<()>; + fn add_mem_region(&self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()>; fn remove_mem_region(&self, region: &VhostUserSingleMemoryRegion) -> Result<()>; } @@ -79,7 +79,7 @@ pub trait VhostUserSlaveReqHandlerMut { fn reset_owner(&mut self) -> Result<()>; fn get_features(&mut self) -> Result<u64>; fn set_features(&mut self, features: u64) -> Result<()>; - fn set_mem_table(&mut self, ctx: &[VhostUserMemoryRegion], fds: &[RawFd]) -> Result<()>; + fn set_mem_table(&mut self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()>; fn set_vring_num(&mut self, index: u32, num: u32) -> Result<()>; fn set_vring_addr( &mut self, @@ -92,9 +92,9 @@ pub trait VhostUserSlaveReqHandlerMut { ) -> Result<()>; fn set_vring_base(&mut self, index: u32, base: u32) -> Result<()>; fn get_vring_base(&mut self, index: u32) -> Result<VhostUserVringState>; - fn set_vring_kick(&mut self, index: u8, fd: Option<RawFd>) -> Result<()>; - fn set_vring_call(&mut self, index: u8, fd: Option<RawFd>) -> Result<()>; - fn set_vring_err(&mut self, index: u8, fd: Option<RawFd>) -> Result<()>; + fn set_vring_kick(&mut self, index: u8, fd: Option<File>) -> Result<()>; + fn set_vring_call(&mut self, index: u8, fd: Option<File>) -> Result<()>; + fn set_vring_err(&mut self, index: u8, fd: Option<File>) -> Result<()>; fn get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures>; fn set_protocol_features(&mut self, features: u64) -> Result<()>; @@ -111,10 +111,10 @@ pub trait VhostUserSlaveReqHandlerMut { fn get_inflight_fd( &mut self, inflight: &VhostUserInflight, - ) -> Result<(VhostUserInflight, RawFd)>; + ) -> Result<(VhostUserInflight, File)>; fn set_inflight_fd(&mut self, inflight: &VhostUserInflight, file: File) -> Result<()>; fn get_max_mem_slots(&mut self) -> Result<u64>; - fn add_mem_region(&mut self, region: &VhostUserSingleMemoryRegion, fd: RawFd) -> Result<()>; + fn add_mem_region(&mut self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()>; fn remove_mem_region(&mut self, region: &VhostUserSingleMemoryRegion) -> Result<()>; } @@ -135,8 +135,8 @@ impl<T: VhostUserSlaveReqHandlerMut> VhostUserSlaveReqHandler for Mutex<T> { self.lock().unwrap().set_features(features) } - fn set_mem_table(&self, ctx: &[VhostUserMemoryRegion], fds: &[RawFd]) -> Result<()> { - self.lock().unwrap().set_mem_table(ctx, fds) + fn set_mem_table(&self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()> { + self.lock().unwrap().set_mem_table(ctx, files) } fn set_vring_num(&self, index: u32, num: u32) -> Result<()> { @@ -165,15 +165,15 @@ impl<T: VhostUserSlaveReqHandlerMut> VhostUserSlaveReqHandler for Mutex<T> { self.lock().unwrap().get_vring_base(index) } - fn set_vring_kick(&self, index: u8, fd: Option<RawFd>) -> Result<()> { + fn set_vring_kick(&self, index: u8, fd: Option<File>) -> Result<()> { self.lock().unwrap().set_vring_kick(index, fd) } - fn set_vring_call(&self, index: u8, fd: Option<RawFd>) -> Result<()> { + fn set_vring_call(&self, index: u8, fd: Option<File>) -> Result<()> { self.lock().unwrap().set_vring_call(index, fd) } - fn set_vring_err(&self, index: u8, fd: Option<RawFd>) -> Result<()> { + fn set_vring_err(&self, index: u8, fd: Option<File>) -> Result<()> { self.lock().unwrap().set_vring_err(index, fd) } @@ -205,7 +205,7 @@ impl<T: VhostUserSlaveReqHandlerMut> VhostUserSlaveReqHandler for Mutex<T> { self.lock().unwrap().set_slave_req_fd(vu_req) } - fn get_inflight_fd(&self, inflight: &VhostUserInflight) -> Result<(VhostUserInflight, RawFd)> { + fn get_inflight_fd(&self, inflight: &VhostUserInflight) -> Result<(VhostUserInflight, File)> { self.lock().unwrap().get_inflight_fd(inflight) } @@ -217,7 +217,7 @@ impl<T: VhostUserSlaveReqHandlerMut> VhostUserSlaveReqHandler for Mutex<T> { self.lock().unwrap().get_max_mem_slots() } - fn add_mem_region(&self, region: &VhostUserSingleMemoryRegion, fd: RawFd) -> Result<()> { + fn add_mem_region(&self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()> { self.lock().unwrap().add_mem_region(region, fd) } @@ -307,8 +307,9 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> { // . recv optional message body and payload according size field in // message header // . validate message body and optional payload - let (hdr, rfds) = self.main_sock.recv_header()?; - let rfds = self.check_attached_rfds(&hdr, rfds)?; + let (hdr, files) = self.main_sock.recv_header()?; + self.check_attached_files(&hdr, &files)?; + let (size, buf) = match hdr.get_size() { 0 => (0, vec![0u8; 0]), len => { @@ -347,7 +348,7 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> { self.send_ack_message(&hdr, res)?; } MasterReq::SET_MEM_TABLE => { - let res = self.set_mem_table(&hdr, size, &buf, rfds); + let res = self.set_mem_table(&hdr, size, &buf, files); self.send_ack_message(&hdr, res)?; } MasterReq::SET_VRING_NUM => { @@ -383,20 +384,20 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> { } MasterReq::SET_VRING_CALL => { self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?; - let (index, rfds) = self.handle_vring_fd_request(&buf, rfds)?; - let res = self.backend.set_vring_call(index, rfds); + let (index, file) = self.handle_vring_fd_request(&buf, files)?; + let res = self.backend.set_vring_call(index, file); self.send_ack_message(&hdr, res)?; } MasterReq::SET_VRING_KICK => { self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?; - let (index, rfds) = self.handle_vring_fd_request(&buf, rfds)?; - let res = self.backend.set_vring_kick(index, rfds); + let (index, file) = self.handle_vring_fd_request(&buf, files)?; + let res = self.backend.set_vring_kick(index, file); self.send_ack_message(&hdr, res)?; } MasterReq::SET_VRING_ERR => { self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?; - let (index, rfds) = self.handle_vring_fd_request(&buf, rfds)?; - let res = self.backend.set_vring_err(index, rfds); + let (index, file) = self.handle_vring_fd_request(&buf, files)?; + let res = self.backend.set_vring_err(index, file); self.send_ack_message(&hdr, res)?; } MasterReq::GET_PROTOCOL_FEATURES => { @@ -459,7 +460,7 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> { return Err(Error::InvalidOperation); } self.check_request_size(&hdr, size, hdr.get_size() as usize)?; - let res = self.set_slave_req_fd(rfds); + let res = self.set_slave_req_fd(files); self.send_ack_message(&hdr, res)?; } MasterReq::GET_INFLIGHT_FD => { @@ -470,10 +471,10 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> { } let msg = self.extract_request_body::<VhostUserInflight>(&hdr, size, &buf)?; - let (inflight, fd) = self.backend.get_inflight_fd(&msg)?; + let (inflight, file) = self.backend.get_inflight_fd(&msg)?; let reply_hdr = self.new_reply_header::<VhostUserInflight>(&hdr, 0)?; self.main_sock - .send_message(&reply_hdr, &inflight, Some(&[fd]))?; + .send_message(&reply_hdr, &inflight, Some(&[file.as_raw_fd()]))?; } MasterReq::SET_INFLIGHT_FD => { if self.acked_protocol_features & VhostUserProtocolFeatures::INFLIGHT_SHMFD.bits() @@ -481,18 +482,7 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> { { return Err(Error::InvalidOperation); } - let file = if let Some(fds) = rfds { - if fds.len() != 1 || fds[0] < 0 { - Endpoint::<MasterReq>::close_rfds(Some(fds)); - return Err(Error::IncorrectFds); - } - - // Safe because we know the fd is valid. - unsafe { File::from_raw_fd(fds[0]) } - } else { - return Err(Error::IncorrectFds); - }; - + let file = take_single_file(files).ok_or(Error::IncorrectFds)?; let msg = self.extract_request_body::<VhostUserInflight>(&hdr, size, &buf)?; let res = self.backend.set_inflight_fd(&msg, file); self.send_ack_message(&hdr, res)?; @@ -516,18 +506,13 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> { { return Err(Error::InvalidOperation); } - let fd = if let Some(fds) = &rfds { - if fds.len() != 1 { - return Err(Error::InvalidParam); - } - fds[0] - } else { + let mut files = files.ok_or(Error::InvalidParam)?; + if files.len() != 1 { return Err(Error::InvalidParam); - }; - + } let msg = self.extract_request_body::<VhostUserSingleMemoryRegion>(&hdr, size, &buf)?; - let res = self.backend.add_mem_region(&msg, fd); + let res = self.backend.add_mem_region(&msg, files.swap_remove(0)); self.send_ack_message(&hdr, res)?; } MasterReq::REM_MEM_REG => { @@ -555,37 +540,28 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> { hdr: &VhostUserMsgHeader<MasterReq>, size: usize, buf: &[u8], - rfds: Option<Vec<RawFd>>, + files: Option<Vec<File>>, ) -> Result<()> { self.check_request_size(&hdr, size, hdr.get_size() as usize)?; // check message size is consistent let hdrsize = mem::size_of::<VhostUserMemory>(); if size < hdrsize { - Endpoint::<MasterReq>::close_rfds(rfds); return Err(Error::InvalidMessage); } let msg = unsafe { &*(buf.as_ptr() as *const VhostUserMemory) }; if !msg.is_valid() { - Endpoint::<MasterReq>::close_rfds(rfds); return Err(Error::InvalidMessage); } if size != hdrsize + msg.num_regions as usize * mem::size_of::<VhostUserMemoryRegion>() { - Endpoint::<MasterReq>::close_rfds(rfds); return Err(Error::InvalidMessage); } // validate number of fds matching number of memory regions - let fds = match rfds { - None => return Err(Error::InvalidMessage), - Some(fds) => { - if fds.len() != msg.num_regions as usize { - Endpoint::<MasterReq>::close_rfds(Some(fds)); - return Err(Error::InvalidMessage); - } - fds - } - }; + let files = files.ok_or(Error::InvalidMessage)?; + if files.len() != msg.num_regions as usize { + return Err(Error::InvalidMessage); + } // Validate memory regions let regions = unsafe { @@ -596,12 +572,11 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> { }; for region in regions.iter() { if !region.is_valid() { - Endpoint::<MasterReq>::close_rfds(Some(fds)); return Err(Error::InvalidMessage); } } - self.backend.set_mem_table(®ions, &fds) + self.backend.set_mem_table(®ions, files) } fn get_config(&mut self, hdr: &VhostUserMsgHeader<MasterReq>, buf: &[u8]) -> Result<()> { @@ -662,26 +637,19 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> { self.backend.set_config(msg.offset, buf, flags) } - fn set_slave_req_fd(&mut self, rfds: Option<Vec<RawFd>>) -> Result<()> { - if let Some(fds) = rfds { - if fds.len() == 1 { - let sock = unsafe { UnixStream::from_raw_fd(fds[0]) }; - let vu_req = SlaveFsCacheReq::from_stream(sock); - self.backend.set_slave_req_fd(vu_req); - Ok(()) - } else { - Err(Error::InvalidMessage) - } - } else { - Err(Error::InvalidMessage) - } + fn set_slave_req_fd(&mut self, files: Option<Vec<File>>) -> Result<()> { + let file = take_single_file(files).ok_or(Error::InvalidMessage)?; + let sock = unsafe { UnixStream::from_raw_fd(file.into_raw_fd()) }; + let vu_req = SlaveFsCacheReq::from_stream(sock); + self.backend.set_slave_req_fd(vu_req); + Ok(()) } fn handle_vring_fd_request( &mut self, buf: &[u8], - rfds: Option<Vec<RawFd>>, - ) -> Result<(u8, Option<RawFd>)> { + files: Option<Vec<File>>, + ) -> Result<(u8, Option<File>)> { if buf.len() > MAX_MSG_SIZE || buf.len() < mem::size_of::<VhostUserU64>() { return Err(Error::InvalidMessage); } @@ -691,28 +659,19 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> { } // Bits (0-7) of the payload contain the vring index. Bit 8 is the - // invalid FD flag. This flag is set when there is no file descriptor + // invalid FD flag. This bit is set when there is no file descriptor // in the ancillary data. This signals that polling will be used // instead of waiting for the call. - let nofd = (msg.value & 0x100u64) == 0x100u64; - - let mut rfd = None; - match rfds { - Some(fds) => { - if !nofd && fds.len() == 1 { - rfd = Some(fds[0]); - } else if (nofd && !fds.is_empty()) || (!nofd && fds.len() != 1) { - Endpoint::<MasterReq>::close_rfds(Some(fds)); - return Err(Error::InvalidMessage); - } - } - None => { - if !nofd { - return Err(Error::InvalidMessage); - } - } + // If Bit 8 is unset, the data must contain a file descriptor. + let has_fd = (msg.value & 0x100u64) == 0; + + let file = take_single_file(files); + + if has_fd && file.is_none() || !has_fd && file.is_some() { + return Err(Error::InvalidMessage); } - Ok((msg.value as u8, rfd)) + + Ok((msg.value as u8, file)) } fn check_state(&self) -> Result<()> { @@ -738,29 +697,23 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> { Ok(()) } - fn check_attached_rfds( + fn check_attached_files( &self, hdr: &VhostUserMsgHeader<MasterReq>, - rfds: Option<Vec<RawFd>>, - ) -> Result<Option<Vec<RawFd>>> { + files: &Option<Vec<File>>, + ) -> Result<()> { match hdr.get_code() { - MasterReq::SET_MEM_TABLE => Ok(rfds), - MasterReq::SET_VRING_CALL => Ok(rfds), - MasterReq::SET_VRING_KICK => Ok(rfds), - MasterReq::SET_VRING_ERR => Ok(rfds), - MasterReq::SET_LOG_BASE => Ok(rfds), - MasterReq::SET_LOG_FD => Ok(rfds), - MasterReq::SET_SLAVE_REQ_FD => Ok(rfds), - MasterReq::SET_INFLIGHT_FD => Ok(rfds), - MasterReq::ADD_MEM_REG => Ok(rfds), - _ => { - if rfds.is_some() { - Endpoint::<MasterReq>::close_rfds(rfds); - Err(Error::InvalidMessage) - } else { - Ok(rfds) - } - } + MasterReq::SET_MEM_TABLE + | MasterReq::SET_VRING_CALL + | MasterReq::SET_VRING_KICK + | MasterReq::SET_VRING_ERR + | MasterReq::SET_LOG_BASE + | MasterReq::SET_LOG_FD + | MasterReq::SET_SLAVE_REQ_FD + | MasterReq::SET_INFLIGHT_FD + | MasterReq::ADD_MEM_REG => Ok(()), + _ if files.is_some() => Err(Error::InvalidMessage), + _ => Ok(()), } } @@ -850,6 +803,16 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> { } } +/// Utility function to take the first element from option of a vector of files. +/// Returns `None` if the vector contains no file or more than one file. +pub(crate) fn take_single_file(files: Option<Vec<File>>) -> Option<File> { + let mut files = files?; + if files.len() != 1 { + return None; + } + Some(files.swap_remove(0)) +} + impl<S: VhostUserSlaveReqHandler> AsRawFd for SlaveReqHandler<S> { fn as_raw_fd(&self) -> RawFd { self.main_sock.as_raw_fd() |