35 Practical Project How to Implement a Basic Mpsc Channel

35 Practical Project: How to Implement a Basic MPSC Channel #

Hello, I’m Chen Tian.

Through the last two lectures, I believe you have realized that although concurrent primitives seem to be very low-level and mysterious, they are not as difficult to implement as imagined, especially in Rust. In [Lecture 33], we implemented a simple SpinLock with a few dozen lines of code.

You might find it unsatisfactory, and SpinLocks are not a commonly used concurrency primitive. So today, let’s try to implement a very widely used MPSC channel, shall we?

Previously we discussed how to use an MPSC channel on a search engine’s Index writer: there can be many contexts for updating the index (it could be threads or asynchronous tasks), but the IndexWriter has to be unique. To avoid locking when accessing the IndexWriter, we can use an MPSC channel to send messages in multiple contexts and then read those messages in the thread that exclusively owns the IndexWriter, which is very efficient.

Okay, let’s take a look at the basic functionality of the MPSC channel we will implement today. For simplicity’s sake, we’re only concerned with the unbounded MPSC channel. That is to say, when the queue’s capacity is insufficient, it will automatically expand, so at any given time, the producer can write data without being blocked, but when there are no data in the queue, the consumer will be blocked: - Image

Test-Driven Design #

Previously we designed interfaces and data structures from the perspective of requirements. Today, we will try another method, completely from the user’s perspective, using examples (tests) to drive the design of interfaces and data structures.

Requirement 1 #

To implement the MPSC channel we just talked about, what requirements do we have? First, producers can generate data, consumers can consume the produced data, that is, the basic send/recv. We will describe this requirement with the following unit test 1:

#[test]
fn channel_should_work() {
    let (mut s, mut r) = unbounded();
    s.send("hello world!".to_string()).unwrap();
    let msg = r.recv().unwrap();
    assert_eq!(msg, "hello world!");
}

Here, through the unbounded() method, a sender and a receiver can be created. Sender has a send() method to send data, receiver has a recv() method to receive data. The overall interface is designed to be consistent with std::sync::mpsc to reduce the cognitive burden for users.

To implement such an interface, what data structure do we need? First, the producer and consumer will share a queue between them. Last time, we mentioned that we could use VecDeque. Obviously, this queue needs to be protected by a mutex during insertion and removal. So, we could have a structure like this:

struct Shared<T> {
    queue: Mutex<VecDeque<T>>,
}

pub struct Sender<T> {
    shared: Arc<Shared<T>>,
}

pub struct Receiver<T> {
    shared: Arc<Shared<T>>,
}

This data structure should be able to satisfy unit test 1.

Requirement 2 #

Since multiple senders are allowed to send data to the channel, we describe this requirement with unit test 2:

#[test]
fn multiple_senders_should_work() {
    let (mut s, mut r) = unbounded();
    let mut s1 = s.clone();
    let mut s2 = s.clone();
    let t = thread::spawn(move || {
        s.send(1).unwrap();
    });
    let t1 = thread::spawn(move || {
        s1.send(2).unwrap();
    });
    let t2 = thread::spawn(move || {
        s2.send(3).unwrap();
    });
    for handle in [t, t1, t2] {
        handle.join().unwrap();
    }

    let mut result = [r.recv().unwrap(), r.recv().unwrap(), r.recv().unwrap()];
    // In this test, the order of data arrival is non-deterministic, so we sort them before assert
    result.sort();

    assert_eq!(result, [1, 2, 3]);
}

This requirement can be satisfied by the previous data structure. We just need the Sender to implement the Clone trait. However, writing this test, we feel a bit awkward since this line of code keeps being repeated:

let mut result = [r.recv().unwrap(), r.recv().unwrap(), r.recv().unwrap()];

Note that DRY (Don’t Repeat Yourself) is also important in test code, which we have emphasized before. So, when writing this test, you might think, could we provide an implementation for Iterator? Let’s hold on to this idea.

