From ae67245130b5fdb5b85b3e58bcf66a4c96bf1525 Mon Sep 17 00:00:00 2001 From: muxi Date: Thu, 19 Dec 2024 21:07:12 +0800 Subject: [PATCH] fix: Extract common function try_retrieve_current_frame Signed-off-by: muxi --- ylong_http/src/h2/encoder.rs | 753 ++++++++++++++++------------------- 1 file changed, 354 insertions(+), 399 deletions(-) diff --git a/ylong_http/src/h2/encoder.rs b/ylong_http/src/h2/encoder.rs index adfbef3..7e1cd1c 100644 --- a/ylong_http/src/h2/encoder.rs +++ b/ylong_http/src/h2/encoder.rs @@ -371,36 +371,40 @@ impl FrameEncoder { self.remaining_header_payload = payload_size; } - fn encode_headers_frame(&mut self, buf: &mut [u8]) -> Result { + fn try_retrieve_current_frame(&self) -> Result<&Frame, FrameEncoderErr> { if let Some(frame) = &self.current_frame { - if let Payload::Headers(_) = frame.payload() { - // HTTP/2 frame header size is 9 bytes. - let frame_header_size = 9; - let remaining_header_bytes = if self.encoded_bytes >= frame_header_size { - 0 - } else { - frame_header_size - self.encoded_bytes - }; - let bytes_to_write = remaining_header_bytes.min(buf.len()); + return Ok(frame); + } + Err(FrameEncoderErr::NoCurrentFrame) + } - self.iterate_headers_header(frame, buf, bytes_to_write)?; - self.encoded_bytes += bytes_to_write; - let bytes_written = bytes_to_write; - let mut payload_bytes_written = 0; + fn encode_headers_frame(&mut self, buf: &mut [u8]) -> Result { + let frame = self.try_retrieve_current_frame()?; + if let Payload::Headers(_) = frame.payload() { + // HTTP/2 frame header size is 9 bytes. + let frame_header_size = 9; + let remaining_header_bytes = if self.encoded_bytes >= frame_header_size { + 0 + } else { + frame_header_size - self.encoded_bytes + }; + let bytes_to_write = remaining_header_bytes.min(buf.len()); - if self.encoded_bytes >= frame_header_size { - payload_bytes_written = self - .write_payload(&mut buf[bytes_written..], self.remaining_header_payload); - self.encoded_bytes += payload_bytes_written; - self.headers_header_status(); - } + self.iterate_headers_header(frame, buf, bytes_to_write)?; + self.encoded_bytes += bytes_to_write; + let bytes_written = bytes_to_write; + let mut payload_bytes_written = 0; - Ok(bytes_written + payload_bytes_written) - } else { - Err(FrameEncoderErr::UnexpectedPayloadType) + if self.encoded_bytes >= frame_header_size { + payload_bytes_written = + self.write_payload(&mut buf[bytes_written..], self.remaining_header_payload); + self.encoded_bytes += payload_bytes_written; + self.headers_header_status(); } + + Ok(bytes_written + payload_bytes_written) } else { - Err(FrameEncoderErr::NoCurrentFrame) + Err(FrameEncoderErr::UnexpectedPayloadType) } } @@ -461,107 +465,98 @@ impl FrameEncoder { } fn encode_headers_payload(&mut self, buf: &mut [u8]) -> Result { - if let Some(frame) = &self.current_frame { - if let Payload::Headers(_) = frame.payload() { - if buf.is_empty() { - return Ok(0); - } + let frame = self.try_retrieve_current_frame()?; + if let Payload::Headers(_) = frame.payload() { + if buf.is_empty() { + return Ok(0); + } - let payload_bytes_written = self.write_payload(buf, self.remaining_header_payload); - self.encoded_bytes += payload_bytes_written; + let payload_bytes_written = self.write_payload(buf, self.remaining_header_payload); + self.encoded_bytes += payload_bytes_written; - // Updates the state based on the encoding progress - self.headers_header_status(); + // Updates the state based on the encoding progress + self.headers_header_status(); - Ok(payload_bytes_written) - } else { - Err(FrameEncoderErr::UnexpectedPayloadType) - } + Ok(payload_bytes_written) } else { - Err(FrameEncoderErr::NoCurrentFrame) + Err(FrameEncoderErr::UnexpectedPayloadType) } } fn encode_continuation_frames(&mut self, buf: &mut [u8]) -> Result { - if let Some(frame) = &self.current_frame { - if let Payload::Headers(_) = frame.payload() { - let available_space = buf.len(); - let frame_header_size = 9; - // TODO allow available_space < 9 - if available_space < frame_header_size { - return Ok(0); - } - // Encodes CONTINUATION frame header. - // And this value is always the remaining_header_payload. - let continuation_frame_len = self.remaining_header_payload.min(self.max_frame_size); - for (buf_index, item) in buf.iter_mut().enumerate().take(3) { - *item = ((continuation_frame_len >> (16 - (8 * buf_index))) & 0xFF) as u8; - } - buf[3] = FrameType::Continuation as u8; - let mut new_flags = FrameFlags::empty(); - if self.remaining_header_payload <= self.max_frame_size - && self.hpack_encoder.is_finished() - && self.is_end_headers - { - // Sets the END_HEADER flag on the last CONTINUATION frame. - new_flags.set_end_headers(true); - } - buf[4] = new_flags.bits(); - - for buf_index in 0..4 { - let stream_id_byte_index = buf_index; - buf[5 + buf_index] = - (frame.stream_id() >> (24 - (8 * stream_id_byte_index))) as u8; - } - - // Encodes CONTINUATION frame payload. - let payload_bytes_written = - self.write_payload(&mut buf[frame_header_size..], continuation_frame_len); - self.encoded_bytes += payload_bytes_written; + let frame = self.try_retrieve_current_frame()?; - // Updates the state based on the encoding progress. - self.headers_header_status(); + if let Payload::Headers(_) = frame.payload() { + let available_space = buf.len(); + let frame_header_size = 9; + // TODO allow available_space < 9 + if available_space < frame_header_size { + return Ok(0); + } + // Encodes CONTINUATION frame header. + // And this value is always the remaining_header_payload. + let continuation_frame_len = self.remaining_header_payload.min(self.max_frame_size); + for (buf_index, item) in buf.iter_mut().enumerate().take(3) { + *item = ((continuation_frame_len >> (16 - (8 * buf_index))) & 0xFF) as u8; + } + buf[3] = FrameType::Continuation as u8; + let mut new_flags = FrameFlags::empty(); + if self.remaining_header_payload <= self.max_frame_size + && self.hpack_encoder.is_finished() + && self.is_end_headers + { + // Sets the END_HEADER flag on the last CONTINUATION frame. + new_flags.set_end_headers(true); + } + buf[4] = new_flags.bits(); - Ok(frame_header_size + payload_bytes_written) - } else { - Err(FrameEncoderErr::UnexpectedPayloadType) + for buf_index in 0..4 { + let stream_id_byte_index = buf_index; + buf[5 + buf_index] = (frame.stream_id() >> (24 - (8 * stream_id_byte_index))) as u8; } + + // Encodes CONTINUATION frame payload. + let payload_bytes_written = + self.write_payload(&mut buf[frame_header_size..], continuation_frame_len); + self.encoded_bytes += payload_bytes_written; + + // Updates the state based on the encoding progress. + self.headers_header_status(); + + Ok(frame_header_size + payload_bytes_written) } else { - Err(FrameEncoderErr::NoCurrentFrame) + Err(FrameEncoderErr::UnexpectedPayloadType) } } fn encode_data_header(&mut self, buf: &mut [u8]) -> Result { - if let Some(frame) = &self.current_frame { - if let Payload::Data(_) = frame.payload() { - // HTTP/2 frame header size is 9 bytes. - let frame_header_size = 9; - let remaining_header_bytes = if self.encoded_bytes >= frame_header_size { - 0 - } else { - frame_header_size - self.encoded_bytes - }; - let bytes_to_write = remaining_header_bytes.min(buf.len()); - - self.iterate_data_header( - frame, - buf, - min(self.remaining_payload_bytes, self.max_frame_size), - bytes_to_write, - )?; - - self.encoded_bytes += bytes_to_write; - if self.encoded_bytes == frame_header_size { - // data frame header is finished, reset encoded_bytes. - self.encoded_bytes = 0; - self.state = FrameEncoderState::EncodingDataPayload; - } - Ok(bytes_to_write) + let frame = self.try_retrieve_current_frame()?; + if let Payload::Data(_) = frame.payload() { + // HTTP/2 frame header size is 9 bytes. + let frame_header_size = 9; + let remaining_header_bytes = if self.encoded_bytes >= frame_header_size { + 0 } else { - Err(FrameEncoderErr::UnexpectedPayloadType) + frame_header_size - self.encoded_bytes + }; + let bytes_to_write = remaining_header_bytes.min(buf.len()); + + self.iterate_data_header( + frame, + buf, + min(self.remaining_payload_bytes, self.max_frame_size), + bytes_to_write, + )?; + + self.encoded_bytes += bytes_to_write; + if self.encoded_bytes == frame_header_size { + // data frame header is finished, reset encoded_bytes. + self.encoded_bytes = 0; + self.state = FrameEncoderState::EncodingDataPayload; } + Ok(bytes_to_write) } else { - Err(FrameEncoderErr::NoCurrentFrame) + Err(FrameEncoderErr::UnexpectedPayloadType) } } @@ -607,85 +602,78 @@ impl FrameEncoder { } fn encode_data_payload(&mut self, buf: &mut [u8]) -> Result { - if let Some(frame) = self.current_frame.as_ref() { - if let Payload::Data(data_frame) = frame.payload() { - let payload = data_frame.data(); - let writen_bytes = self.encode_slice(buf, payload); - self.data_offset += writen_bytes; - self.remaining_payload_bytes -= writen_bytes; - - self.state = if self.remaining_payload_bytes == 0 { - self.data_offset = 0; - FrameEncoderState::DataComplete - } else if self.data_offset == self.max_frame_size { - self.data_offset = 0; - FrameEncoderState::EncodingDataHeader - } else { - FrameEncoderState::EncodingDataPayload - }; - - Ok(writen_bytes) + let frame = self.try_retrieve_current_frame()?; + if let Payload::Data(data_frame) = frame.payload() { + let payload = data_frame.data(); + let writen_bytes = self.encode_slice(buf, payload); + self.data_offset += writen_bytes; + self.remaining_payload_bytes -= writen_bytes; + + self.state = if self.remaining_payload_bytes == 0 { + self.data_offset = 0; + FrameEncoderState::DataComplete + } else if self.data_offset == self.max_frame_size { + self.data_offset = 0; + FrameEncoderState::EncodingDataHeader } else { - Err(FrameEncoderErr::UnexpectedPayloadType) - } + FrameEncoderState::EncodingDataPayload + }; + + Ok(writen_bytes) } else { - Err(FrameEncoderErr::NoCurrentFrame) + Err(FrameEncoderErr::UnexpectedPayloadType) } } fn encode_padding(&mut self, buf: &mut [u8]) -> Result { - if let Some(frame) = &self.current_frame { - if frame.flags().is_padded() { - let padding_len = if let Payload::Data(data_frame) = frame.payload() { - data_frame.data().len() - } else { - return Err(FrameEncoderErr::UnexpectedPayloadType); - }; + let frame = self.try_retrieve_current_frame()?; - let remaining_padding_bytes = padding_len - self.encoded_bytes; - let bytes_to_write = remaining_padding_bytes.min(buf.len()); + if frame.flags().is_padded() { + let padding_len = if let Payload::Data(data_frame) = frame.payload() { + data_frame.data().len() + } else { + return Err(FrameEncoderErr::UnexpectedPayloadType); + }; - for item in buf.iter_mut().take(bytes_to_write) { - // Padding bytes are filled with 0. - *item = 0; - } + let remaining_padding_bytes = padding_len - self.encoded_bytes; + let bytes_to_write = remaining_padding_bytes.min(buf.len()); - self.encoded_bytes += bytes_to_write; + for item in buf.iter_mut().take(bytes_to_write) { + // Padding bytes are filled with 0. + *item = 0; + } - if self.encoded_bytes == padding_len { - self.state = FrameEncoderState::DataComplete; - } + self.encoded_bytes += bytes_to_write; - Ok(bytes_to_write) - } else { - Ok(0) // No padding to encode, so return 0 bytes written. + if self.encoded_bytes == padding_len { + self.state = FrameEncoderState::DataComplete; } + + Ok(bytes_to_write) } else { - Err(FrameEncoderErr::NoCurrentFrame) + Ok(0) // No padding to encode, so return 0 bytes written. } } fn encode_goaway_frame(&mut self, buf: &mut [u8]) -> Result { - if let Some(frame) = &self.current_frame { - if let Payload::Goaway(_) = frame.payload() { - let frame_header_size = 9; - let remaining_header_bytes = if self.encoded_bytes >= frame_header_size { - 0 - } else { - frame_header_size - self.encoded_bytes - }; - let bytes_to_write = remaining_header_bytes.min(buf.len()); - self.iterate_goaway_header(frame, buf, bytes_to_write)?; - self.encoded_bytes += bytes_to_write; - if self.encoded_bytes == frame_header_size { - self.state = FrameEncoderState::EncodingGoawayPayload; - } - Ok(bytes_to_write) + let frame = self.try_retrieve_current_frame()?; + + if let Payload::Goaway(_) = frame.payload() { + let frame_header_size = 9; + let remaining_header_bytes = if self.encoded_bytes >= frame_header_size { + 0 } else { - Err(FrameEncoderErr::UnexpectedPayloadType) + frame_header_size - self.encoded_bytes + }; + let bytes_to_write = remaining_header_bytes.min(buf.len()); + self.iterate_goaway_header(frame, buf, bytes_to_write)?; + self.encoded_bytes += bytes_to_write; + if self.encoded_bytes == frame_header_size { + self.state = FrameEncoderState::EncodingGoawayPayload; } + Ok(bytes_to_write) } else { - Err(FrameEncoderErr::NoCurrentFrame) + Err(FrameEncoderErr::UnexpectedPayloadType) } } @@ -725,25 +713,22 @@ impl FrameEncoder { } fn encode_goaway_payload(&mut self, buf: &mut [u8]) -> Result { - if let Some(frame) = &self.current_frame { - if let Payload::Goaway(goaway) = frame.payload() { - let payload_size = goaway.encoded_len(); - let remaining_payload_bytes = - payload_size.saturating_sub(self.encoded_bytes.saturating_sub(9)); - let bytes_to_write = remaining_payload_bytes.min(buf.len()); - - self.iterate_goaway_payload(goaway, buf, bytes_to_write)?; - self.encoded_bytes += bytes_to_write; - if self.encoded_bytes == 9 + payload_size { - self.state = FrameEncoderState::DataComplete; - } - - Ok(bytes_to_write) - } else { - Err(FrameEncoderErr::UnexpectedPayloadType) + let frame = self.try_retrieve_current_frame()?; + if let Payload::Goaway(goaway) = frame.payload() { + let payload_size = goaway.encoded_len(); + let remaining_payload_bytes = + payload_size.saturating_sub(self.encoded_bytes.saturating_sub(9)); + let bytes_to_write = remaining_payload_bytes.min(buf.len()); + + self.iterate_goaway_payload(goaway, buf, bytes_to_write)?; + self.encoded_bytes += bytes_to_write; + if self.encoded_bytes == 9 + payload_size { + self.state = FrameEncoderState::DataComplete; } + + Ok(bytes_to_write) } else { - Err(FrameEncoderErr::NoCurrentFrame) + Err(FrameEncoderErr::UnexpectedPayloadType) } } @@ -781,29 +766,26 @@ impl FrameEncoder { } fn encode_window_update_frame(&mut self, buf: &mut [u8]) -> Result { - if let Some(frame) = &self.current_frame { - if let Payload::WindowUpdate(_) = frame.payload() { - // HTTP/2 frame header size is 9 bytes. - let frame_header_size = 9; - let remaining_header_bytes = if self.encoded_bytes >= frame_header_size { - 0 - } else { - frame_header_size - self.encoded_bytes - }; - let bytes_to_write = remaining_header_bytes.min(buf.len()); - self.iterate_window_update_header(frame, buf, bytes_to_write)?; - self.encoded_bytes += bytes_to_write; - if self.encoded_bytes == frame_header_size { - self.state = FrameEncoderState::EncodingWindowUpdatePayload; - // Resets the encoded_bytes counter here. - self.encoded_bytes = 0; - } - Ok(bytes_to_write) + let frame = self.try_retrieve_current_frame()?; + if let Payload::WindowUpdate(_) = frame.payload() { + // HTTP/2 frame header size is 9 bytes. + let frame_header_size = 9; + let remaining_header_bytes = if self.encoded_bytes >= frame_header_size { + 0 } else { - Err(FrameEncoderErr::UnexpectedPayloadType) + frame_header_size - self.encoded_bytes + }; + let bytes_to_write = remaining_header_bytes.min(buf.len()); + self.iterate_window_update_header(frame, buf, bytes_to_write)?; + self.encoded_bytes += bytes_to_write; + if self.encoded_bytes == frame_header_size { + self.state = FrameEncoderState::EncodingWindowUpdatePayload; + // Resets the encoded_bytes counter here. + self.encoded_bytes = 0; } + Ok(bytes_to_write) } else { - Err(FrameEncoderErr::NoCurrentFrame) + Err(FrameEncoderErr::UnexpectedPayloadType) } } @@ -848,61 +830,55 @@ impl FrameEncoder { } fn encode_window_update_payload(&mut self, buf: &mut [u8]) -> Result { - if let Some(frame) = &self.current_frame { - if let Payload::WindowUpdate(window_update) = frame.payload() { - let payload_size = 4usize; - let remaining_payload_bytes = - payload_size.saturating_sub(self.encoded_bytes.saturating_sub(9usize)); - let bytes_to_write = remaining_payload_bytes.min(buf.len()); - for (buf_index, buf_item) in buf.iter_mut().enumerate().take(bytes_to_write) { - let payload_byte_index = self - .encoded_bytes - .saturating_sub(9) - .saturating_add(buf_index); - let increment_byte_index = payload_byte_index; - *buf_item = - (window_update.get_increment() >> (24 - (8 * increment_byte_index))) as u8; - } - self.encoded_bytes += bytes_to_write; - if self.encoded_bytes == payload_size { - self.state = FrameEncoderState::DataComplete; - } - - Ok(bytes_to_write) - } else { - Err(FrameEncoderErr::UnexpectedPayloadType) + let frame = self.try_retrieve_current_frame()?; + if let Payload::WindowUpdate(window_update) = frame.payload() { + let payload_size = 4usize; + let remaining_payload_bytes = + payload_size.saturating_sub(self.encoded_bytes.saturating_sub(9usize)); + let bytes_to_write = remaining_payload_bytes.min(buf.len()); + for (buf_index, buf_item) in buf.iter_mut().enumerate().take(bytes_to_write) { + let payload_byte_index = self + .encoded_bytes + .saturating_sub(9) + .saturating_add(buf_index); + let increment_byte_index = payload_byte_index; + *buf_item = + (window_update.get_increment() >> (24 - (8 * increment_byte_index))) as u8; + } + self.encoded_bytes += bytes_to_write; + if self.encoded_bytes == payload_size { + self.state = FrameEncoderState::DataComplete; } + + Ok(bytes_to_write) } else { - Err(FrameEncoderErr::NoCurrentFrame) + Err(FrameEncoderErr::UnexpectedPayloadType) } } fn encode_settings_frame(&mut self, buf: &mut [u8]) -> Result { - if let Some(frame) = &self.current_frame { - if let Payload::Settings(settings) = frame.payload() { - let frame_header_size = 9; - let remaining_header_bytes = if self.encoded_bytes >= frame_header_size { - 0 - } else { - frame_header_size - self.encoded_bytes - }; - let bytes_to_write = remaining_header_bytes.min(buf.len()); - self.iterate_settings_header( - frame, - buf, - settings.get_settings().len() * 6, - bytes_to_write, - )?; - self.encoded_bytes += bytes_to_write; - if self.encoded_bytes == frame_header_size { - self.state = FrameEncoderState::EncodingSettingsPayload; - } - Ok(bytes_to_write) + let frame = self.try_retrieve_current_frame()?; + if let Payload::Settings(settings) = frame.payload() { + let frame_header_size = 9; + let remaining_header_bytes = if self.encoded_bytes >= frame_header_size { + 0 } else { - Err(FrameEncoderErr::UnexpectedPayloadType) + frame_header_size - self.encoded_bytes + }; + let bytes_to_write = remaining_header_bytes.min(buf.len()); + self.iterate_settings_header( + frame, + buf, + settings.get_settings().len() * 6, + bytes_to_write, + )?; + self.encoded_bytes += bytes_to_write; + if self.encoded_bytes == frame_header_size { + self.state = FrameEncoderState::EncodingSettingsPayload; } + Ok(bytes_to_write) } else { - Err(FrameEncoderErr::NoCurrentFrame) + Err(FrameEncoderErr::UnexpectedPayloadType) } } @@ -944,24 +920,21 @@ impl FrameEncoder { } fn encode_settings_payload(&mut self, buf: &mut [u8]) -> Result { - if let Some(frame) = &self.current_frame { - if let Payload::Settings(settings) = frame.payload() { - let settings_len = settings.get_settings().len() * 6; - let remaining_payload_bytes = - settings_len.saturating_sub(self.encoded_bytes.saturating_sub(9)); - let bytes_to_write = remaining_payload_bytes.min(buf.len()); - self.iterate_settings_payload(settings, buf, bytes_to_write)?; - self.encoded_bytes += bytes_to_write; - if self.encoded_bytes == 9 + settings_len { - self.state = FrameEncoderState::DataComplete; - } - - Ok(bytes_to_write) - } else { - Err(FrameEncoderErr::UnexpectedPayloadType) + let frame = self.try_retrieve_current_frame()?; + if let Payload::Settings(settings) = frame.payload() { + let settings_len = settings.get_settings().len() * 6; + let remaining_payload_bytes = + settings_len.saturating_sub(self.encoded_bytes.saturating_sub(9)); + let bytes_to_write = remaining_payload_bytes.min(buf.len()); + self.iterate_settings_payload(settings, buf, bytes_to_write)?; + self.encoded_bytes += bytes_to_write; + if self.encoded_bytes == 9 + settings_len { + self.state = FrameEncoderState::DataComplete; } + + Ok(bytes_to_write) } else { - Err(FrameEncoderErr::NoCurrentFrame) + Err(FrameEncoderErr::UnexpectedPayloadType) } } @@ -1005,28 +978,25 @@ impl FrameEncoder { } fn encode_priority_frame(&mut self, buf: &mut [u8]) -> Result { - if let Some(frame) = &self.current_frame { - if let Payload::Priority(_) = frame.payload() { - // HTTP/2 frame header size is 9 bytes. - let frame_header_size = 9; - let remaining_header_bytes = if self.encoded_bytes >= frame_header_size { - 0 - } else { - frame_header_size - self.encoded_bytes - }; - let bytes_to_write = remaining_header_bytes.min(buf.len()); - - self.iterate_priority_header(frame, buf, bytes_to_write)?; - self.encoded_bytes += bytes_to_write; - if self.encoded_bytes == frame_header_size { - self.state = FrameEncoderState::EncodingPriorityPayload; - } - Ok(bytes_to_write) + let frame = self.try_retrieve_current_frame()?; + if let Payload::Priority(_) = frame.payload() { + // HTTP/2 frame header size is 9 bytes. + let frame_header_size = 9; + let remaining_header_bytes = if self.encoded_bytes >= frame_header_size { + 0 } else { - Err(FrameEncoderErr::UnexpectedPayloadType) + frame_header_size - self.encoded_bytes + }; + let bytes_to_write = remaining_header_bytes.min(buf.len()); + + self.iterate_priority_header(frame, buf, bytes_to_write)?; + self.encoded_bytes += bytes_to_write; + if self.encoded_bytes == frame_header_size { + self.state = FrameEncoderState::EncodingPriorityPayload; } + Ok(bytes_to_write) } else { - Err(FrameEncoderErr::NoCurrentFrame) + Err(FrameEncoderErr::UnexpectedPayloadType) } } @@ -1067,25 +1037,22 @@ impl FrameEncoder { } fn encode_priority_payload(&mut self, buf: &mut [u8]) -> Result { - if let Some(frame) = &self.current_frame { - if let Payload::Priority(priority) = frame.payload() { - // HTTP/2 frame header size is 9 bytes. - let frame_header_size = 9; - let remaining_payload_bytes = 5 - (self.encoded_bytes - frame_header_size); - let bytes_to_write = remaining_payload_bytes.min(buf.len()); - - self.iterate_priority_payload(priority, buf, frame_header_size, bytes_to_write)?; - self.encoded_bytes += bytes_to_write; - if self.encoded_bytes == frame_header_size + 5 { - self.state = FrameEncoderState::DataComplete - } + let frame = self.try_retrieve_current_frame()?; + if let Payload::Priority(priority) = frame.payload() { + // HTTP/2 frame header size is 9 bytes. + let frame_header_size = 9; + let remaining_payload_bytes = 5 - (self.encoded_bytes - frame_header_size); + let bytes_to_write = remaining_payload_bytes.min(buf.len()); - Ok(bytes_to_write) - } else { - Err(FrameEncoderErr::UnexpectedPayloadType) + self.iterate_priority_payload(priority, buf, frame_header_size, bytes_to_write)?; + self.encoded_bytes += bytes_to_write; + if self.encoded_bytes == frame_header_size + 5 { + self.state = FrameEncoderState::DataComplete } + + Ok(bytes_to_write) } else { - Err(FrameEncoderErr::NoCurrentFrame) + Err(FrameEncoderErr::UnexpectedPayloadType) } } @@ -1125,107 +1092,97 @@ impl FrameEncoder { } fn encode_rst_stream_frame(&mut self, buf: &mut [u8]) -> Result { - if let Some(frame) = &self.current_frame { - let frame_header_size = 9; - if self.encoded_bytes >= frame_header_size { - return Ok(0); - } + let frame = self.try_retrieve_current_frame()?; + let frame_header_size = 9; + if self.encoded_bytes >= frame_header_size { + return Ok(0); + } - let bytes_to_write = (frame_header_size - self.encoded_bytes).min(buf.len()); + let bytes_to_write = (frame_header_size - self.encoded_bytes).min(buf.len()); - for (buf_index, item) in buf.iter_mut().enumerate().take(bytes_to_write) { - let header_byte_index = self.encoded_bytes + buf_index; - match header_byte_index { - 0..=2 => { - let payload_len = 4; - *item = ((payload_len >> (16 - (8 * buf_index))) & 0xFF) as u8; - } - 3 => { - *item = FrameType::RstStream as u8; - } - 4 => { - *item = frame.flags().bits(); - } - 5..=8 => { - let stream_id = frame.stream_id(); - *item = ((stream_id >> (24 - (8 * (buf_index - 5)))) & 0xFF) as u8; - } - _ => { - return Err(FrameEncoderErr::InternalError); - } + for (buf_index, item) in buf.iter_mut().enumerate().take(bytes_to_write) { + let header_byte_index = self.encoded_bytes + buf_index; + match header_byte_index { + 0..=2 => { + let payload_len = 4; + *item = ((payload_len >> (16 - (8 * buf_index))) & 0xFF) as u8; + } + 3 => { + *item = FrameType::RstStream as u8; + } + 4 => { + *item = frame.flags().bits(); + } + 5..=8 => { + let stream_id = frame.stream_id(); + *item = ((stream_id >> (24 - (8 * (buf_index - 5)))) & 0xFF) as u8; + } + _ => { + return Err(FrameEncoderErr::InternalError); } } - self.encoded_bytes += bytes_to_write; - - if self.encoded_bytes == frame_header_size { - self.state = FrameEncoderState::EncodingRstStreamPayload; - } + } + self.encoded_bytes += bytes_to_write; - Ok(bytes_to_write) - } else { - Err(FrameEncoderErr::NoCurrentFrame) + if self.encoded_bytes == frame_header_size { + self.state = FrameEncoderState::EncodingRstStreamPayload; } + + Ok(bytes_to_write) } fn encode_rst_stream_payload(&mut self, buf: &mut [u8]) -> Result { - if let Some(frame) = &self.current_frame { - if let Payload::RstStream(rst_stream) = frame.payload() { - let frame_header_size = 9; - if self.encoded_bytes < frame_header_size { - return Ok(0); - } - - let payload_size = 4; - let encoded_payload_bytes = self.encoded_bytes - frame_header_size; + let frame = self.try_retrieve_current_frame()?; + if let Payload::RstStream(rst_stream) = frame.payload() { + let frame_header_size = 9; + if self.encoded_bytes < frame_header_size { + return Ok(0); + } - if encoded_payload_bytes >= payload_size { - return Ok(0); - } + let payload_size = 4; + let encoded_payload_bytes = self.encoded_bytes - frame_header_size; - let bytes_to_write = (payload_size - encoded_payload_bytes).min(buf.len()); + if encoded_payload_bytes >= payload_size { + return Ok(0); + } - for (buf_index, item) in buf.iter_mut().enumerate().take(bytes_to_write) { - let payload_byte_index = encoded_payload_bytes + buf_index; - *item = - ((rst_stream.error_code() >> (24 - (8 * payload_byte_index))) & 0xFF) as u8; - } + let bytes_to_write = (payload_size - encoded_payload_bytes).min(buf.len()); - self.encoded_bytes += bytes_to_write; + for (buf_index, item) in buf.iter_mut().enumerate().take(bytes_to_write) { + let payload_byte_index = encoded_payload_bytes + buf_index; + *item = ((rst_stream.error_code() >> (24 - (8 * payload_byte_index))) & 0xFF) as u8; + } - if self.encoded_bytes == frame_header_size + payload_size { - self.state = FrameEncoderState::DataComplete; - } + self.encoded_bytes += bytes_to_write; - Ok(bytes_to_write) - } else { - Err(FrameEncoderErr::UnexpectedPayloadType) + if self.encoded_bytes == frame_header_size + payload_size { + self.state = FrameEncoderState::DataComplete; } + + Ok(bytes_to_write) } else { - Err(FrameEncoderErr::NoCurrentFrame) + Err(FrameEncoderErr::UnexpectedPayloadType) } } fn encode_ping_frame(&mut self, buf: &mut [u8]) -> Result { - if let Some(frame) = &self.current_frame { - if let Payload::Ping(_) = frame.payload() { - let frame_header_size = 9; - let remaining_header_bytes = if self.encoded_bytes >= frame_header_size { - 0 - } else { - frame_header_size - self.encoded_bytes - }; - let bytes_to_write = remaining_header_bytes.min(buf.len()); - self.iterate_ping_header(frame, buf, bytes_to_write)?; - self.encoded_bytes += bytes_to_write; - if self.encoded_bytes == frame_header_size { - self.state = FrameEncoderState::EncodingPingPayload; - } - Ok(bytes_to_write) + let frame = self.try_retrieve_current_frame()?; + if let Payload::Ping(_) = frame.payload() { + let frame_header_size = 9; + let remaining_header_bytes = if self.encoded_bytes >= frame_header_size { + 0 } else { - Err(FrameEncoderErr::UnexpectedPayloadType) + frame_header_size - self.encoded_bytes + }; + let bytes_to_write = remaining_header_bytes.min(buf.len()); + self.iterate_ping_header(frame, buf, bytes_to_write)?; + self.encoded_bytes += bytes_to_write; + if self.encoded_bytes == frame_header_size { + self.state = FrameEncoderState::EncodingPingPayload; } + Ok(bytes_to_write) } else { - Err(FrameEncoderErr::NoCurrentFrame) + Err(FrameEncoderErr::UnexpectedPayloadType) } } @@ -1268,31 +1225,29 @@ impl FrameEncoder { } fn encode_ping_payload(&mut self, buf: &mut [u8]) -> Result { - if let Some(frame) = &self.current_frame { - if let Payload::Ping(ping) = frame.payload() { - // PING payload is always 8 bytes. - let payload_size = 8usize; - let remaining_payload_bytes = - payload_size.saturating_sub(self.encoded_bytes.saturating_sub(9usize)); - let bytes_to_write = remaining_payload_bytes.min(buf.len()); - for (buf_index, buf_item) in buf.iter_mut().enumerate().take(bytes_to_write) { - let payload_byte_index = self - .encoded_bytes - .saturating_sub(9) - .saturating_add(buf_index); - *buf_item = ping.data[payload_byte_index]; - } - self.encoded_bytes += bytes_to_write; - if self.encoded_bytes == 9 + 8 { - self.state = FrameEncoderState::DataComplete; - } - - Ok(bytes_to_write) - } else { - Err(FrameEncoderErr::UnexpectedPayloadType) + let frame = self.try_retrieve_current_frame()?; + + if let Payload::Ping(ping) = frame.payload() { + // PING payload is always 8 bytes. + let payload_size = 8usize; + let remaining_payload_bytes = + payload_size.saturating_sub(self.encoded_bytes.saturating_sub(9usize)); + let bytes_to_write = remaining_payload_bytes.min(buf.len()); + for (buf_index, buf_item) in buf.iter_mut().enumerate().take(bytes_to_write) { + let payload_byte_index = self + .encoded_bytes + .saturating_sub(9) + .saturating_add(buf_index); + *buf_item = ping.data[payload_byte_index]; + } + self.encoded_bytes += bytes_to_write; + if self.encoded_bytes == 9 + 8 { + self.state = FrameEncoderState::DataComplete; } + + Ok(bytes_to_write) } else { - Err(FrameEncoderErr::NoCurrentFrame) + Err(FrameEncoderErr::UnexpectedPayloadType) } } -- Gitee