37 Stage Practical Operation 5 Build a Simple Kv Server Network Security

37 Practical Session (5): Building a Simple KV Server - Network Security #

Hello, I’m Chen Tian.

In the last lecture, we completed the entire network construction of the KV server. Security is an inseparable part of networking, and we must consider network security when building applications. Of course, if we don’t aim for extreme performance, we could use systems like gRPC which provides good performance and also ensures security through TLS.

So, how do we use TLS to ensure the security between client and server when our application’s architecture is based on TCP?

Generating x509 Certificates #

To use TLS, we first need x509 certificates. TLS requires x509 certificates to allow clients to verify the server is a trusted server, and even servers to verify clients, confirming that the client is a trusted one.

For testing convenience, we need to be able to generate our own CA certificates, server certificates, and even client certificates. I won’t go into the details of certificate generation today, but I previously created a library called certify, which can be used to generate various certificates. We can add this library to our Cargo.toml:

[dev-dependencies]
...
certify = "0.3"
...

Then, create a fixtures directory at the root to store the certificates, and create a examples/gen_cert.rs file with the following code:

use anyhow::Result;
use certify::{generate_ca, generate_cert, load_ca, CertType, CA};
use tokio::fs;

struct CertPem {
    cert_type: CertType,
    cert: String,
    key: String,
}

#[tokio::main]
async fn main() -> Result<()> {
    let pem = create_ca()?;
    gen_files(&pem).await?;
    let ca = load_ca(&pem.cert, &pem.key)?;
    let pem = create_cert(&ca, &["kvserver.acme.inc"], "Acme KV server", false)?;
    gen_files(&pem).await?;
    let pem = create_cert(&ca, &[], "awesome-device-id", true)?;
    gen_files(&pem).await?;
    Ok(())
}

fn create_ca() -> Result<CertPem> {
    let (cert, key) = generate_ca(
        &["acme.inc"],
        "CN",
        "Acme Inc.",
        "Acme CA",
        None,
        Some(10 * 365),
    )?;
    Ok(CertPem {
        cert_type: CertType::CA,
        cert,
        key,
    })
}

fn create_cert(ca: &CA, domains: &[&str], cn: &str, is_client: bool) -> Result<CertPem> {
    let (days, cert_type) = if is_client {
        (Some(365), CertType::Client)
    } else {
        (Some(5 * 365), CertType::Server)
    };
    let (cert, key) = generate_cert(ca, domains, "CN", "Acme Inc.", cn, None, is_client, days)?;

    Ok(CertPem {
        cert_type,
        cert,
        key,
    })
}

async fn gen_files(pem: &CertPem) -> Result<()> {
    let name = match pem.cert_type {
        CertType::Client => "client",
        CertType::Server => "server",
        CertType::CA => "ca",
    };
    fs::write(format!("fixtures/{}.cert", name), pem.cert.as_bytes()).await?;
    fs::write(format!("fixtures/{}.key", name), pem.key.as_bytes()).await?;
    Ok(())
}

This code is simple; it first generates a CA certificate and then generates server and client certificates, all stored in the newly created fixtures directory. You need to run cargo run --examples gen_cert to execute this command, and we will use these certificates and keys in the tests later.

Using TLS in the KV Server #

TLS is currently the main application layer security protocol, widely used to protect protocols that are built on top of TCP, such as MySQL, HTTP, and many others. Even if a network application is used only within an internal network, it is dangerous to operate without security protocols to protect it.

The following shows the process of the TLS handshake between the client and the server, sourced from wikimedia:- Image

For the KV server, after using TLS, the entire protocol data encapsulation is shown in the following diagram:- Image

So today, we need to add TLS support to the network processing we did last time, ensuring that the communication between the KV server’s client and server is strictly protected and as secure as possible from eavesdropping, tampering, and counterfeiting by third parties.

Let’s see how to implement TLS.

Many people get nervous when they hear about TLS or SSL, because there have been many unpleasant experiences with openssl in the past. The openssl codebase is too large and unwieldy, the API is not user-friendly, and compiling and linking are troublesome.