Requirement 3 #

Next, consider the requirement that when the queue is empty, the thread where the receiver resides will be blocked. How do we test this requirement? This isn’t simple; we don’t have a direct way to detect the state of the thread.

However, we can indirectly infer that the thread is blocked by detecting “whether the thread has exited”. The reason is straightforward: if the thread is neither working nor exiting, it must have been blocked. After being blocked, we continue to send data, the consumer’s thread will be woken up and continue to work, so the queue length should be 0 in the end. Let’s look at unit test 3:

#[test]
fn receiver_should_be_blocked_when_nothing_to_read() {
    let (mut s, r) = unbounded();
    let mut s1 = s.clone();
    thread::spawn(move || {
        for (idx, i) in r.into_iter().enumerate() {
            // If data is read, ensure it is consistent with the sent data
            assert_eq!(idx, i);
        }
        // If there's nothing to read, it should sleep, so it shouldn't execute this line. Execution here indicates a logic error
        assert!(false);
    });

    thread::spawn(move || {
        for i in 0..100usize {
            s.send(i).unwrap();
        }
    });

    // 1ms is enough for the producer to send 100 messages and the consumer to consume 100 messages and be blocked
    thread::sleep(Duration::from_millis(1));

    // Send data again to wake up the consumer
    for i in 100..200usize {
        s1.send(i).unwrap();
    }

    // Leave some time for the receiver to process
    thread::sleep(Duration::from_millis(1));

    // If the receiver is awakened normally to process, then the data in the queue will be read off entirely
    assert_eq!(s1.total_queued_items(), 0);
}

In this test code, we assume that receiver implements Iterator, and we also assume that sender provides a method total_queued_items(). These can be handled during implementation.

You can take some time to carefully look at this code and think about the processing logic. Although the code is straightforward and easy to understand, turning a complete requirement into appropriate test code requires quite a bit of thought.

Good, to support blocking when the queue is empty, we need to use Condvar. So Shared needs to be modified a bit:

struct Shared<T> {
    queue: Mutex<VecDeque<T>>,
    available: Condvar,
}

This way, when implementing the Receiver’s recv() method, we can block the thread if no data can be read:

// Lock acquired
let mut inner = self.shared.queue.lock().unwrap();
// ... assuming no data can be read
// Current thread suspended using condvar and MutexGuard
self.shared.available.wait(inner)

Requirement 4 #

Continuing with the thought of having multiple senders, what if all Senders leave the scope now, but the Receiver continues to receive until there’s no more data to read? Shouldn’t an error occur to let the caller know that there’s no longer a producer on the other side of the channel, and no more data can be read?

Let’s write unit test 4:

#[test]
fn last_sender_drop_should_error_when_receive() {
    let (s, mut r) = unbounded();
    let s1 = s.clone();
    let senders = [s, s1];
    let total = senders.len();

    // Use and drop the sender
    for mut sender in senders {
        thread::spawn(move || {
            sender.send("hello").unwrap();
            // Sender is dropped here
        })
        .join()
        .unwrap();
    }

    // Although there are no senders left, the receiver can still receive data already in the queue
    for _ in 0..total {
        r.recv().unwrap();
    }

    // However, reading more data will result in an error
    assert!(r.recv().is_err());
}

This test is straightforward. You can imagine what kind of data structure could achieve this purpose.

Firstly, each time Clone is called, the count of Senders should be increased; when Sender Drops, this count should be decreased; then, we provide a method for the Receiver total_senders() to read the count of Senders, and when the count is 0 and there is no data to read from the queue, the recv() method will report an error.

With this idea in mind, think about it, what data structure should this counter use? Should it be protected with a lock?

Ha, you must have thought that atomics could be used. Right, we can use AtomicUsize. So, the Shared data structure needs an update:

