rusty_common/websocket/
zerocopy_frame.rs

1//! Zero-copy WebSocket frame parsing for ultra-fast binary protocol handling
2//!
3//! This module provides zerocopy-based WebSocket frame parsing that eliminates
4//! allocations when processing frame headers and enables direct access to payloads.
5
6use zerocopy::FromBytes;
7
8/// WebSocket frame header (2 bytes minimum)
9///
10/// ```text
11/// 0                   1
12/// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5
13/// +-+-+-+-+-------+-+-------------+
14/// |F|R|R|R| opcode|M| Payload len |
15/// |I|S|S|S|  (4)  |A|     (7)     |
16/// |N|V|V|V|       |S|             |
17/// | |1|2|3|       |K|             |
18/// +-+-+-+-+-------+-+-------------+
19/// ```
20#[derive(
21    Debug,
22    Clone,
23    Copy,
24    zerocopy_derive::FromBytes,
25    zerocopy_derive::Immutable,
26    zerocopy_derive::KnownLayout,
27)]
28#[repr(C, packed)]
29pub struct ZeroCopyFrameHeader {
30    /// First byte: FIN (1 bit) + RSV (3 bits) + Opcode (4 bits)
31    pub first_byte: u8,
32    /// Second byte: MASK (1 bit) + Payload length (7 bits)
33    pub second_byte: u8,
34}
35
36impl ZeroCopyFrameHeader {
37    /// Check if this is the final fragment
38    #[inline]
39    pub const fn is_final(&self) -> bool {
40        (self.first_byte & 0x80) != 0
41    }
42
43    /// Get the opcode
44    #[inline]
45    pub const fn opcode(&self) -> u8 {
46        self.first_byte & 0x0F
47    }
48
49    /// Check if the payload is masked
50    #[inline]
51    pub const fn is_masked(&self) -> bool {
52        (self.second_byte & 0x80) != 0
53    }
54
55    /// Get the initial payload length (0-125, 126, or 127)
56    #[inline]
57    pub const fn payload_len_indicator(&self) -> u8 {
58        self.second_byte & 0x7F
59    }
60
61    /// Calculate total header size based on payload length indicator
62    #[inline]
63    pub const fn header_size(&self) -> usize {
64        let base_size = 2; // First two bytes
65        let mask_size = if self.is_masked() { 4 } else { 0 };
66
67        match self.payload_len_indicator() {
68            126 => base_size + 2 + mask_size, // 16-bit extended length
69            127 => base_size + 8 + mask_size, // 64-bit extended length
70            _ => base_size + mask_size,       // 7-bit length
71        }
72    }
73}
74
75/// Extended payload length for medium-sized messages (126 indicator)
76#[derive(
77    Debug,
78    Clone,
79    Copy,
80    zerocopy_derive::FromBytes,
81    zerocopy_derive::Immutable,
82    zerocopy_derive::KnownLayout,
83)]
84#[repr(C)]
85pub struct ExtendedLength16 {
86    /// The length of the payload.
87    pub length: [u8; 2], // Big-endian u16
88}
89
90impl ExtendedLength16 {
91    /// Returns the length as a u64.
92    #[inline]
93    pub const fn as_u64(&self) -> u64 {
94        u16::from_be_bytes(self.length) as u64
95    }
96}
97
98/// Extended payload length for large messages (127 indicator)
99#[derive(
100    Debug,
101    Clone,
102    Copy,
103    zerocopy_derive::FromBytes,
104    zerocopy_derive::Immutable,
105    zerocopy_derive::KnownLayout,
106)]
107#[repr(C)]
108pub struct ExtendedLength64 {
109    /// The length of the payload.
110    pub length: [u8; 8], // Big-endian u64
111}
112
113impl ExtendedLength64 {
114    /// Returns the length as a u64.
115    #[inline]
116    pub const fn as_u64(&self) -> u64 {
117        u64::from_be_bytes(self.length)
118    }
119}
120
121/// Masking key for client-to-server messages
122#[derive(
123    Debug,
124    Clone,
125    Copy,
126    zerocopy_derive::FromBytes,
127    zerocopy_derive::Immutable,
128    zerocopy_derive::KnownLayout,
129)]
130#[repr(C)]
131pub struct MaskingKey {
132    /// The masking key.
133    pub key: [u8; 4],
134}
135
136/// Zero-copy WebSocket frame parser
137pub struct ZeroCopyFrameParser;
138
139impl ZeroCopyFrameParser {
140    /// Parse a WebSocket frame header with zero allocations
141    ///
142    /// Returns (header, payload_length, mask_key, payload_offset)
143    pub fn parse_header(
144        buffer: &[u8],
145    ) -> Option<(ZeroCopyFrameHeader, u64, Option<MaskingKey>, usize)> {
146        if buffer.len() < 2 {
147            return None;
148        }
149
150        // Parse the basic header
151        let header = ZeroCopyFrameHeader::ref_from_bytes(&buffer[..2]).ok()?;
152        let mut offset = 2;
153
154        // Parse extended payload length if needed
155        let payload_len = match header.payload_len_indicator() {
156            126 => {
157                if buffer.len() < offset + 2 {
158                    return None;
159                }
160                let ext_len = ExtendedLength16::ref_from_bytes(&buffer[offset..offset + 2]).ok()?;
161                offset += 2;
162                ext_len.as_u64()
163            }
164            127 => {
165                if buffer.len() < offset + 8 {
166                    return None;
167                }
168                let ext_len = ExtendedLength64::ref_from_bytes(&buffer[offset..offset + 8]).ok()?;
169                offset += 8;
170                ext_len.as_u64()
171            }
172            len => len as u64,
173        };
174
175        // Parse masking key if present
176        let mask_key = if header.is_masked() {
177            if buffer.len() < offset + 4 {
178                return None;
179            }
180            let key = MaskingKey::ref_from_bytes(&buffer[offset..offset + 4]).ok()?;
181            offset += 4;
182            Some(*key)
183        } else {
184            None
185        };
186
187        Some((*header, payload_len, mask_key, offset))
188    }
189
190    /// Get a reference to the payload data without copying
191    #[inline]
192    pub fn get_payload(buffer: &[u8], payload_offset: usize, payload_len: u64) -> Option<&[u8]> {
193        let end = payload_offset + payload_len as usize;
194        if buffer.len() >= end {
195            Some(&buffer[payload_offset..end])
196        } else {
197            None
198        }
199    }
200
201    /// Unmask payload data in-place (for client-to-server messages)
202    pub fn unmask_payload(payload: &mut [u8], mask_key: &MaskingKey) {
203        for (i, byte) in payload.iter_mut().enumerate() {
204            *byte ^= mask_key.key[i % 4];
205        }
206    }
207}
208
209/// WebSocket opcodes
210#[derive(Debug, Clone, Copy, PartialEq, Eq)]
211#[repr(u8)]
212pub enum OpCode {
213    /// Continuation frame.
214    Continuation = 0x0,
215    /// Text frame.
216    Text = 0x1,
217    /// Binary frame.
218    Binary = 0x2,
219    /// Close frame.
220    Close = 0x8,
221    /// Ping frame.
222    Ping = 0x9,
223    /// Pong frame.
224    Pong = 0xA,
225}
226
227impl OpCode {
228    /// Creates an `OpCode` from a `u8`.
229    #[inline]
230    pub const fn from_u8(value: u8) -> Option<Self> {
231        match value {
232            0x0 => Some(OpCode::Continuation),
233            0x1 => Some(OpCode::Text),
234            0x2 => Some(OpCode::Binary),
235            0x8 => Some(OpCode::Close),
236            0x9 => Some(OpCode::Ping),
237            0xA => Some(OpCode::Pong),
238            _ => None,
239        }
240    }
241
242    /// Returns `true` if the opcode is a control frame.
243    #[inline]
244    pub const fn is_control(&self) -> bool {
245        matches!(self, OpCode::Close | OpCode::Ping | OpCode::Pong)
246    }
247
248    /// Returns `true` if the opcode is a data frame.
249    #[inline]
250    pub const fn is_data(&self) -> bool {
251        matches!(self, OpCode::Text | OpCode::Binary | OpCode::Continuation)
252    }
253}
254
255#[cfg(test)]
256mod tests {
257    use super::*;
258
259    #[test]
260    fn test_simple_frame_header() {
261        // Text frame, FIN=1, no mask, 5 bytes payload
262        let data = vec![0x81, 0x05, b'h', b'e', b'l', b'l', b'o'];
263
264        let (header, payload_len, mask_key, offset) =
265            ZeroCopyFrameParser::parse_header(&data).unwrap();
266
267        assert!(header.is_final());
268        assert_eq!(header.opcode(), 0x1); // Text
269        assert!(!header.is_masked());
270        assert_eq!(payload_len, 5);
271        assert!(mask_key.is_none());
272        assert_eq!(offset, 2);
273
274        let payload = ZeroCopyFrameParser::get_payload(&data, offset, payload_len).unwrap();
275        assert_eq!(payload, b"hello");
276    }
277
278    #[test]
279    fn test_extended_16bit_length() {
280        // Binary frame with 16-bit extended length (300 bytes)
281        let mut data = vec![0x82, 126, 0x01, 0x2C]; // 300 = 0x012C
282        data.extend(vec![0u8; 300]);
283
284        let (header, payload_len, _, offset) = ZeroCopyFrameParser::parse_header(&data).unwrap();
285
286        assert_eq!(header.opcode(), 0x2); // Binary
287        assert_eq!(payload_len, 300);
288        assert_eq!(offset, 4);
289    }
290
291    #[test]
292    fn test_masked_frame() {
293        // Masked text frame
294        // To get masked "hello", we XOR each byte with the mask key:
295        // 'h' (0x68) ^ 0x12 = 0x7A
296        // 'e' (0x65) ^ 0x34 = 0x51
297        // 'l' (0x6C) ^ 0x56 = 0x3A
298        // 'l' (0x6C) ^ 0x78 = 0x14
299        // 'o' (0x6F) ^ 0x12 = 0x7D (mask wraps around, so uses first byte again)
300        let data = vec![
301            0x81, 0x85, // FIN=1, Text, MASK=1, len=5
302            0x12, 0x34, 0x56, 0x78, // Mask key
303            0x7A, 0x51, 0x3A, 0x14, 0x7D, // Masked "hello"
304        ];
305
306        let (header, payload_len, mask_key, offset) =
307            ZeroCopyFrameParser::parse_header(&data).unwrap();
308
309        assert!(header.is_masked());
310        assert_eq!(payload_len, 5);
311        assert!(mask_key.is_some());
312        assert_eq!(offset, 6);
313
314        // Test unmasking
315        let mut payload = data[offset..offset + 5].to_vec();
316        ZeroCopyFrameParser::unmask_payload(&mut payload, &mask_key.unwrap());
317        assert_eq!(&payload, b"hello");
318    }
319
320    #[test]
321    fn test_control_frame() {
322        // Ping frame
323        let data = vec![0x89, 0x00]; // FIN=1, Ping, no payload
324
325        let (header, payload_len, _, _) = ZeroCopyFrameParser::parse_header(&data).unwrap();
326
327        assert_eq!(header.opcode(), 0x9); // Ping
328        assert_eq!(payload_len, 0);
329
330        let opcode = OpCode::from_u8(header.opcode()).unwrap();
331        assert!(opcode.is_control());
332        assert!(!opcode.is_data());
333    }
334}