设计一个 C++ 线程池

设计与实现了一个简单的 C++ 线程池,支持添加使用函数指针与参数的自定义任务

什么是线程池?

线程池是一种典型的 生产者消费者 模型,线程池管理多个线程并为它们分配任务,实现了对线程资源的复用,用于解决一些场景下频繁创建线程和销毁线程带来的性能损耗。下图来自 Wiki,演示了一个包含等待任务队列和完成任务队列的简单线程池的工作流程。

A sample thread pool (green boxes) with waiting tasks (blue) and completed tasks (yellow)

设计分析

初见线程池,实现一个含有基本功能的简易线程池还是比较简单的,我们只需要该线程池具备基础的添加任务并执行的功能,那么重点有 2 个,一个是用于管理各个工作线程的 线程池类,另一个是用于存放具体执行任务的 任务队列

任务队列

先来设计任务队列,我们将一个可执行的 函数对象 称为任务,任务队列用于存放用户提交的任务,以供线程池分配具体任务给池中的线程,对于这个简易的线程池,我们暂不考虑支持任务优先级以及任务抢占,只需关注任务的添加和取出即可。显然任务队列需要符合 FIFO(First Input First Output)的原则,那么我们可以使用 std::queue 来存放任务,此外我们需要关注的是任务队列的线程安全,多个线程同时操作单一任务队列时会产生线程不安全的问题,因此我们可以使用互斥锁来保障队列的线程安全。代码实现如下:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
template <typename T>
class TaskQueue
{
private:
    std::queue<T> queue_;
    std::mutex    queue_mutex_;

public:
    TaskQueue() {}
    ~TaskQueue() {}

    int size() {
        std::unique_lock<std::mutex> lock(queue_mutex_);

        return queue_.size();
    }

    bool empty() {
        std::unique_lock<std::mutex> lock(queue_mutex_);

        return queue_.empty();
    }

    void enqueue(T &t) {
        std::unique_lock<std::mutex> lock(queue_mutex_);

        queue_.emplace(t);
    }

    bool dequeue(T &t) {
        std::unique_lock<std::mutex> lock(queue_mutex_);

        if (queue_.empty()) return false;

        t = std::move(queue_.front());
        queue_.pop();

        return true;
    }
};

我们使用一个模板类来实现 TaskQueue,存放在任务队列中的任务类型可以是多样的,但在后续线程池的实现中我们将任务函数均包装为 std::function<void()> 以让代码更为简洁。该任务队列使用了 std::mutexstd::unique_lock 来保证线程安全,其中 std::mutex 为互斥锁,用于保护共享资源的访问,而 std::unique_lock 用于管理互斥锁,其利用了 RAII(Resource Acquisition Is Initialization)机制来实现锁的自动释放。该任务队列含有入队与出队的功能,线程池可将任务从中取出以分配给工作线程执行。

线程池类

线程池类符合 生产者消费者 模型,生产者是线程池,消费者是池中的工作线程。其中池最重要的功能便是实现任务的提交,即将任务加入任务队列并唤起线程去执行;工作线程则从任务队列中取出任务并执行后返回执行结果。

我们先来看看池中提交函数的实现:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
template <typename F, typename ...Args>
auto submit(F &&f, Args &&...args) -> std::future<decltype(f(args...))> {
    std::function<decltype(f(args...))()> func = std::bind(std::forward<F>(f), std::forward<Args>(args)...);

    auto task_ptr = std::make_shared<std::packaged_task<decltype(f(args...))()>>(func);

    std::function<void()> wrapper_func = [task_ptr]() {
        (*task_ptr)();
    };
    
    task_queue_.enqueue(wrapper_func);

    // 唤醒一个线程执行新提交的任务
    thread_condition_.notify_one();

    return task_ptr->get_future();
}

我们需要支持用户通过传递函数指针与参数的方式提交任务,而参数的个数和类型是不确定的,因此我们需要一个模板函数,其中 typename ...Args 是一个模板参数包,用于接收不定数量的参数,可使用 args... 将参数展开。

再来看看这个函数头:
auto submit(F &&f, Args &&...args) -> std::future<decltype(f(args...))>
由于不同任务的返回值不同,这就到了发挥泛型编程能力的时候了,这里使用到了 尾置返回类型推导,实际上这个函数返回的是一个 std::future,其类型由 decltype(f(args...)) 推导得出,std::future 用于获取异步操作的返回值。

函数体内的操作简单来说就是将任务重新包装成一个 std::function<void()>,这样的目的是统一每个任务以便于将其加入任务队列,简化代码。由于我们通过 std::shared_ptr 定义了一个指向由 std::packaged_task 打包的任务的智能指针,最终可以通过 get_future() 接口获取与 std::packaged_task 关联的 std::future 对象以获得执行结果。