struct Shared<T> {
    queue: Mutex<VecDeque<T>>,
    available: Condvar,
    senders: AtomicUsize,
}

Requirement 5 #

Since we need to report an error when there are no Senders left, shouldn’t there also be an error return when there is no Receiver left and the Sender sends data? This requirement is similar to the previous one, so it won’t be explained again. Look at the constructed unit test 5:

#[test]
fn receiver_drop_should_error_when_send() {
    let (mut s1, mut s2) = {
        let (s, _) = unbounded();
        let s1 = s.clone();
        let s2 = s.clone();
        (s1, s2)
    };

    assert!(s1.send(1).is_err());
    assert!(s2.send(1).is_err());
}

Here, we create a channel, and as soon as two Senders are generated, the Receiver is immediately dropped. Both Senders will encounter errors when sending.

Similarly, the Shared data structure needs an update:

struct Shared<T> {
    queue: Mutex<VecDeque<T>>,
    available: Condvar,
    senders: AtomicUsize,
    receivers: AtomicUsize,
}

Implementing the MPSC Channel #

We have now written five unit tests, we’ve thoroughly understood the requirements, and we have a basic design for the interface and data structure. Next, let’s write the implementation code.

Create a new project cargo new con_utils --lib. Add anyhow as a dependency in cargo.toml. In lib.rs, just write: pub mod channel, then create src/channel.rs, and put all the test cases, designed data structures, and interfaces used in the test cases into the code:

use anyhow::Result;
use std::{
    collections::VecDeque,
    sync::{atomic::AtomicUsize, Arc, Condvar, Mutex},
};

/// Sender
pub struct Sender<T> {
    shared: Arc<Shared<T>>,
}

/// Receiver
pub struct Receiver<T> {
    shared: Arc<Shared<T>>,
}

/// Sender and receiver share a VecDeque, protected by Mutex, notified by Condvar
/// Also, we record the number of senders and receivers

struct Shared<T> {
    queue: Mutex<VecDeque<T>>,
    available: Condvar,
    senders: AtomicUsize,
    receivers: AtomicUsize,
}

impl<T> Sender<T> {
    /// The producer writes a piece of data
    pub fn send(&mut self, t: T) -> Result<()> {
        todo!()
    }

    pub fn total_receivers(&self) -> usize {
        todo!()
    }

    pub fn total_queued_items(&self) -> usize {
        todo!()
    }
}

impl<T> Receiver<T> {
    pub fn recv(&mut self) -> Result<T> {
        todo!()
    }

    pub fn total_senders(&self) -> usize {
        todo!()
    }
}

impl<T> Iterator for Receiver<T> {
    type Item = T;

    fn next(&mut self) -> Option<Self::Item> {
        todo!()
    }
}

/// Clone the sender
impl<T> Clone for Sender<T> {
    fn clone(&self) -> Self {
        todo!()
    }
}

/// Drop sender
impl<T> Drop for Sender<T> {
    fn drop(&mut self) {
        todo!()
    }
}

impl<T> Drop for Receiver<T> {
    fn drop(&mut self) {
        todo!()
    }
}

/// Create an unbounded channel
pub fn unbounded<T>() -> (Sender<T>, Receiver<T>) {
    todo!()
}

#[cfg(test)]
mod tests {
    use std::{thread, time::Duration};

    use super::*;
		// All test cases are omitted here
}

Currently, although this code can be compiled, since there’s no implementation, cargo test will fail. Next, we will implement the functionality, step by step.

Creating an unbonded channel #

The interface for creating an unbounded channel is very simple:

pub fn unbounded<T>() -> (Sender<T>, Receiver<T>) {
    let shared = Shared::default();
    let shared = Arc::new(shared);
    (
        Sender {
            shared: shared.clone(),
        },
        Receiver { shared },
    )
}

