Skip to content
This repository was archived by the owner on Jan 25, 2023. It is now read-only.

Commit d22d464

Browse files
author
Ivan Butygin
authored
[MLIR] tbb fixes (#189)
* fix plier parallel * use loop to init reduce vars * transfor to tbb parallel if have parallel attr or outermost loop * we dont need fix_tls_observer, also do not recreate task arena each time
1 parent ee7dfd7 commit d22d464

File tree

2 files changed

+48
-16
lines changed

2 files changed

+48
-16
lines changed

mlir-compiler/mlir-compiler/src/pipelines/parallel_to_tbb.cpp

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,9 @@ struct ParallelToTbb : public mlir::OpRewritePattern<mlir::scf::ParallelOp>
5050
{
5151
return mlir::failure();
5252
}
53-
if (!op->hasAttr(plier::attributes::getParallelName()))
53+
bool need_parallel = op->hasAttr(plier::attributes::getParallelName()) ||
54+
!op->getParentOfType<mlir::scf::ParallelOp>();
55+
if (!need_parallel)
5456
{
5557
return mlir::failure();
5658
}
@@ -85,31 +87,46 @@ struct ParallelToTbb : public mlir::OpRewritePattern<mlir::scf::ParallelOp>
8587
auto reduce = rewriter.create<mlir::AllocaOp>(loc, reduce_type);
8688
auto index = static_cast<unsigned>(it.index());
8789
reduce_vars[index] = reduce;
88-
auto zero = getZeroVal(rewriter, loc, type);
89-
mapping.map(op.initVals()[index], zero);
90-
for (unsigned i = 0; i < max_concurrency; ++i)
90+
}
91+
92+
auto reduce_init_body_builder = [&](mlir::OpBuilder &builder, mlir::Location loc, mlir::Value index, mlir::ValueRange args)
93+
{
94+
assert(args.empty());
95+
(void)args;
96+
for (auto it : llvm::enumerate(reduce_vars))
9197
{
92-
mlir::Value index = rewriter.create<mlir::ConstantIndexOp>(loc, i);
93-
rewriter.create<mlir::StoreOp>(loc, zero, reduce, index);
98+
auto reduce = it.value();
99+
auto type = op.getResultTypes()[it.index()];
100+
auto zero = getZeroVal(rewriter, loc, type);
101+
builder.create<mlir::StoreOp>(loc, zero, reduce, index);
94102
}
95-
}
103+
builder.create<mlir::scf::YieldOp>(loc);
104+
};
105+
106+
auto reduce_lower_bound = rewriter.create<mlir::ConstantIndexOp>(loc, 0);
107+
auto reduce_upper_bound = rewriter.create<mlir::ConstantIndexOp>(loc, max_concurrency);
108+
auto reduce_step = rewriter.create<mlir::ConstantIndexOp>(loc, 1);
109+
rewriter.create<mlir::scf::ForOp>(loc, reduce_lower_bound, reduce_upper_bound, reduce_step, llvm::None, reduce_init_body_builder);
96110

97111
auto& old_body = op.getLoopBody().front();
98112
auto orig_lower_bound = op.lowerBound().front();
99113
auto orig_upper_bound = op.upperBound().front();
100114
auto orig_step = op.step().front();
101115
auto body_builder = [&](mlir::OpBuilder &builder, ::mlir::Location loc, mlir::Value lower_bound, mlir::Value upper_bound, mlir::Value thread_index)
102116
{
117+
llvm::SmallVector<mlir::Value, 8> initVals(op.initVals().size());
103118
for (auto it : llvm::enumerate(op.initVals()))
104119
{
105120
auto reduce_var = reduce_vars[it.index()];
106121
auto val = builder.create<mlir::LoadOp>(loc, reduce_var, thread_index);
107-
mapping.map(it.value(), val);
122+
initVals[it.index()] = val;
108123
}
109124
auto new_op = mlir::cast<mlir::scf::ParallelOp>(builder.clone(*op, mapping));
125+
new_op->removeAttr(plier::attributes::getParallelName());
110126
assert(new_op->getNumResults() == reduce_vars.size());
111127
new_op.lowerBoundMutable().assign(lower_bound);
112128
new_op.upperBoundMutable().assign(upper_bound);
129+
new_op.initValsMutable().assign(initVals);
113130
for (auto it : llvm::enumerate(new_op->getResults()))
114131
{
115132
auto reduce_var = reduce_vars[it.index()];
@@ -119,10 +136,6 @@ struct ParallelToTbb : public mlir::OpRewritePattern<mlir::scf::ParallelOp>
119136

120137
rewriter.create<plier::ParallelOp>(loc, orig_lower_bound, orig_upper_bound, orig_step, body_builder);
121138

122-
auto reduce_lower_bound = rewriter.create<mlir::ConstantIndexOp>(loc, 0);
123-
auto reduce_upper_bound = rewriter.create<mlir::ConstantIndexOp>(loc, max_concurrency);
124-
auto reduce_step = rewriter.create<mlir::ConstantIndexOp>(loc, 1);
125-
126139
auto reduce_body_builder = [&](mlir::OpBuilder &builder, mlir::Location loc, mlir::Value index, mlir::ValueRange args)
127140
{
128141
assert(args.size() == reduce_vars.size());

numba/np/ufunc/tbbpool.cpp

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,21 @@ Implement parallel vectorize workqueue on top of Intel TBB.
4242

4343
static tbb::task_group *tg = NULL;
4444
static tbb::task_scheduler_init *tsi = NULL;
45+
46+
namespace
47+
{
48+
struct ThreadContext
49+
{
50+
ThreadContext(int n_threads):
51+
num_threads(n_threads),
52+
arena(n_threads) {}
53+
54+
int num_threads = 0;
55+
tbb::task_arena arena;
56+
};
57+
static ThreadContext* thread_context = nullptr;
58+
}
59+
4560
static int tsi_count = 0;
4661

4762
#ifdef _MSC_VER
@@ -209,15 +224,15 @@ parallel_for(void *fn, char **args, size_t *dimensions, size_t *steps, void *dat
209224
using parallel_for2_fptr = void(*)(size_t, size_t, size_t, void*);
210225
static void parallel_for2(size_t lower_bound, size_t upper_bound, size_t step, parallel_for2_fptr func, void* ctx)
211226
{
212-
auto num_threads = get_num_threads();
227+
auto context = thread_context;
228+
assert(nullptr != context);
229+
auto num_threads = context->num_threads;
213230
if(_DEBUG)
214231
{
215232
printf("parallel_for2 %d %d %d %d\n", (int)lower_bound, (int)upper_bound, (int)step, (int)num_threads);
216233
}
217-
tbb::task_arena limited(num_threads);
218-
fix_tls_observer observer(limited, num_threads);
219234

220-
limited.execute([&]
235+
context->arena.execute([&]
221236
{
222237
size_t count = (upper_bound - lower_bound - 1) / step + 1;
223238
size_t grain = std::max(size_t(1), std::min(count / num_threads / 2, size_t(64)));
@@ -284,6 +299,8 @@ static void unload_tbb(void)
284299
tbb::set_assertion_handler(orig);
285300
delete tsi;
286301
tsi = NULL;
302+
delete thread_context;
303+
thread_context = nullptr;
287304
}
288305
}
289306
#endif
@@ -300,6 +317,8 @@ static void launch_threads(int count)
300317
tg = new tbb::task_group;
301318
tg->run([] {}); // start creating threads asynchronously
302319

320+
thread_context = new ThreadContext(count);
321+
303322
_INIT_NUM_THREADS = count;
304323

305324
#ifndef _MSC_VER

0 commit comments

Comments
 (0)