我们再来看看工作线程的实现:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
class Worker
{
private:
    int            id_;
    ThreadPoolDemo *pool_; // 所属的线程池

public:
    Worker(const int id, ThreadPoolDemo *pool)
    : id_(id)
    , pool_(pool) {}

    ~Worker() {}

    // 重载运算符
    void operator()() {
        std::function<void()> func;
        bool dequeued;

        while (!pool_->close_) {
            {
                // 使用函数块来管理互斥锁的生命周期
                std::unique_lock<std::mutex> lock(pool_->thread_mutex_);

                if (pool_->task_queue_.empty()) {
                    pool_->thread_condition_.wait(lock);
                }

                dequeued = pool_->task_queue_.dequeue(func);
            }
            if (dequeued) func();
        }
    }
};

由于 std::thread 在执行时会调用函数对象的 operator() 运算符,因此我们需要重载此运算符。

一般情况下,当单一线程被创建并执行完任务后它就会被自动销毁,那么我们如何保证工作线程的持久性呢?将 从任务队列中取任务并执行 这个操作作为一个工作线程的执行任务即可,我们可以使用一个循环去实现,那么当任务队列中没有任务时如何处理呢?这里我们使用 条件变量 来控制工作线程的状态,当任务队列中没有任务时,工作线程已经获得了互斥锁,调用 std::condition_variablewait() 接口可以让线程进入等待状态并释放它所持有的互斥锁,当该工作线程被 notify_one() 或者 notify_all() 通知后,会重新获得互斥锁,继续从任务队列中获取并执行任务。

完整代码

以下是线程池的完整代码实现:

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
// thread_pool_demo.h

#ifndef THREAD_POOL_DEMO_H
#define THREAD_POOL_DEMO_H

#include <queue>
#include <mutex>
#include <functional>
#include <vector>
#include <thread>
#include <condition_variable>
#include <future>

template <typename T>
class TaskQueue
{
private:
    std::queue<T> queue_;
    std::mutex    queue_mutex_;

public:
    TaskQueue() {}
    ~TaskQueue() {}

    int size() {
        std::unique_lock<std::mutex> lock(queue_mutex_);

        return queue_.size();
    }

    bool empty() {
        std::unique_lock<std::mutex> lock(queue_mutex_);

        return queue_.empty();
    }

    void enqueue(T &t) {
        std::unique_lock<std::mutex> lock(queue_mutex_);

        queue_.emplace(t);
    }

    bool dequeue(T &t) {
        std::unique_lock<std::mutex> lock(queue_mutex_);

        if (queue_.empty()) return false;

        t = std::move(queue_.front());
        queue_.pop();

        return true;
    }
};

class ThreadPoolDemo
{
private:
    bool                             close_;
    TaskQueue<std::function<void()>> task_queue_;
    std::vector<std::thread>         threads_;
    std::mutex                       thread_mutex_;
    std::condition_variable          thread_condition_;

    class Worker
    {
    private:
        int            id_;
        ThreadPoolDemo *pool_; // 所属的线程池

    public:
        Worker(const int id, ThreadPoolDemo *pool)
        : id_(id)
        , pool_(pool) {}

        ~Worker() {}

        void operator()() {
            std::function<void()> func;
            bool dequeued;

            while (!pool_->close_) {
                {
                    std::unique_lock<std::mutex> lock(pool_->thread_mutex_);

                    if (pool_->task_queue_.empty()) {
                        pool_->thread_condition_.wait(lock);
                    }

                    dequeued = pool_->task_queue_.dequeue(func);
                }
                if (dequeued) func();
            }
        }
    };

public:
    ThreadPoolDemo(const unsigned int num_of_threads = std::thread::hardware_concurrency()) // 线程数默认为CPU的核数
    : close_(false)
    , threads_(std::vector<std::thread>(num_of_threads)) {}

    // 禁用拷贝构造函数
    ThreadPoolDemo(const ThreadPoolDemo &) = delete;
    ThreadPoolDemo(ThreadPoolDemo &&) = delete;
    ThreadPoolDemo &operator=(const ThreadPoolDemo &) = delete;
    ThreadPoolDemo &operator=(ThreadPoolDemo &&) = delete;

    ~ThreadPoolDemo() {}

    // 初始化,创建所有工作线程
    void init() {
        for (int i = 0; i < threads_.size(); ++i) {
            threads_.at(i) = std::thread(Worker(i, this));
        }
    }

    void close() {
        close_ = true;

        // 唤醒所有线程,完成剩余任务
        thread_condition_.notify_all();

        for (int i = 0; i < threads_.size(); ++i) {
            if (threads_.at(i).joinable()) {
                threads_.at(i).join();
            }
        }
    }

