36 Stage Practical Exercise 4 Building a Simple Kv Server for Network Processing

Stage 36 Practical Operation (4): Building a Simple KV Server - Network Handling #

Hello, I’m Chen Tian.

After the construction and optimization in the basic and advanced sections, by now, our KV server core functionality has become quite comprehensive. Have you noticed that we have been using a mysterious async-prost library, and magically completed the packing and unpacking of TCP frames? How was it done?

async-prost is a library I made, inspired by Jonhoo’s async-bincode, which processes protobuf frames. It can be adapted to various network protocols, including TCP/WebSocket/HTTP2, etc. Considering its versatility, it operates at a relatively high level of abstraction, using a large number of generic parameters. The main process is shown in the following diagram: - Image

The main idea is to add a header to provide the frame length when serializing data, and when deserializing, read the header first to get the length, then read the corresponding data. Those interested can check out the code, which we won’t go into detail here.

Today’s challenge is to try handling the logic of packing and unpacking ourselves, based on the completed KV server from last time, without depending on async-prost. If you master this skill, paired with protobuf, you’ll be able to design any protocol that can carry real business.

How to Define the Protocol’s Frame? #

protobuf solved the issue of how to define protocol messages, but distinguishing one message from another is a tricky matter. We need to define a suitable delimiter.

Delimiter + message data, that’s a Frame. Previously in the 28th lecture on network development [that one], I briefly talked about how to define a frame.

Many TCP-based protocols use \r\n as the delimiter, such as FTP; there are also protocols that use message length as the delimiter, such as gRPC; others use both, like Redis’s RESP; more complex ones like HTTP, where headers are separated by \r\n, and headers/bodies by \r\n\r\n, and the header provides the body length, and so on.

Delimiters like “\r\n” are suitable for protocols where the message is ASCII data; binary data protocols typically use length as the delimiter. Our KV Server carries binary protobuf, so we place a length before the payload as the frame’s delimiter.

How big should this length be? If we use 2 bytes, then the maximum payload is 64k; with 4 bytes, the payload can go up to 4G. 4 bytes usually suffice for general applications. If you want more flexibility, you can use varint.

Tokio has a tokio-util library that takes care of most of our needs for packing and unpacking frames related to frame handling, including LinesDelimited (dealing with \r\n delimiters) and LengthDelimited (handling length delimiters). We can try it out using its LengthDelimitedCodec.

First, add a dependency in Cargo.toml:

[dev-dependencies]
...
tokio-util = { version = "0.6", features = ["codec"]}
...

Then create examples/server_with_codec.rs file and add the following code:

use anyhow::Result;
use futures::prelude::*;
use kv2::{CommandRequest, MemTable, Service, ServiceInner};
use prost::Message;
use tokio::net::TcpListener;
use tokio_util::codec::{Framed, LengthDelimitedCodec};
use tracing::info;

#[tokio::main]
async fn main() -> Result<()> {
    tracing_subscriber::fmt::init();
    let service: Service = ServiceInner::new(MemTable::new()).into();
    let addr = "127.0.0.1:9527";
    let listener = TcpListener::bind(addr).await?;
    info!("Start listening on {}", addr);
    loop {
        let (stream, addr) = listener.accept().await?;
        info!("Client {:?} connected", addr);
        let svc = service.clone();
        tokio::spawn(async move {
            let mut stream = Framed::new(stream, LengthDelimitedCodec::new());
            while let Some(Ok(mut buf)) = stream.next().await {
                let cmd = CommandRequest::decode(&buf[..]).unwrap();
                info!("Got a new command: {:?}", cmd);
                let res = svc.execute(cmd);
                buf.clear();
                res.encode(&mut buf).unwrap();
                stream.send(buf.freeze()).await.unwrap();
            }
            info!("Client {:?} disconnected", addr);
        });
    }
}

You can compare the differences with the previous examples/server.rs, with the main change being this line:

