markrs/
thread_pool.rs

1use std::{
2    fmt,
3    sync::{Arc, Mutex, mpsc},
4    thread,
5};
6
7use log::warn;
8
9pub struct ThreadPool {
10    workers: Vec<Worker>,
11    sender: mpsc::Sender<Job>,
12}
13
14impl ThreadPool {
15    pub fn build(size: usize) -> Result<Self, Error> {
16        if size == 0 {
17            return Err(Error::PoolCreation {
18                message: "Thread pool size must be greater than 0".to_string(),
19            });
20        }
21
22        let (sender, receiver) = mpsc::channel();
23        let receiver = Arc::new(Mutex::new(receiver));
24
25        let mut workers = Vec::with_capacity(size);
26        for id in 0..size {
27            workers.push(Worker::build(id, Arc::clone(&receiver)).map_err(|e| {
28                Error::PoolCreation {
29                    message: format!("Failed to create worker thread {}: {}", id, e),
30                }
31            })?);
32        }
33
34        Ok(ThreadPool { workers, sender })
35    }
36
37    pub fn join_all(self) {
38        drop(self.sender); // Close the channel to signal workers to exit
39
40        for worker in self.workers {
41            if let Err(e) = worker.thread.join() {
42                warn!("Worker thread {} failed to join: {:?}", worker.id, e);
43            }
44        }
45    }
46
47    pub fn execute<F>(&self, f: F) -> Result<(), Error>
48    where
49        F: FnOnce() + Send + 'static,
50    {
51        let job: Job = Box::new(f);
52
53        self.sender.send(job).map_err(|e| {
54            warn!("Failed to send job to thread pool: {e}");
55            Error::JobExecution {
56                message: format!("Failed to send job to thread pool: {e}"),
57            }
58        })
59    }
60}
61
62struct Worker {
63    pub id: usize,
64    pub thread: thread::JoinHandle<()>,
65}
66
67impl Worker {
68    fn build(id: usize, receiver: Arc<Mutex<mpsc::Receiver<Job>>>) -> Result<Self, Error> {
69        let builder = thread::Builder::new();
70
71        let thread = builder
72            .spawn(move || {
73                loop {
74                    let job_result = {
75                        let receiver = receiver.lock().expect("Failed to lock receiver mutex");
76                        receiver.recv()
77                    };
78
79                    match job_result {
80                        Ok(job) => {
81                            job();
82                        }
83                        Err(_) => {
84                            break; // Exit the loop if the channel is closed
85                        }
86                    }
87                }
88            })
89            .map_err(|e| Error::WorkerCreation {
90                message: format!("Failed to spawn thread {id}: {e}"),
91            })?;
92
93        Ok(Worker { id, thread })
94    }
95}
96
97type Job = Box<dyn FnOnce() + Send + 'static>;
98
99#[derive(Debug)]
100pub enum Error {
101    PoolCreation { message: String },
102    JobExecution { message: String },
103    WorkerCreation { message: String },
104}
105
106impl fmt::Display for Error {
107    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
108        match self {
109            Error::PoolCreation { message } => write!(f, "Thread pool creation error: {message}"),
110            Error::JobExecution { message } => write!(f, "Job execution error: {message}"),
111            Error::WorkerCreation { message } => write!(f, "Worker creation error: {message}"),
112        }
113    }
114}
115
116impl std::error::Error for Error {}