const INITIAL_SIZE: usize = 32;
impl<T> Default for Shared<T> {
    fn default() -> Self {
        Self {
            queue: Mutex::new(VecDeque::with_capacity(INITIAL_SIZE)),
            available: Condvar::new(),
            senders: AtomicUsize::new(1),
            receivers: AtomicUsize::new(1),
        }
    }
}

Since default() is used here to create the Shared structure, we need to implement Default for it. Initially, we have 1 producer and 1 consumer.

Implementing the Consumer #

For the consumer, we mainly need to implement the recv method.

In recv, if there is data in the queue, it is directly returned; if there is no data and all producers have left, we return an error; if there is no data but there are still producers, we block the consumer’s thread:

impl<T> Receiver<T> {
    pub fn recv(&mut self) -> Result<T> {
        // Lock the queue
        let mut inner = self.shared.queue.lock().unwrap();
        loop {
            match inner.pop_front() {
                // Return the data when it is read, the lock is released
                Some(t) => {
                    return Ok(t);
                }
                // No data can be read, and all producers have exited, release the lock and return an error
                None if self.total_senders() == 0 => return Err(anyhow!("no sender left")),
                // No data to read, hand the lock over to the available Condvar, it will release the lock and suspend the thread, waiting for notification
                None => {
                    // Once Condvar is awakened, it will return a MutexGuard, we can loop back to get the data
                    // This is why Condvar should be used within a loop
                    inner = self
                        .shared
                        .available
                        .wait(inner)
                        .map_err(|_| anyhow!("lock poisoned"))?;
                }
            }
        }
    }

    pub fn total_senders(&self) -> usize {
        self.shared.senders.load(Ordering::SeqCst)
    }
}

Take note of the use of Condvar here.

In its wait() method, it accepts a MutexGuard, then releases this Mutex, suspending the thread. When notified, it re-acquires the lock, getting a MutexGuard, and returns it. So here it’s:

inner = self.shared.available.wait(inner).map_err(|_| anyhow!("lock poisoned"))?;

Since recv() will return a value, after suspension and awakening, we should loop back to get the data. This is why the logic should be wrapped in a loop{}. We’ve previously considered during design: when a sender sends data, it should notify the suspended consumer. So, in implementing Sender’s send(), we need to do the relevant notify handling.

Remember the drop for the consumer:

impl<T> Drop for Receiver<T> {
    fn drop(&mut self) {
        self.shared.receivers.fetch_sub(1, Ordering::AcqRel);
    }
}

It’s simple, the consumer leaves, and the receivers are decremented by one.

Implementing the Producer #

Next, let’s see how to implement the functionality of the producer.

Firstly, if there are no consumers left, there should be an error. We should use thiserror to define our own errors, but here, for the sake of simplifying the code, we use the anyhow! macro to generate an ad-hoc error. If there are still consumers, we get the lock for VecDeque, push the data, and immediately release the lock:

impl<T> Sender<T> {
    /// The producer writes a piece of data
    pub fn send(&mut self, t: T) -> Result<()> {
        // If there are no consumers left, raise an error when writing
        if self.total_receivers() == 0 {
            return Err(anyhow!("no receiver left"));
        }

        // Lock, access VecDeque, push in data, then immediately release the lock
        let was_empty = {
            let mut inner = self.shared.queue.lock().unwrap();
            let empty = inner.is_empty();
            inner.push_back(t);
            empty
        };

        // Notify any suspended and waiting consumers there is data
        if was_empty {
            self.shared.available.notify_one();
        }

        Ok(())
    }

    pub fn total_receivers(&self) -> usize {
        self.shared.receivers.load(Ordering::SeqCst)
    }

    pub fn total_queued_items(&self) -> usize {
        let queue = self.shared.queue.lock().unwrap();
        queue.len()
    }
}

Here, when getting total_receivers(), we used Ordering::SeqCst to ensure that all threads see the operations on receivers in the same order. This value is the most recent.

**When pushing data, it is necessary to check whether the queue was empty before because when the queue is empty, we