// let mut stream = AsyncProstStream::<_, CommandRequest, CommandResponse, _>::from(stream).for_async();
let mut stream = Framed::new(stream, LengthDelimitedCodec::new());

After completing this, open a command line window and run: RUST_LOG=info cargo run --example server_with_codec --quiet. Then, in another command line window, run: RUST_LOG=info cargo run --example client --quiet. At this time, both the server and the client receive each other’s requests and responses, and everything works normally.

Aren’t you a bit puzzled as to why the client can communicate with the server without any modifications? That’s because, in the current use case, an AsyncProst client is compatible with LengthDelimitedCodec.

How to Write Code that Handles Frames? #

LengthDelimitedCodec is very useful, and its code isn’t complicated. It’s highly recommended that you study it when you have time. Since this lecture mainly revolves around network development, let’s try writing our own code for Frame handling as well.

Based on the analysis earlier, we add a 4-byte length before the protobuf payload. This way, when the peer reads the data, it can first read 4 bytes, and based on the length read, further read data that satisfies this length, then use the corresponding data structure to unpack.

To be more realistic, we’ll use the highest bit of the 4-byte length as a signal for whether it’s compressed. If set, it means that the subsequent payload is protobuf compressed with gzip; otherwise, it’s protobuf directly: - Image

As usual, let’s define the trait for handling this logic first:

pub trait FrameCoder
where
    Self: Message + Sized + Default,
{
    /// Encode a Message into a frame
    fn encode_frame(&self, buf: &mut BytesMut) -> Result<(), KvError>;
    /// Decode a complete frame into a Message
    fn decode_frame(buf: &mut BytesMut) -> Result<Self, KvError>;
}

Two methods are defined:

  • encode_frame() can wrap a message, such as CommandRequest, into a frame, written into the passed-in BytesMut;
  • decode_frame() can unwrap a complete frame stored in BytesMut into a message, such as CommandRequest.

To implement this trait, Self needs to implement prost::Message, its size must be fixed, and it must implement Default (as required by prost).

Alright, let’s write the implementation code. First, create a src/network directory and add two files within it: mod.rs and frame.rs. Then, in src/network/mod.rs import src/network/frame.rs:

mod frame;
pub use frame::FrameCoder;

At the same time, import network in lib.rs:

mod network;
pub use network::*;

Since we need to handle gzip compression, we also need to add flate2 in Cargo.toml. Additionally, because today’s session introduced network-related operations and data structures, we need to move tokio from dev-dependencies to dependencies. For simplicity’s sake, let’s use full features:

[dependencies]
...
flate2 = "1" # gzip compression
...
tokio = { version = "1", features = ["full"] } # Asynchronous network library
...

Then, add trait and code to implement the trait in src/network/frame.rs:

use std::io::{Read, Write};

use crate::{CommandRequest, CommandResponse, KvError};
use bytes::{Buf, BufMut, BytesMut};
use flate2::{read::GzDecoder, write::GzEncoder, Compression};
use prost::Message;
use tokio::io::{AsyncRead, AsyncReadExt};
use tracing::debug;

/// The entire length occupies 4 bytes
pub const LEN_LEN: usize = 4;
/// The length takes up 31 bits, so the maximum frame is 2GB
const MAX_FRAME: usize = 2 * 1024 * 1024 * 1024;
/// If the payload exceeds 1436 bytes, it will be compressed
const COMPRESSION_LIMIT: usize = 1436;
/// The bit representing compression (the highest bit of the entire 4-byte length)
const COMPRESSION_BIT: usize = 1 << 31;

