Skip to content

Commit 85d923c

Browse files
committed
Cleanup/refactor of thread pool.
1 parent c459ad2 commit 85d923c

29 files changed

+184
-197
lines changed

Common/Cpp/Concurrency/AsyncDispatcher.cpp

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -15,35 +15,6 @@
1515
namespace PokemonAutomation{
1616

1717

18-
AsyncTask::~AsyncTask(){
19-
std::unique_lock<std::mutex> lg(m_lock);
20-
m_cv.wait(lg, [this]{ return m_finished; });
21-
}
22-
bool AsyncTask::is_finished() const{
23-
std::lock_guard<std::mutex> lg(m_lock);
24-
return m_finished;
25-
}
26-
void AsyncTask::rethrow_exceptions(){
27-
if (!m_stopped_with_error.load(std::memory_order_acquire)){
28-
return;
29-
}
30-
std::unique_lock<std::mutex> lg(m_lock);
31-
if (m_exception){
32-
std::rethrow_exception(m_exception);
33-
}
34-
}
35-
void AsyncTask::wait_and_rethrow_exceptions(){
36-
std::unique_lock<std::mutex> lg(m_lock);
37-
m_cv.wait(lg, [this]{ return m_finished; });
38-
if (m_exception){
39-
std::rethrow_exception(m_exception);
40-
}
41-
}
42-
void AsyncTask::signal(){
43-
std::lock_guard<std::mutex> lg(m_lock);
44-
m_finished = true;
45-
m_cv.notify_all();
46-
}
4718

4819

4920
#if 0

Common/Cpp/Concurrency/AsyncDispatcher.h

Lines changed: 3 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
* From: https://github.com/PokemonAutomation/
44
*
55
* This class is meant for asynchronous tasks, not for parallel computation.
6-
* This class will always spawn enough threads run all tasks in parallel.
6+
* This class will always spawn enough threads run all tasks in parallel.
77
*
8-
* If you need to spam a bunch of compute tasks in parallel, use ParallelTaskRunner.
8+
* If you need to spam a bunch of compute tasks in parallel, use ComputationThreadPool.
99
*
1010
*/
1111

@@ -15,53 +15,14 @@
1515
#include <vector>
1616
#include <deque>
1717
#include <functional>
18-
#include <atomic>
1918
#include <mutex>
2019
#include <condition_variable>
2120
#include <thread>
22-
#include <exception>
21+
#include "AsyncTask.h"
2322

2423
namespace PokemonAutomation{
2524

2625

27-
class AsyncTask{
28-
public:
29-
// Wait for the task to finish before destructing. Doesn't rethrow exceptions.
30-
~AsyncTask();
31-
32-
bool is_finished() const;
33-
34-
// If the task ended with an exception, rethrow it here.
35-
// This does not clear the exception.
36-
void rethrow_exceptions();
37-
38-
// Wait for the task to finish. Will rethrow any exceptions.
39-
void wait_and_rethrow_exceptions();
40-
41-
42-
private:
43-
template <class... Args>
44-
AsyncTask(Args&&... args)
45-
: m_task(std::forward<Args>(args)...)
46-
, m_finished(false)
47-
, m_stopped_with_error(false)
48-
{}
49-
void signal();
50-
51-
private:
52-
friend class FireForgetDispatcher;
53-
friend class AsyncDispatcher;
54-
friend class ParallelTaskRunnerCore;
55-
56-
std::function<void()> m_task;
57-
bool m_finished;
58-
std::atomic<bool> m_stopped_with_error;
59-
std::exception_ptr m_exception;
60-
mutable std::mutex m_lock;
61-
std::condition_variable m_cv;
62-
};
63-
64-
6526

6627
class AsyncDispatcher{
6728
public:
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
/* Async Task
2+
*
3+
* From: https://github.com/PokemonAutomation/
4+
*
5+
*/
6+
7+
#include "AsyncTask.h"
8+
9+
namespace PokemonAutomation{
10+
11+
12+
13+
AsyncTask::~AsyncTask(){
14+
std::unique_lock<std::mutex> lg(m_lock);
15+
m_cv.wait(lg, [this]{ return m_finished; });
16+
}
17+
bool AsyncTask::is_finished() const{
18+
std::lock_guard<std::mutex> lg(m_lock);
19+
return m_finished;
20+
}
21+
void AsyncTask::rethrow_exceptions(){
22+
if (!m_stopped_with_error.load(std::memory_order_acquire)){
23+
return;
24+
}
25+
std::unique_lock<std::mutex> lg(m_lock);
26+
if (m_exception){
27+
std::rethrow_exception(m_exception);
28+
}
29+
}
30+
void AsyncTask::wait_and_rethrow_exceptions(){
31+
std::unique_lock<std::mutex> lg(m_lock);
32+
m_cv.wait(lg, [this]{ return m_finished; });
33+
if (m_exception){
34+
std::rethrow_exception(m_exception);
35+
}
36+
}
37+
void AsyncTask::signal(){
38+
std::lock_guard<std::mutex> lg(m_lock);
39+
m_finished = true;
40+
m_cv.notify_all();
41+
}
42+
43+
44+
45+
46+
47+
}

Common/Cpp/Concurrency/AsyncTask.h

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
/* Async Task
2+
*
3+
* From: https://github.com/PokemonAutomation/
4+
*
5+
*/
6+
7+
#ifndef PokemonAutomation_AsyncTask_H
8+
#define PokemonAutomation_AsyncTask_H
9+
10+
#include <functional>
11+
#include <atomic>
12+
#include <mutex>
13+
14+
namespace PokemonAutomation{
15+
16+
17+
class AsyncTask{
18+
public:
19+
// Wait for the task to finish before destructing. Doesn't rethrow exceptions.
20+
~AsyncTask();
21+
22+
bool is_finished() const;
23+
24+
// If the task ended with an exception, rethrow it here.
25+
// This does not clear the exception.
26+
void rethrow_exceptions();
27+
28+
// Wait for the task to finish. Will rethrow any exceptions.
29+
void wait_and_rethrow_exceptions();
30+
31+
32+
private:
33+
template <class... Args>
34+
AsyncTask(Args&&... args)
35+
: m_task(std::forward<Args>(args)...)
36+
, m_finished(false)
37+
, m_stopped_with_error(false)
38+
{}
39+
void signal();
40+
41+
private:
42+
friend class FireForgetDispatcher;
43+
friend class AsyncDispatcher;
44+
friend class ComputationThreadPoolCore;
45+
46+
std::function<void()> m_task;
47+
bool m_finished;
48+
std::atomic<bool> m_stopped_with_error;
49+
std::exception_ptr m_exception;
50+
mutable std::mutex m_lock;
51+
std::condition_variable m_cv;
52+
};
53+
54+
55+
56+
}
57+
#endif

Common/Cpp/Concurrency/ParallelTaskRunner.cpp renamed to Common/Cpp/Concurrency/ComputationThreadPool.cpp

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,13 @@
44
*
55
*/
66

7-
#if _WIN32
8-
#include <Windows.h>
9-
#endif
7+
#include <deque>
108
#include "Common/Cpp/PanicDump.h"
119
#include "Common/Cpp/Containers/Pimpl.tpp"
1210
#include "Common/Cpp/CpuUtilization/CpuUtilization.h"
1311
#include "Common/Cpp/Stopwatch.h"
14-
#include "ParallelTaskRunner.h"
12+
#include "AsyncTask.h"
13+
#include "ComputationThreadPool.h"
1514

1615
//#include <iostream>
1716
//using std::cout;
@@ -25,14 +24,14 @@ namespace PokemonAutomation{
2524

2625

2726

28-
class ParallelTaskRunnerCore final{
27+
class ComputationThreadPoolCore final{
2928
public:
30-
ParallelTaskRunnerCore(
29+
ComputationThreadPoolCore(
3130
std::function<void()>&& new_thread_callback,
3231
size_t starting_threads,
3332
size_t max_threads
3433
);
35-
~ParallelTaskRunnerCore();
34+
~ComputationThreadPoolCore();
3635

3736
size_t current_threads() const{
3837
std::lock_guard<std::mutex> lg(m_lock);
@@ -94,7 +93,7 @@ class ParallelTaskRunnerCore final{
9493

9594

9695

97-
ParallelTaskRunner::ParallelTaskRunner(
96+
ComputationThreadPool::ComputationThreadPool(
9897
std::function<void()>&& new_thread_callback,
9998
size_t starting_threads,
10099
size_t max_threads
@@ -106,29 +105,29 @@ ParallelTaskRunner::ParallelTaskRunner(
106105
max_threads
107106
)
108107
{}
109-
ParallelTaskRunner::~ParallelTaskRunner() = default;
110-
size_t ParallelTaskRunner::current_threads() const{
108+
ComputationThreadPool::~ComputationThreadPool() = default;
109+
size_t ComputationThreadPool::current_threads() const{
111110
return m_core->current_threads();
112111
}
113-
size_t ParallelTaskRunner::max_threads() const{
112+
size_t ComputationThreadPool::max_threads() const{
114113
return m_core->max_threads();
115114
}
116-
WallDuration ParallelTaskRunner::cpu_time() const{
115+
WallDuration ComputationThreadPool::cpu_time() const{
117116
return m_core->cpu_time();
118117
}
119-
void ParallelTaskRunner::ensure_threads(size_t threads){
118+
void ComputationThreadPool::ensure_threads(size_t threads){
120119
m_core->ensure_threads(threads);
121120
}
122121
//void ParallelTaskRunner::wait_for_everything(){
123122
// m_core->wait_for_everything();
124123
//}
125-
std::unique_ptr<AsyncTask> ParallelTaskRunner::blocking_dispatch(std::function<void()>&& func){
124+
std::unique_ptr<AsyncTask> ComputationThreadPool::blocking_dispatch(std::function<void()>&& func){
126125
return m_core->blocking_dispatch(std::move(func));
127126
}
128-
std::unique_ptr<AsyncTask> ParallelTaskRunner::try_dispatch(std::function<void()>& func){
127+
std::unique_ptr<AsyncTask> ComputationThreadPool::try_dispatch(std::function<void()>& func){
129128
return m_core->try_dispatch(func);
130129
}
131-
void ParallelTaskRunner::run_in_parallel(
130+
void ComputationThreadPool::run_in_parallel(
132131
const std::function<void(size_t index)>& func,
133132
size_t start, size_t end,
134133
size_t block_size
@@ -143,7 +142,7 @@ void ParallelTaskRunner::run_in_parallel(
143142

144143

145144

146-
ParallelTaskRunnerCore::ParallelTaskRunnerCore(
145+
ComputationThreadPoolCore::ComputationThreadPoolCore(
147146
std::function<void()>&& new_thread_callback,
148147
size_t starting_threads,
149148
size_t max_threads
@@ -157,7 +156,7 @@ ParallelTaskRunnerCore::ParallelTaskRunnerCore(
157156
spawn_thread();
158157
}
159158
}
160-
ParallelTaskRunnerCore::~ParallelTaskRunnerCore(){
159+
ComputationThreadPoolCore::~ComputationThreadPoolCore(){
161160
{
162161
std::lock_guard<std::mutex> lg(m_lock);
163162
m_stopping = true;
@@ -172,7 +171,7 @@ ParallelTaskRunnerCore::~ParallelTaskRunnerCore(){
172171
}
173172
}
174173

175-
WallDuration ParallelTaskRunnerCore::cpu_time() const{
174+
WallDuration ComputationThreadPoolCore::cpu_time() const{
176175
// TODO: Don't lock the entire queue.
177176
WallDuration ret = WallDuration::zero();
178177
std::lock_guard<std::mutex> lg(m_lock);
@@ -184,22 +183,22 @@ WallDuration ParallelTaskRunnerCore::cpu_time() const{
184183
}
185184

186185

187-
void ParallelTaskRunnerCore::ensure_threads(size_t threads){
186+
void ComputationThreadPoolCore::ensure_threads(size_t threads){
188187
std::lock_guard<std::mutex> lg(m_lock);
189188
while (m_threads.size() < threads){
190189
spawn_thread();
191190
}
192191
}
193192
#if 0
194-
void ParallelTaskRunnerCore::wait_for_everything(){
193+
void ComputationThreadPoolCore::wait_for_everything(){
195194
std::unique_lock<std::mutex> lg(m_lock);
196195
m_dispatch_cv.wait(lg, [this]{
197196
return m_queue.size() + m_busy_count == 0;
198197
});
199198
}
200199
#endif
201200

202-
std::unique_ptr<AsyncTask> ParallelTaskRunnerCore::blocking_dispatch(std::function<void()>&& func){
201+
std::unique_ptr<AsyncTask> ComputationThreadPoolCore::blocking_dispatch(std::function<void()>&& func){
203202
std::unique_ptr<AsyncTask> task(new AsyncTask(std::move(func)));
204203

205204
std::unique_lock<std::mutex> lg(m_lock);
@@ -219,7 +218,7 @@ std::unique_ptr<AsyncTask> ParallelTaskRunnerCore::blocking_dispatch(std::functi
219218

220219
return task;
221220
}
222-
std::unique_ptr<AsyncTask> ParallelTaskRunnerCore::try_dispatch(std::function<void()>& func){
221+
std::unique_ptr<AsyncTask> ComputationThreadPoolCore::try_dispatch(std::function<void()>& func){
223222
std::lock_guard<std::mutex> lg(m_lock);
224223

225224
if (m_queue.size() + m_busy_count >= m_max_threads){
@@ -241,7 +240,7 @@ std::unique_ptr<AsyncTask> ParallelTaskRunnerCore::try_dispatch(std::function<vo
241240
}
242241

243242

244-
void ParallelTaskRunnerCore::run_in_parallel(
243+
void ComputationThreadPoolCore::run_in_parallel(
245244
const std::function<void(size_t index)>& func,
246245
size_t start, size_t end,
247246
size_t block_size
@@ -261,6 +260,7 @@ void ParallelTaskRunnerCore::run_in_parallel(
261260
size_t blocks = (total + block_size - 1) / block_size;
262261

263262
std::vector<std::unique_ptr<AsyncTask>> tasks;
263+
tasks.reserve(blocks);
264264
for (size_t c = 0; c < blocks; c++){
265265
tasks.emplace_back(blocking_dispatch([=, &func]{
266266
size_t s = start + c * block_size;
@@ -279,15 +279,15 @@ void ParallelTaskRunnerCore::run_in_parallel(
279279

280280

281281

282-
void ParallelTaskRunnerCore::spawn_thread(){
282+
void ComputationThreadPoolCore::spawn_thread(){
283283
ThreadData& handle = m_threads.emplace_back();
284284
handle.thread = std::thread(
285285
run_with_catch,
286286
"ParallelTaskRunner::thread_loop()",
287287
[&, this]{ thread_loop(handle); }
288288
);
289289
}
290-
void ParallelTaskRunnerCore::thread_loop(ThreadData& data){
290+
void ComputationThreadPoolCore::thread_loop(ThreadData& data){
291291
data.handle = current_thread_handle();
292292

293293
if (m_new_thread_callback){
@@ -336,7 +336,7 @@ void ParallelTaskRunnerCore::thread_loop(ThreadData& data){
336336

337337

338338

339-
//template class Pimpl<ParallelTaskRunnerCore>;
339+
//template class Pimpl<ComputationThreadPoolCore>;
340340

341341

342342

0 commit comments

Comments
 (0)