Rust线程池源码拆解及实现

看了rust的threadpool crate,照虎画猫!

基本模型

基本模型是,存在一个任务channel,发送头在ThreadPool里,通过execute方法发送任务闭包,已初始化的若干个工作线程一直处在loop里,行为是从任务通道里取任务,执行,取任务……这样循环。因为有loop,所以线程就一直没退出,免去了回收、再创建的消耗。

threadpool.png

任务闭包签名为FnOnce() + Send + 'static,没有参数,没有返回值,所以需要对真实的调用函数做适配,捕获参数,构造闭包,如果有返回值就直接用共享的变量收集,比如下面这样,每个任务是产生输入参数的平方数,结果收集到Vec

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
fn do_something(i: usize) -> usize {
thread::sleep(Duration::from_micros(50 * i as u64));
i * i
}

fn main() {
let pool = ThreadPool::new(4);
let ans = Arc::new(Mutex::new(Vec::<usize>::new()));
for i in 0..10 {
let ans = ans.clone();
let job = move || {
let result = do_something(i);
(*ans.lock().unwrap()).push(result);
};
pool.execute(job)
}
pool.join();
assert_eq!(285usize, (*ans.lock().unwrap()).iter().sum());
}

搭建骨架

任务闭包用Box装箱来传指针,动态查找虽然有性能损耗,但是闭包结构的复制消耗也相对固定,免得有个什么大闭包结构被莫名复制到线程栈上。

1
type Job = Box<dyn FnOnce() + Send + 'static>;

初始化一个工作线程,就是把一个channel的接收端加锁,传给线程,同步地从channel中拿出任务来执行。Mutex加锁,Arc创建共享所有权。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
fn spawn_in_pool(receiver: Arc<Mutex<Receiver<Job>>>) {
thread::spawn(move || loop {
let recv_result = {
// 以最少时间占用锁
receiver.lock().expect("can't lock channel receiver").recv()
};

let job_fn = match recv_result {
Ok(job) => job,
Err(_) => break, // 这个错误说明另一端已经drop,退出该线程就可以
};
job_fn();
});
}

ThreadPool现在只要一个channel的传输入口,用于任务发送。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
struct ThreadPool {
sender: Sender<Job>,
}

impl ThreadPool {
...
fn execute<F>(&self, f: F)
where
F: FnOnce() + Send + 'static,
{
self.sender.send(Box::new(f)).unwrap();
}
...
}

初始化工作自然就是要建立起任务channel,然后调用spawn_in_pool,初始化要求个数的工作线程

1
2
3
4
5
6
7
8
9
10
11
impl ThreadPool {
fn new(threads_num: usize) -> ThreadPool {
let (tx, rx) = channel::<Job>(); // 建channel
let receiver = Arc::new(Mutex::new(rx)); // 初始化工作线程
for _ in 0..threads_num {
let receiver = receiver.clone();
spawn_in_pool(receiver);
}
ThreadPool { sender: tx } // 返回pool
}
}

join作用就等待工作线程完成全部任务,这个事情下阶段来做,这里先随便sleep替代一下,能通过测试用例就可以。

1
2
3
4
5
6
impl ThreadPool {
fn join(&self) {
println!("waiting...");
thread::sleep(Duration::from_secs(2));
}
}

很好,测试通过!🎉🎉🎉🎉线程池写完了!膨胀得先去吃个烧烤庆祝一下!

正确地join

什么时候说明任务执行完了?

  • 任务队列任务计数为0,记为queue_cnt == 0
  • 没有正在执行任务的线程,记为active_cnt == 0

所以我们需要共享这两个变量给所有工作线程。工作线程取出一个任务,queue_cnt -= 1,执行任务前active_cnt += 1,完成后active_cnt -= 1。显然ThreadPool也需要共享这两个变量,executequeue_cnt += 1

join时,可以简单粗暴地轮询这些两个变量判断。但是这种通知的同步场景,使用条件变量来挂起等待线程,有事件时重启线程更高效。

谁在等待?join方法,使用条件变量的wait挂起线程,等待事件发生。

