diff --git a/ylong_http/src/h2/encoder.rs b/ylong_http/src/h2/encoder.rs index 4380a31ded3ac6e39dd07fb136dab7df206cc8dd..ee5c14dc205deffdaf278068e2434dca2d2f2d9d 100644 --- a/ylong_http/src/h2/encoder.rs +++ b/ylong_http/src/h2/encoder.rs @@ -11,6 +11,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::cmp::min; + use crate::h2::frame::{FrameFlags, FrameType, Payload, Priority, Setting}; use crate::h2::{Frame, Goaway, HpackEncoder, Settings}; @@ -93,6 +95,7 @@ pub struct FrameEncoder { hpack_encoder: HpackEncoder, state: FrameEncoderState, encoded_bytes: usize, + data_offset: usize, // `remaining_header_payload` will always be smaller than the minimum max_frame_size, // because the `header_payload_buffer` length is the minimum max_frame_size. remaining_header_payload: usize, @@ -114,6 +117,7 @@ impl FrameEncoder { hpack_encoder: HpackEncoder::with_max_size(max_header_list_size), state: FrameEncoderState::Idle, encoded_bytes: 0, + data_offset: 0, remaining_header_payload: 0, remaining_payload_bytes: 0, is_end_stream: false, @@ -151,7 +155,13 @@ impl FrameEncoder { Payload::Priority(_) => self.state = FrameEncoderState::EncodingPriorityFrame, Payload::RstStream(_) => self.state = FrameEncoderState::EncodingRstStreamFrame, Payload::Ping(_) => self.state = FrameEncoderState::EncodingPingFrame, - Payload::Data(_) => self.state = FrameEncoderState::EncodingDataHeader, + Payload::Data(data) => { + self.state = { + self.data_offset = 0; + self.remaining_payload_bytes = data.size(); + FrameEncoderState::EncodingDataHeader + } + } Payload::Settings(_) => self.state = FrameEncoderState::EncodingSettingsFrame, Payload::Goaway(_) => self.state = FrameEncoderState::EncodingGoawayFrame, Payload::WindowUpdate(_) => { @@ -387,6 +397,8 @@ impl FrameEncoder { } fn headers_header_status(&mut self) { + // Headers encoding does not need to consider max_frame_size + // because header_payload_buffer must be smaller than max_frame_size. self.state = if self.header_payload_index < self.remaining_header_payload { FrameEncoderState::EncodingHeadersPayload } else if self.hpack_encoder.is_finished() { @@ -479,7 +491,10 @@ impl FrameEncoder { } buf[3] = FrameType::Continuation as u8; let mut new_flags = FrameFlags::empty(); - if self.remaining_header_payload <= self.max_frame_size && self.hpack_encoder.is_finished() && self.is_end_headers { + if self.remaining_header_payload <= self.max_frame_size + && self.hpack_encoder.is_finished() + && self.is_end_headers + { // Sets the END_HEADER flag on the last CONTINUATION frame. new_flags.set_end_headers(true); } @@ -510,7 +525,7 @@ impl FrameEncoder { fn encode_data_header(&mut self, buf: &mut [u8]) -> Result { if let Some(frame) = &self.current_frame { - if let Payload::Data(data_frame) = frame.payload() { + if let Payload::Data(_) = frame.payload() { // HTTP/2 frame header size is 9 bytes. let frame_header_size = 9; let remaining_header_bytes = if self.encoded_bytes >= frame_header_size { @@ -520,12 +535,18 @@ impl FrameEncoder { }; let bytes_to_write = remaining_header_bytes.min(buf.len()); - self.iterate_data_header(frame, buf, data_frame.data().len(), bytes_to_write)?; + self.iterate_data_header( + frame, + buf, + min(self.remaining_payload_bytes, self.max_frame_size), + bytes_to_write, + )?; self.encoded_bytes += bytes_to_write; if self.encoded_bytes == frame_header_size { + // data frame header is finished, reset encoded_bytes. + self.encoded_bytes = 0; self.state = FrameEncoderState::EncodingDataPayload; - self.remaining_payload_bytes = data_frame.data().len(); } Ok(bytes_to_write) } else { @@ -556,7 +577,12 @@ impl FrameEncoder { } // The 5th byte represents the frame flags in the frame header. 4 => { - *item = frame.flags().bits(); + if self.remaining_payload_bytes <= self.max_frame_size { + *item = frame.flags().bits(); + } else { + // When max_frame_size is exceeded, a data frame cannot send all data. + *item = frame.flags().bits() & 0xFE; + } } // The last 4 bytes (6th to 9th) represent the stream identifier in the // frame header. @@ -575,24 +601,22 @@ impl FrameEncoder { fn encode_data_payload(&mut self, buf: &mut [u8]) -> Result { if let Some(frame) = self.current_frame.as_ref() { if let Payload::Data(data_frame) = frame.payload() { - // HTTP/2 frame header size is 9 bytes. - let frame_header_size = 9; - let encoded_payload_bytes = self.encoded_bytes - frame_header_size; let payload = data_frame.data(); - let bytes_to_write = self.encode_slice(buf, payload, encoded_payload_bytes); - self.encoded_bytes += bytes_to_write; - - if self.remaining_payload_bytes == 0 { - self.state = if self.is_end_stream { - FrameEncoderState::DataComplete - } else { - FrameEncoderState::EncodingDataPayload - }; - } else if self.remaining_payload_bytes > self.max_frame_size { - self.state = FrameEncoderState::EncodingDataPayload; - } + let writen_bytes = self.encode_slice(buf, payload); + self.data_offset += writen_bytes; + self.remaining_payload_bytes -= writen_bytes; + + self.state = if self.remaining_payload_bytes == 0 { + self.data_offset = 0; + FrameEncoderState::DataComplete + } else if self.data_offset == self.max_frame_size { + self.data_offset = 0; + FrameEncoderState::EncodingDataHeader + } else { + FrameEncoderState::EncodingDataPayload + }; - Ok(bytes_to_write) + Ok(writen_bytes) } else { Err(FrameEncoderErr::UnexpectedPayloadType) } @@ -1264,10 +1288,10 @@ impl FrameEncoder { } } - fn encode_slice(&self, buf: &mut [u8], data: &[u8], start: usize) -> usize { - let data_len = data.len(); - let remaining_data_bytes = data_len.saturating_sub(start); - let bytes_to_write = remaining_data_bytes.min(buf.len()); + fn encode_slice(&self, buf: &mut [u8], data: &[u8]) -> usize { + let start = self.data_offset; + let allow_to_write = (self.max_frame_size - start).min(self.remaining_payload_bytes); + let bytes_to_write = allow_to_write.min(buf.len()); buf[..bytes_to_write].copy_from_slice(&data[start..start + bytes_to_write]); bytes_to_write @@ -1895,4 +1919,108 @@ mod ut_frame_encoder { frame_encoder.current_frame = None; assert!(frame_encoder.encode_padding(&mut buf).is_err()); } + + /// UT test cases for `FrameEncoder` encoding data frame. + /// + /// # Brief + /// 1. Creates a `FrameEncoder`. + /// 2. Creates a `Frame` smaller than max_frame_size with `Payload::Data` + /// and sets the flags. + /// 3. Sets the frame for the encoder. + /// 4. Encode the data frame with a buf. + /// 5. Checks whether the result is correct. + #[test] + fn ut_encode_small_data_frame() { + let mut encoder = FrameEncoder::new(100, 4096); + let data_payload = vec![b'a'; 10]; + let mut buf = [0u8; 10]; + encode_small_frame(&mut encoder, &mut buf, data_payload.clone()); + + // buf is larger than frame but smaller than max_frame_size; + let mut buf = [0u8; 50]; + encode_small_frame(&mut encoder, &mut buf, data_payload.clone()); + + // buf is larger than max_frame_size; + let mut buf = [0u8; 200]; + encode_small_frame(&mut encoder, &mut buf, data_payload); + } + + fn encode_small_frame(encoder: &mut FrameEncoder, buf: &mut [u8], data_payload: Vec) { + let data_frame = Frame::new( + 1, + FrameFlags::new(1), + Payload::Data(Data::new(data_payload)), + ); + encoder.set_frame(data_frame).unwrap(); + + let mut result = [b'0'; 10 + 9]; + let total = assert_encoded_data(encoder, &mut result, buf); + assert_eq!(total, 10 + 9); + assert_eq!(result[4], 0x1); + assert_eq!(encoder.state, FrameEncoderState::DataComplete); + } + + /// UT test cases for `FrameEncoder` encoding data frame. + /// + /// # Brief + /// 1. Creates a `FrameEncoder`. + /// 2. Creates a `Frame` larger than max_frame_size with `Payload::Data` and + /// sets the flags. + /// 3. Sets the frame for the encoder. + /// 4. Encode the data frame with a buf. + /// 5. Checks whether the result is correct. + #[test] + fn ut_encode_large_data_frame() { + let mut encoder = FrameEncoder::new(100, 4096); + let data_payload = vec![b'a'; 1024]; + let mut buf = [0u8; 10]; + + let data_frame = Frame::new( + 1, + FrameFlags::new(0), + Payload::Data(Data::new(data_payload.clone())), + ); + let mut result = [b'0'; 1024 + 9 * 11]; + encoder.set_frame(data_frame).unwrap(); + + let total = assert_encoded_data(&mut encoder, &mut result, &mut buf); + // This is because the data frame flag is 0. + assert_eq!(result[4 + 10 * (9 + 100)], 0x0); + assert_eq!(total, 1024 + 9 * 11); + for index in 0..=9 { + assert_eq!(result[4 + index * (9 + 100)], 0x0); + } + assert_eq!(encoder.state, FrameEncoderState::DataComplete); + + // finished + let data_frame = Frame::new( + 1, + FrameFlags::new(1), + Payload::Data(Data::new(data_payload)), + ); + let mut result = [b'0'; 1024 + 9 * 11]; + encoder.set_frame(data_frame).unwrap(); + + let total = assert_encoded_data(&mut encoder, &mut result, &mut buf); + // This is because the data frame flag is 0. + assert_eq!(result[4 + 10 * (9 + 100)], 0x1); + assert_eq!(total, 1024 + 9 * 11); + for index in 0..=9 { + assert_eq!(result[4 + index * (9 + 100)], 0x0); + } + assert_eq!(encoder.state, FrameEncoderState::DataComplete); + } + + fn assert_encoded_data(encoder: &mut FrameEncoder, result: &mut [u8], buf: &mut [u8]) -> usize { + let mut total = 0; + loop { + let size = encoder.encode(buf).unwrap(); + result[total..total + size].copy_from_slice(&buf[..size]); + total += size; + if size == 0 { + break; + } + } + total + } }