However, the experience of using TLS in Rust is quite favorable. Rust has excellent wrappers for openssl, as well as rustls, which is written in Rust and doesn’t rely on openssl. Tokio goes further to provide tls support consistent with the tokio ecosystem, with both openssl and rustls versions available.

Today we’ll use tokio-rustls to write TLS support. I believe that during the implementation process, you will see how effortless it is to add the TLS protocol to protect the network layer in an application.

First, add tokio-rustls to your Cargo.toml:

[dependencies]
...
tokio-rustls = "0.22"
...

Then create src/network/tls.rs with the following code (remember to include this file in src/network/mod.rs):

use std::io::Cursor;
use std::sync::Arc;

use tokio::io::{AsyncRead, AsyncWrite};
use tokio_rustls::rustls::{internal::pemfile, Certificate, ClientConfig, ServerConfig};
use tokio_rustls::rustls::{AllowAnyAuthenticatedClient, NoClientAuth, PrivateKey, RootCertStore};
use tokio_rustls::webpki::DNSNameRef;
use tokio_rustls::TlsConnector;
use tokio_rustls::{
    client::TlsStream as ClientTlsStream, server::TlsStream as ServerTlsStream, TlsAcceptor,
};

use crate::KvError;

/// KV Server self-defined ALPN (Application-Layer Protocol Negotiation)
const ALPN_KV: &str = "kv";

/// Holds TLS ServerConfig and provides a method to transform the underlying protocol into TLS
#[derive(Clone)]
pub struct TlsServerAcceptor {
    inner: Arc<ServerConfig>,
}

/// Holds TLS Client and provides a method to transform the underlying protocol into TLS
#[derive(Clone)]
pub struct TlsClientConnector {
    pub config: Arc<ClientConfig>,
    pub domain: Arc<String>,
}

impl TlsClientConnector {
    /// Load client cert/CA cert and generate ClientConfig
    pub fn new(
        domain: impl Into<String>,
        identity: Option<(&str, &str)>,
        server_ca: Option<&str>,
    ) -> Result<Self, KvError> {
        let mut config = ClientConfig::new();

        // If a client certificate is provided, load it
        if let Some((cert, key)) = identity {
            let certs = load_certs(cert)?;
            let key = load_key(key)?;
            config.set_single_client_cert(certs, key)?;
        }

        // Load the local trusted root certificate chain
        config.root_store = match rustls_native_certs::load_native_certs() {
            Ok(store) | Err((Some(store), _)) => store,
            Err((None, error)) => return Err(error.into()),
        };

        // If the CA certificate that signed the server exists, load it, so even if the server certificate is not in the root certificate chain, the CA certificate can verify it
        if let Some(cert) = server_ca {
            let mut buf = Cursor::new(cert);
            config.root_store.add_pem_file(&mut buf).unwrap();
        }

        Ok(Self {
            config: Arc::new(config),
            domain: Arc::new(domain.into()),
        })
    }

    /// Initiate TLS protocol, converting the underlying stream into a TLS stream
    pub async fn connect<S>(&self, stream: S) -> Result<ClientTlsStream<S>, KvError>
    where
        S: AsyncRead + AsyncWrite + Unpin + Send,
    {
        let dns = DNSNameRef::try_from_ascii_str(self.domain.as_str())
            .map_err(|_| KvError::Internal("Invalid DNS name".into()))?;

        let stream = TlsConnector::from(self.config.clone())
            .connect(dns, stream)
            .await?;

        Ok(stream)
    }
}