/// Handling the encoding/decoding of Frames
pub trait FrameCoder
where
    Self: Message + Sized + Default,
{
    /// Encode a Message into a frame
    fn encode_frame(&self, buf: &mut BytesMut) -> Result<(), KvError> {
        let size = self.encoded_len();

        if size >= MAX_FRAME {
            return Err(KvError::FrameError);
        }

        // We first write the length, and if compression is needed, rewrite the length afterwards
        buf.put_u32(size as _);

        if size > COMPRESSION_LIMIT {
            let mut buf1 = Vec::with_capacity(size);
            self.encode(&mut buf1)?;
    
            // BytesMut supports logical splits (which can be unsplit later)
            // We first remove the 4 bytes of the length
            let payload = buf.split_off(LEN_LEN);
            buf.clear();
    
            // Deal with gzip compression, see flate2 documentation for details
            let mut encoder = GzEncoder::new(payload.writer(), Compression::default());
            encoder.write_all(&buf1[..])?;
    
            // After compression, get the BytesMut back from the gzip encoder
            let payload = encoder.finish()?.into_inner();
            debug!("Encode a frame: size {}({})", size, payload.len());
    
            // Write the compressed length
            buf.put_u32((payload.len() | COMPRESSION_BIT) as _);
    
            // Merge BytesMut back
            buf.unsplit(payload);
    
            Ok(())
        } else {
            self.encode(buf)?;
            Ok(())
        }
    }

    /// Decode a complete frame into a Message
    fn decode_frame(buf: &mut BytesMut) -> Result<Self, KvError> {
        // Take the first 4 bytes and extract the length and compression bit
        let header = buf.get_u32() as usize;
        let (len, compressed) = decode_header(header);
        debug!("Got a frame: msg len {}, compressed {}", len, compressed);

        if compressed {
            // Decompress
            let mut decoder = GzDecoder::new(&buf[..len]);
            let mut buf1 = Vec::with_capacity(len * 2);
            decoder.read_to_end(&mut buf1)?;
            buf.advance(len);
    
            // Decode into the corresponding message
            Ok(Self::decode(&buf1[..buf1.len()])?)
        } else {
            let msg = Self::decode(&buf[..len])?;
            buf.advance(len);
            Ok(msg)
        }
    }
}

impl FrameCoder for CommandRequest {}
impl FrameCoder for CommandResponse {}

fn decode_header(header: usize) -> (usize, bool) {
    let len = header & !COMPRESSION_BIT;
    let compressed = header & COMPRESSION_BIT == COMPRESSION_BIT;
    (len, compressed)
}

This code isn’t hard to understand. We provided a default implementation for FrameCoder and then made empty implementations for CommandRequest/CommandResponse. It uses BytesMut from the bytes library and GzEncoder/GzDecoder we’ve introduced. You can explore the usage of these data types following the [20 Lectures] source code reading method. Lastly, we wrote an auxiliary function decode_header() to make the decode_frame() code clearer.

If you’re puzzled as to why COMPRESSION_LIMIT is set to 1436?

That’s because Ethernet’s MTU is 1500, minus the IP header at 20 bytes and TCP header at 20 bytes, leaving 1460. Typically, TCP packets will include some Options (like timestamp), and IP packets may contain them as well, so we reserve 20 bytes. Subtracting the 4-byte length, you get 1436, the maximum message length without fragmentation. If it’s larger than this, it might lead to fragmentation, so we might as well compress it.

Now, CommandRequest/CommandResponse can handle frames at the frame level. Let’s write some tests to verify that it works. Again in src/network/frame.rs, add test code:

#[cfg(test)]
mod tests {
    use super::*;
    use crate::Value;
    use bytes::Bytes;

    #[test]
    fn command_request_encode_decode_should_work() {
        let mut buf = BytesMut::new();

        let cmd = CommandRequest::new_hdel("t1", "k1");
        cmd.encode_frame(&mut buf).unwrap();

        // The highest bit isn't set
        assert_eq!(is_compressed(&buf), false);

        let cmd1 = CommandRequest::decode_frame(&mut buf).unwrap();
        assert_eq!(cmd, cmd1);
    }