谁来通知?工作线程,当完成一个任务后,检查任务是否全部完成,如果是,就notify_all(因为ThreadPool也可以在多个线程里并发地发送任务)。

所以现在要共享的变量又多了4个,只能单独写一个结构来封装他们,然后共享给工作线程。对于queue_cntactive_cnt这种基本类型,有AtomicXxx等原子类型供选择,但是内存顺序暂时还没看懂,先用Mutex代替吧,性能低点就低。为了方便,直接把判断和通知的方法实现在这个SharedData里。

条件变量和互斥锁配套使用,互斥锁的基本作用是,对条件变量的操作,无论是waitnotify_*,都是互斥进行的,保证条件变量在增加、移除暂停线程时的安全性,另外互斥锁本身也可以携带信息,作为条件判断的对象。但是这个里的cond_guard只起第一个作用,条件的判断由queue_cntactive_cnt承担。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
struct SharedData {
receiver: Mutex<Receiver<Job>>,
queue_cnt: Mutex<usize>,
active_cnt: Mutex<usize>,
cond_guard: Mutex<()>,
cond: Condvar,
}

impl SharedData {
fn has_task(&self) -> bool {
*self.queue_cnt.lock().expect("can't lock queue_cnt") > 0
|| *self.active_cnt.lock().expect("can't lock active_cnt") > 0
}

fn notify_when_no_tasks(&self) {
if !self.has_task() {
*self.cond_guard.lock().expect("can't lock cond_guard");
self.cond.notify_all(); // 独占地通知
}
}
}

接下来就是围绕这个新的中间结构,改造之前的骨架就可以,比如spawn_in_pool的要改签名,增加变量维护的逻辑。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
fn spawn_in_pool(data: Arc<SharedData>) {
// ...取出任务
{
*data.queue_cnt.lock().expect("can't lock queue_cnt") -= 1;
}
{
*data.active_cnt.lock().expect("can't lock queue_cnt") += 1;
}
job_fn();
{
*data.active_cnt.lock().expect("can't lock queue_cnt") -= 1;
}

data.notify_when_no_tasks();
}

ThreadPool本身也多了一个data字段,方便joinexecute时使用

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
struct ThreadPool {
sender: Sender<Job>,
data: Arc<SharedData>,
}

impl ThreadPool {
fn new(threads_num: usize) -> ThreadPool {
let (tx, rx) = channel::<Job>();
let data = Arc::new(SharedData {
receiver: Mutex::new(rx),
queue_cnt: Mutex::new(0),
active_cnt: Mutex::new(0),
cond_guard: Mutex::new(()),
cond: Condvar::new(),
});

for _ in 0..threads_num {
let data = data.clone();
spawn_in_pool(data);
}
ThreadPool {
sender: tx,
data: data,
}
}

fn execute<F>(&self, f: F)
where
F: FnOnce() + Send + 'static,
{
{
*self.data.queue_cnt.lock().expect("can't lock queue_cnt") += 1;
}
self.sender.send(Box::new(f)).unwrap();
}

fn join(&self) {
if !self.data.has_task() { // 一个小优化,有机会避免一次cond_guard的加锁
return;
}
let mut guard = self.data.cond_guard.lock().expect("can't lock cond guard");
while self.data.has_task() {
guard = self.data.cond.wait(guard).unwrap();
}
}
}

🐮🍺,测试通过!🎉🎉🎉🎉阶段性成果!成功完成了正常工作的join

动态设置工作线程数

线程池创建之后,如果想要动态增加或者减少线程数量可咋搞?分两种情况来考虑。

  1. 工作线程数增加,差多少个,就调用多少次spawn_in_pool,补齐差距。
  2. 工作线程减少,那就要让一些工作线程自动地break,退出后自动被回收。怎么让这些多余的线程知道自己被下岗了呢?现在正在执行任务的线程不能动,那些执行完,再次循环去工作队列里拿任务前,就可以检查多少线程正在工作active_cnt,如果大于设定的最大工作线程数,那么自己就主动退出。

所以思路就是

  • SharedData里增加一个max_threads_cnt字段;
  • ThreadPool实现set_threads_num方法,如果是增加,就调用spawn_in_pool补齐;
  • 工作线程的每次取任务前都检查一下active_cnt是否大于max_threads_cnt,是就break;