impl TlsServerAcceptor {
    /// Load server cert/CA cert and generate ServerConfig
    pub fn new(cert: &str, key: &str, client_ca: Option<&str>) -> Result<Self, KvError> {
        let certs = load_certs(cert)?;
        let key = load_key(key)?;

        let mut config = match client_ca {
            None => ServerConfig::new(NoClientAuth::new()),
            Some(cert) => {
                // If the client certificate is issued by a certain CA, then load this CA certificate into the trust chain
                let mut cert = Cursor::new(cert);
                let mut client_root_cert_store = RootCertStore::empty();
                client_root_cert_store
                    .add_pem_file(&mut cert)
                    .map_err(|_| KvError::CertifcateParseError("CA", "cert"))?;

                let client_auth = AllowAnyAuthenticatedClient::new(client_root_cert_store);
                ServerConfig::new(client_auth)
            }
        };

        // Load the server certificate
        config
            .set_single_cert(certs, key)
            .map_err(|_| KvError::CertifcateParseError("server", "cert"))?;
        config.set_protocols(&[Vec::from(&ALPN_KV[..])]);

        Ok(Self {
            inner: Arc::new(config),
        })
    }

    /// Initiate TLS protocol, converting the underlying stream into a TLS stream
    pub async fn accept<S>(&self, stream: S) -> Result<ServerTlsStream<S>, KvError>
    where
        S: AsyncRead + AsyncWrite + Unpin + Send,
    {
        let acceptor = TlsAcceptor::from(self.inner.clone());
        Ok(acceptor.accept(stream).await?)
    }
}

fn load_certs(cert: &str) -> Result<Vec<Certificate>, KvError> {
    let mut cert = Cursor::new(cert);
    pemfile::certs(&mut cert).map_err(|_| KvError::CertifcateParseError("server", "cert"))
}

fn load_key(key: &str) -> Result<PrivateKey, KvError> {
    let mut cursor = Cursor::new(key);

    // First try to load the private key using PKCS8
    if let Ok(mut keys) = pemfile::pkcs8_private_keys(&mut cursor) {
        if !keys.is_empty() {
            return Ok(keys.remove(0));
        }
    }

    // Then try to load the RSA key
    cursor.set_position(0);
    if let Ok(mut keys) = pemfile::rsa_private_keys(&mut cursor) {
        if !keys.is_empty() {
            return Ok(keys.remove(0));
        }
    }

    // Unsupported private key type
    Err(KvError::CertifcateParseError("private", "key"))
}

This code creates two structures, TlsServerAcceptor/TlsClientConnector. Although there are over 100 lines of code, the main task is just to generate the tokio-tls required ServerConfig/ClientConfig based on the provided certificates.

Because TLS needs to verify the CA of the certificates, CA certificates must also be loaded. Although in web development we mostly use only server certificates, TLS actually supports mutual verification where the server can also check if the client’s certificate is issued by a CA it recognizes.

After dealing with the config, the core logic of this code are the connect() method for the client and the accept() method for the server. They both accept a stream that satisfies AsyncRead + AsyncWrite + Unpin + Send:

/// Initiate TLS protocol, converting the underlying stream into a TLS stream
pub async fn connect<S>(&self, stream: S) -> Result<ClientTlsStream<S>, KvError>
where
    S: AsyncRead + AsyncWrite + Unpin + Send,
{
    let dns = DNSNameRef::try_from_ascii_str(self.domain.as_str())
        .map_err(|_| KvError::Internal("Invalid DNS name".into()))?;

    let stream = TlsConnector::from(self.config.clone())
        .connect(dns, stream)
        .await?;

    Ok(stream)
}

/// Initiate TLS protocol, converting the underlying stream into a TLS stream
pub async fn accept<S>(&self, stream: S) -> Result<ServerTlsStream<S>, KvError>
where
    S: AsyncRead + AsyncWrite + Unpin + Send,
{
    let acceptor = TlsAcceptor::from(self.inner.clone());
    Ok(acceptor.accept(stream).await?)
}

After handling the connect/accept with the TlsConnector or TlsAcceptor, we obtain a TlsStream, which also meets AsyncRead + AsyncWrite + Unpin + Send, and subsequent operations can be completed on it. It’s relaxed to manage TLS in just over a hundred lines of code, isn’t it?

Let’s continue writing a test:

#[cfg(test)]
mod tests {

    use std::net::SocketAddr;

    use super::*;
    use anyhow::Result;
    use tokio::{
        io::{AsyncReadExt, AsyncWriteExt},
        net::{TcpListener, TcpStream},
    };