    #[test]
    fn command_response_encode_decode_should_work() {
        let mut buf = BytesMut::new();

        let values: Vec<Value> = vec![1.into(), "hello".into(), b"data".into()];
        let res: CommandResponse = values.into();
        res.encode_frame(&mut buf).unwrap();

        // The highest bit isn't set
        assert_eq!(is_compressed(&buf), false);

        let res1 = CommandResponse::decode_frame(&mut buf).unwrap();
        assert_eq!(res, res1);
    }

    #[test]
    fn command_response_compressed_encode_decode_should_work() {
        let mut buf = BytesMut::new();

        let value: Value = Bytes::from(vec![0u8; COMPRESSION_LIMIT + 1]).into();
        let res: CommandResponse = value.into();
        res.encode_frame(&mut buf).unwrap();

        // The highest bit is set
        assert_eq!(is_compressed(&buf), true);

        let res1 = CommandResponse::decode_frame(&mut buf).unwrap();
        assert_eq!(res, res1);
    }

    fn is_compressed(data: &[u8]) -> bool {
        if let &[v] = &data[..1] {
            v >> 7 == 1
        } else {
            false
        }
    }
}

This test code involves converting from [u8; N] to Value (b"data".into()) and from Bytes to Value, so we need to add the corresponding implementation of the From trait in src/pb/mod.rs:

impl<const N: usize> From<&[u8; N]> for Value {
    fn from(buf: &[u8; N]) -> Self {
        Bytes::copy_from_slice(&buf[..]).into()
    }
}

impl From<Bytes> for Value {
    fn from(buf: Bytes) -> Self {
        Self {
            value: Some(value::Value::Binary(buf)),
        }
    }
}

Running cargo test, all tests should pass.

Here, we’ve completed the serialization (encode_frame) and deserialization (decode_frame) of Frames and ensured their correctness through tests. When doing network development, it’s essential to separate implementation logic and IO as much as possible, which helps with testability and handling future changes to the IO layer. Currently, this code does not touch any socket IO-related content; it’s purely logical. Next, we need to connect it with the TcpStream we use to handle the server-client.

Before writing more network-related code, one question remains: How do we get the BytesMut used by decode_frame() out of the socket? Evidently, you first read 4 bytes, extract the length N, and then read N more bytes. This detail is closely related to frames, so another auxiliary function read_frame() is needed in src/network/frame.rs:

/// Read a complete frame from the stream
pub async fn read_frame<S>(stream: &mut S, buf: &mut BytesMut) -> Result<(), KvError>
where
    S: AsyncRead + Unpin + Send,
{
    let header = stream.read_u32().await? as usize;
    let (len, _compressed) = decode_header(header);
    // If there's not enough memory, at least allocate enough for one frame
    buf.reserve(LEN_LEN + len);
    buf.put_u32(header as _);
    // advance_mut is unsafe because the memory from the current position pos to pos + len 
    // isn't initialized. We're reserving this space to read from the stream.
    // Once it's read, it's initialized. So using it in this way is safe.
    unsafe { buf.advance_mut(len) };
    stream.read_exact(&mut buf[LEN_LEN..]).await?;
    Ok(())
}

When writing read_frame(), we don’t want it to be specific to TcpStream, which would be inflexible. Thus, we use the generic parameter S, requiring the provided S to satisfy AsyncRead + Unpin + Send. Let’s look at these three constraints.

AsyncRead is a trait under tokio for asynchronous reading, with a method poll_read():

pub trait AsyncRead {
    fn poll_read(
        self: Pin<&mut Self>, 
        cx: &mut Context<'_>, 
        buf: &mut ReadBuf<'_>
    ) -> Poll<Result<()>>;
}

Once a structure implements AsyncRead, it automatically implements AsyncReadExt with up to 29 helper methods:

impl<R: AsyncRead + ?Sized> AsyncReadExt for R {}

We haven’t formally learned how to handle asynchrony yet, but we