后两点变动的代码是

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
impl ThreadPool {
// ...
fn set_threads_num(&self, size: usize) {
let mut prev_cnt = size;
{
let mut p_thread_cnt = self
.data
.max_threads_cnt
.lock()
.expect("can't lock max_threads_cnt");
prev_cnt = *p_thread_cnt;
*p_thread_cnt = size;
}
if let Some(n) = size.checked_sub(prev_cnt) {
for _ in 0..n {
let data = self.data.clone();
spawn_in_pool(data);
}
}
}
}


fn spawn_in_pool(data: Arc<SharedData>) {
{
if *data.active_cnt.lock().unwrap() >= *data.max_threads_cnt.lock().unwrap() {
break
}
}
// ...取出任务
}

工作线程panic怎么办

如果任务闭包job_fn执行产生了panic,那么会产生什么影响?首先是线程退出

  • 线程退出,工作线程不等于当初设定的线程数
  • job_fn后面的代码没有执行:active_cnt没有减一、没有通知join

所以因panic退出时,得执行被跳过的逻辑,并且重启线程。

panic退出时还要执行代码?这咋办呢,只能是某个结构实现Droptrait,在drop方法中执行没有执行完的逻辑。下面是源码中Sentinel哨兵的实现。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
impl<'a> Sentinel<'a> {
// 在工作线程进入loop前,new一个新鲜的sentinel
fn new(shared_data: &'a Arc<ThreadPoolSharedData>) -> Sentinel<'a> {
Sentinel {
shared_data: shared_data,
active: true,
}
}

/// loop正常break出来,执行cancel.
fn cancel(mut self) {
self.active = false;
}
}

impl<'a> Drop for Sentinel<'a> {
fn drop(&mut self) {
if self.active { // 说明不是break正常退出
self.shared_data.active_count.fetch_sub(1, Ordering::SeqCst);
if thread::panicking() {
self.shared_data.panic_count.fetch_add(1, Ordering::SeqCst);
}
self.shared_data.no_work_notify_all();
spawn_in_pool(self.shared_data.clone())
}
}
}

然后我的疑问是:为什么会有哨兵的active字段和thread::panicking()的双重检查?非正常退出除了panic还有其他的情况?如果是其他情况,为什么进入active分支后直接对active_cnt进行减1操作,这个操作的潜在假设是逻辑流一定中止于闭包调用呀,那么除了job_fn本身panic,还能有什么可能满足这个假设呢?我还试过了,源代码实现中,如果获取获取工作队列锁失败,expect产生的panic也会进入drop方法,造成active_count减1后溢出,行为完全失控。

我个人看法,正确的解决办法,应该是drop只应对一种情况,就是任务闭包panic,这是唯一不可控的代码,其他情况如果出现error,就地处理,比如工作队列获取锁,如果失败,标记sentinel的drop不要处理,主动spand_in_pool挽救。像我这里没有使用AtomicXxx,获取状态字段时出错,那么就是严重的运行时问题,已经无法维持线程池的正确状态,也是标记drop不要处理,同时标记线程池停止工作,抛出panic。

还有什么可以做的吗?

  • 现在异常只会计数,是不是可以给每个任务增加名字,join之后可以重新执行失败的任务?
  • 任务超时怎么办?ThreadPool结构就得主动杀死线程,所以还得共享一个<线程id-状态>的映射?
  • 把工作队列的性质也做成初始化参数?用户可以选择使用有界队列和无界队列?
  • 工作线程panic的问题,还待确认解决方案。
  • 源代码中的test_threads_num_decreasing测试用例有误,第二次execute的数量应该多于new_threads_num

OK,线程池的源码解析就到这里吧,挺有收获的

  1. 终于明白了线程池的模型,原来是队列+多个loop线程👏
  2. 熟悉了rust中的线程同步工具。算是对Unique+Share角度的复习,这个角度思考真的非常有用!ArcMutex啥的看得更清楚了。

本博客所有文章除特别声明外,均采用 CC BY-SA 4.0 协议 ,转载请注明出处!