    template <typename F, typename ...Args>
    auto submit(F &&f, Args &&...args) -> std::future<decltype(f(args...))> {
        std::function<decltype(f(args...))()> func = std::bind(std::forward<F>(f), std::forward<Args>(args)...);

        auto task_ptr = std::make_shared<std::packaged_task<decltype(f(args...))()>>(func);

        std::function<void()> wrapper_func = [task_ptr]() {
            (*task_ptr)();
        };
        
        task_queue_.enqueue(wrapper_func);
        thread_condition_.notify_one();

        return task_ptr->get_future();
    }
};

#endif

测试

以下是测试用例的完整代码实现:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
// thread_pool_test.cpp

#include <iostream>
#include <random>
#include "thread_pool_demo.h"

std::random_device rd;
std::mt19937 mt(rd());
std::uniform_int_distribution<int> dist(-1000, 1000);
auto rand_sec = std::bind(dist, mt);

// 使用随机事件向线程池中添加任务
void simulate_hard_computation() {
    std::this_thread::sleep_for(std::chrono::milliseconds(1000 + rand_sec()));
}

void multiply(const int a, const int b) {
    simulate_hard_computation();
    const int res = a * b;
    std::cout << a << " * " << b << " = " << res << " Thread id: " << std::this_thread::get_id() << std::endl;
}

// 通过引用返回执行结果
void multiply_output(int &out, const int a, const int b) {
    simulate_hard_computation();
    out = a * b;
    std::cout << a << " * " << b << " = " << out << " Thread id: " << std::this_thread::get_id() << std::endl;
}

// 通过返回值返回执行结果
int multiply_return(const int a, const int b) {
    simulate_hard_computation();
    const int res = a * b;
    std::cout << a << " * " << b << " = " << res << " Thread id: " << std::this_thread::get_id() << std::endl;
    return res;
}

void test() {
    ThreadPoolDemo thread_pool(4);
    thread_pool.init();

    for (int i = 1; i < 4; ++i) {
        for (int j = 1; j < 11; ++j) {
            thread_pool.submit(multiply, i, j);
        }
    }

    int output = 0;
    auto future_1 = thread_pool.submit(multiply_output, std::ref(output), 5, 20);
    future_1.get();
    std::cout << "Last operation result is equals to " << output << std::endl;

    auto future_2 = thread_pool.submit(multiply_return, 13, 14);
    int res = future_2.get();
    std::cout << "Last operation result is equals to " << res << std::endl;

    thread_pool.close();
}

int main(void) 
{
    std::cout << "Test start ..." << std::endl;
    test();
    std::cout << "Test end ..." << std::endl;

    return 0;
}

测试结果:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
Test start ...
1 * 2 = 2 Thread id: 140183345608256
1 * 3 = 3 Thread id: 140183337215552
1 * 4 = 4 Thread id: 140183328822848
1 * 1 = 1 Thread id: 140183320430144
1 * 5 = 5 Thread id: 140183345608256
1 * 8 = 8 Thread id: 140183320430144
1 * 9 = 9 Thread id: 140183345608256
1 * 10 = 10 Thread id: 140183320430144
1 * 6 = 6 Thread id: 140183337215552
1 * 7 = 7 Thread id: 140183328822848
2 * 4 = 8 Thread id: 140183328822848
2 * 1 = 2 Thread id: 140183345608256
2 * 3 = 6 Thread id: 140183337215552
2 * 2 = 4 Thread id: 140183320430144
2 * 6 = 12 Thread id: 140183345608256
2 * 7 = 14 Thread id: 140183337215552
2 * 8 = 16 Thread id: 140183320430144
2 * 5 = 10 Thread id: 140183328822848
2 * 9 = 18 Thread id: 140183345608256
2 * 10 = 20 Thread id: 140183337215552
3 * 1 = 3 Thread id: 140183320430144
3 * 2 = 6 Thread id: 140183328822848
3 * 4 = 12 Thread id: 140183337215552
3 * 5 = 15 Thread id: 140183320430144
3 * 3 = 9 Thread id: 140183345608256
3 * 6 = 18 Thread id: 140183328822848
3 * 8 = 24 Thread id: 140183320430144
3 * 7 = 21 Thread id: 140183337215552
3 * 9 = 27 Thread id: 140183345608256
3 * 10 = 30 Thread id: 140183328822848
5 * 20 = 100 Thread id: 140183320430144
Last operation result is equals to 100
13 * 14 = 182 Thread id: 140183337215552
Last operation result is equals to 182
Test end ...

我们使用随机时间向线程池中提交任务,可以从输出的 Thread id 看到我们共使用了 4 个线程执行了任务,符合该线程池的设计目的。

参考

  1. Thread pool - Wikipedia
  2. 基于 C++11 实现线程池 - 知乎 (zhihu.com)
  3. progschj/ThreadPool: A simple C++11 Thread Pool implementation (github.com)