aboutsummaryrefslogtreecommitdiff
path: root/src/txbuf.rs
blob: ef718d7113676d6a5889bc261a29df9a1d406ec5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
// SPDX-License-Identifier: Apache-2.0 or BSD-3-Clause

use std::{io::Write, num::Wrapping};

use vm_memory::{bitmap::BitmapSlice, VolatileSlice};

use crate::vhu_vsock::{Error, Result};

#[derive(Debug)]
pub(crate) struct LocalTxBuf {
    /// Buffer holding data to be forwarded to a host-side application
    buf: Vec<u8>,
    /// Index into buffer from which data can be consumed from the buffer
    head: Wrapping<u32>,
    /// Index into buffer from which data can be added to the buffer
    tail: Wrapping<u32>,
}

impl LocalTxBuf {
    /// Create a new instance of LocalTxBuf.
    pub fn new(buf_size: u32) -> Self {
        Self {
            buf: vec![0; buf_size as usize],
            head: Wrapping(0),
            tail: Wrapping(0),
        }
    }

    /// Get the buffer size
    pub fn get_buf_size(&self) -> u32 {
        self.buf.len() as u32
    }

    /// Check if the buf is empty.
    pub fn is_empty(&self) -> bool {
        self.len() == 0
    }

    /// Add new data to the tx buffer, push all or none.
    /// Returns LocalTxBufFull error if space not sufficient.
    pub fn push<B: BitmapSlice>(&mut self, data_buf: &VolatileSlice<B>) -> Result<()> {
        if self.get_buf_size() as usize - self.len() < data_buf.len() {
            // Tx buffer is full
            return Err(Error::LocalTxBufFull);
        }

        // Get index into buffer at which data can be inserted
        let tail_idx = self.tail.0 as usize % self.get_buf_size() as usize;

        // Check if we can fit the data buffer between head and end of buffer
        let len = std::cmp::min(self.get_buf_size() as usize - tail_idx, data_buf.len());
        let txbuf = &mut self.buf[tail_idx..tail_idx + len];
        data_buf.copy_to(txbuf);

        // Check if there is more data to be wrapped around
        if len < data_buf.len() {
            let remain_txbuf = &mut self.buf[..(data_buf.len() - len)];
            data_buf.copy_to(remain_txbuf);
        }

        // Increment tail by the amount of data that has been added to the buffer
        self.tail += Wrapping(data_buf.len() as u32);

        Ok(())
    }

    /// Flush buf data to stream.
    pub fn flush_to<S: Write>(&mut self, stream: &mut S) -> Result<usize> {
        if self.is_empty() {
            // No data to be flushed
            return Ok(0);
        }

        // Get index into buffer from which data can be read
        let head_idx = self.head.0 as usize % self.get_buf_size() as usize;

        // First write from head to end of buffer
        let len = std::cmp::min(self.get_buf_size() as usize - head_idx, self.len());
        let written = stream
            .write(&self.buf[head_idx..(head_idx + len)])
            .map_err(Error::LocalTxBufFlush)?;

        // Increment head  by amount of data that has been flushed to the stream
        self.head += Wrapping(written as u32);

        // If written length is less than the expected length we can try again in the future
        if written < len {
            return Ok(written);
        }

        // The head index has wrapped around the end of the buffer, we call self again
        Ok(written + self.flush_to(stream).unwrap_or(0))
    }