    const CA_CERT: &str = include_str!("../../fixtures/ca.cert");
    const CLIENT_CERT: &str = include_str!("../../fixtures/client.cert");
    const CLIENT_KEY: &str = include_str!("../../fixtures/client.key");
    const SERVER_CERT: &str = include_str!("../../fixtures/server.cert");
    const SERVER_KEY: &str = include_str!("../../fixtures/server.key");

    #[tokio::test]
    async fn tls_should_work() -> Result<()> {
        let ca = Some(CA_CERT);

        let addr = start_server(None).await?;

        let connector = TlsClientConnector::new("kvserver.acme.inc", None, ca)?;
        let stream = TcpStream::connect(addr).await?;
        let mut stream = connector.connect(stream).await?;
        stream.write_all(b"hello world!").await?;
        let mut buf = [0; 12];
        stream.read_exact(&mut buf).await?;
        assert_eq!(&buf, b"hello world!");

        Ok(())
    }

    #[tokio::test]
    async fn tls_with_client_cert_should_work() -> Result<()> {
        let client_identity = Some((CLIENT_CERT, CLIENT_KEY));
        let ca = Some(CA_CERT);

        let addr = start_server(ca.clone()).await?;

        let connector = TlsClientConnector::new("kvserver.acme.inc", client_identity, ca)?;
        let stream = TcpStream::connect(addr).await?;
        let mut stream = connector.connect(stream).await?;
        stream.write_all(b"hello world!").await?;
        let mut buf = [0; 12];
        stream.read_exact(&mut buf).await?;
        assert_eq!(&buf, b"hello world!");

        Ok(())
    }

    #[tokio::test]
    async fn tls_with_bad_domain_should_not_work() -> Result<()> {
        let addr = start_server(None).await?;

        let connector = TlsClientConnector::new("kvserver1.acme.inc", None, Some(CA_CERT))?;
        let stream = TcpStream::connect(addr).await?;
        let result = connector.connect(stream).await;

        assert!(result.is_err());

        Ok(())
    }

    async fn start_server(ca: Option<&str>) -> Result<SocketAddr> {
        let acceptor = TlsServerAcceptor::new(SERVER_CERT, SERVER_KEY, ca)?;

        let echo = TcpListener::bind("127.0.0.1:0").await.unwrap();
        the addr = echo.local_addr().unwrap();

        tokio::spawn(async move {
            let (stream, _) = echo.accept().await.unwrap();
            let mut stream = acceptor.accept(stream).await.unwrap();
            let mut buf = [0; 12];
            stream.read_exact(&mut buf).await.unwrap();
            stream.write_all(&buf).await.unwrap();
        });

        Ok(addr)
    }
}

This test code uses the include_str! macro to load files into strings at compile-time, storing them in the RODATA segment. We test three scenarios: standard TLS connections, TLS connections with a client certificate, and a bad domain provided by the client. Running cargo test, all tests pass.

Making KV Client/Server Support TLS #

Once TLS tests pass, we can add TLS support to kvs and kvc.

Thanks to our interface design along the way, especially since ProstClientStream/ProstServerStream accept generic parameters, TLS code can be seamlessly integrated. For example, on the client side:

// Newly added code
let connector = TlsClientConnector::new("kvserver.acme.inc", None, Some(ca_cert))?;

let stream = TcpStream::connect(addr).await?;

// Newly added code
let stream = connector.connect(stream).await?;

let mut client = ProstClientStream::new(stream);

Only one line needs to be changed to pass the TcpStream to ProstClientStream, switching it to TlsStream to seamlessly support TLS.

Let’s look at the complete code, src/server.rs:

use anyhow::Result;
use kv3::{MemTable, ProstServerStream, Service, ServiceInner, TlsServerAcceptor};
use tokio::net::TcpListener;
use tracing::info;

#[tokio::main]
async fn main() -> Result<()> {
    tracing_subscriber::fmt::init();
    let addr = "127.0.0.1: 9527";

    // Obtain from configuration file later
    let server_cert = include_str!("../fixtures/server.cert");
    let