    /// Return amount of data in the buffer.
    fn len(&self) -> usize {
        (self.tail - self.head).0 as usize
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    const CONN_TX_BUF_SIZE: u32 = 64 * 1024;

    #[test]
    fn test_txbuf_len() {
        let mut loc_tx_buf = LocalTxBuf::new(CONN_TX_BUF_SIZE);

        // Zero length tx buf
        assert_eq!(loc_tx_buf.len(), 0);

        // finite length tx buf
        loc_tx_buf.head = Wrapping(0);
        loc_tx_buf.tail = Wrapping(CONN_TX_BUF_SIZE);
        assert_eq!(loc_tx_buf.len(), CONN_TX_BUF_SIZE as usize);

        loc_tx_buf.tail = Wrapping(CONN_TX_BUF_SIZE / 2);
        assert_eq!(loc_tx_buf.len(), (CONN_TX_BUF_SIZE / 2) as usize);

        loc_tx_buf.head = Wrapping(256);
        assert_eq!(loc_tx_buf.len(), 32512);
    }

    #[test]
    fn test_txbuf_is_empty() {
        let mut loc_tx_buf = LocalTxBuf::new(CONN_TX_BUF_SIZE);

        // empty tx buffer
        assert!(loc_tx_buf.is_empty());

        // non empty tx buffer
        loc_tx_buf.tail = Wrapping(CONN_TX_BUF_SIZE);
        assert!(!loc_tx_buf.is_empty());
    }

    #[test]
    fn test_txbuf_push() {
        let mut loc_tx_buf = LocalTxBuf::new(CONN_TX_BUF_SIZE);
        let mut buf = [0; CONN_TX_BUF_SIZE as usize];
        // SAFETY: Safe as the buffer is guaranteed to be valid here.
        let data = unsafe { VolatileSlice::new(buf.as_mut_ptr(), buf.len()) };

        // push data into empty tx buffer
        let res_push = loc_tx_buf.push(&data);
        assert!(res_push.is_ok());
        assert_eq!(loc_tx_buf.head, Wrapping(0));
        assert_eq!(loc_tx_buf.tail, Wrapping(CONN_TX_BUF_SIZE));

        // push data into full tx buffer
        let res_push = loc_tx_buf.push(&data);
        assert!(res_push.is_err());

        // head and tail wrap at full
        loc_tx_buf.head = Wrapping(CONN_TX_BUF_SIZE);
        let res_push = loc_tx_buf.push(&data);
        assert!(res_push.is_ok());
        assert_eq!(loc_tx_buf.tail, Wrapping(CONN_TX_BUF_SIZE * 2));

        // only tail wraps at full
        let mut buf = vec![1; 4];
        // SAFETY: Safe as the buffer is guaranteed to be valid here.
        let data = unsafe { VolatileSlice::new(buf.as_mut_ptr(), buf.len()) };
        let mut cmp_data = vec![1; 4];
        cmp_data.append(&mut vec![0; (CONN_TX_BUF_SIZE - 4) as usize]);
        loc_tx_buf.head = Wrapping(4);
        loc_tx_buf.tail = Wrapping(CONN_TX_BUF_SIZE);
        let res_push = loc_tx_buf.push(&data);
        assert!(res_push.is_ok());
        assert_eq!(loc_tx_buf.head, Wrapping(4));
        assert_eq!(loc_tx_buf.tail, Wrapping(CONN_TX_BUF_SIZE + 4));
        assert_eq!(loc_tx_buf.buf, cmp_data);
    }

    #[test]
    fn test_txbuf_flush_to() {
        let mut loc_tx_buf = LocalTxBuf::new(CONN_TX_BUF_SIZE);

        // data to be flushed
        let mut buf = vec![1; CONN_TX_BUF_SIZE as usize];
        // SAFETY: Safe as the buffer is guaranteed to be valid here.
        let data = unsafe { VolatileSlice::new(buf.as_mut_ptr(), buf.len()) };

        // target to which data is flushed
        let mut cmp_vec = Vec::with_capacity(data.len());

        // flush no data
        let res_flush = loc_tx_buf.flush_to(&mut cmp_vec);
        assert!(res_flush.is_ok());
        assert_eq!(res_flush.unwrap(), 0);

        // flush data of CONN_TX_BUF_SIZE amount
        let res_push = loc_tx_buf.push(&data);
        assert!(res_push.is_ok());
        let res_flush = loc_tx_buf.flush_to(&mut cmp_vec);
        if let Ok(n) = res_flush {
            assert_eq!(loc_tx_buf.head, Wrapping(n as u32));
            assert_eq!(loc_tx_buf.tail, Wrapping(CONN_TX_BUF_SIZE));
            assert_eq!(n, cmp_vec.len());
            assert_eq!(cmp_vec, buf[..n]);
        }

        // wrapping head flush
        let mut buf = vec![0; (CONN_TX_BUF_SIZE / 2) as usize];
        buf.append(&mut vec![1; (CONN_TX_BUF_SIZE / 2) as usize]);
        // SAFETY: Safe as the buffer is guaranteed to be valid here.
        let data = unsafe { VolatileSlice::new(buf.as_mut_ptr(), buf.len()) };

        loc_tx_buf.head = Wrapping(0);
        loc_tx_buf.tail = Wrapping(0);
        let res_push = loc_tx_buf.push(&data);
        assert!(res_push.is_ok());
        cmp_vec.clear();
        loc_tx_buf.head = Wrapping(CONN_TX_BUF_SIZE / 2);
        loc_tx_buf.tail = Wrapping(CONN_TX_BUF_SIZE + (CONN_TX_BUF_SIZE / 2));
        let res_flush = loc_tx_buf.flush_to(&mut cmp_vec);
        if let Ok(n) = res_flush {
            assert_eq!(
                loc_tx_buf.head,
                Wrapping(CONN_TX_BUF_SIZE + (CONN_TX_BUF_SIZE / 2))
            );
            assert_eq!(
                loc_tx_buf.tail,
                Wrapping(CONN_TX_BUF_SIZE + (CONN_TX_BUF_SIZE / 2))
            );
            assert_eq!(n, cmp_vec.len());
            let mut data = vec![1; (CONN_TX_BUF_SIZE / 2) as usize];
            data.append(&mut vec![0; (CONN_TX_BUF_SIZE / 2) as usize]);
            assert_eq!(cmp_vec, data[..n]);
        }
    }
}