@@ -50,7 +50,9 @@ struct ParallelToTbb : public mlir::OpRewritePattern<mlir::scf::ParallelOp>
5050 {
5151 return mlir::failure ();
5252 }
53- if (!op->hasAttr (plier::attributes::getParallelName ()))
53+ bool need_parallel = op->hasAttr (plier::attributes::getParallelName ()) ||
54+ !op->getParentOfType <mlir::scf::ParallelOp>();
55+ if (!need_parallel)
5456 {
5557 return mlir::failure ();
5658 }
@@ -85,31 +87,46 @@ struct ParallelToTbb : public mlir::OpRewritePattern<mlir::scf::ParallelOp>
8587 auto reduce = rewriter.create <mlir::AllocaOp>(loc, reduce_type);
8688 auto index = static_cast <unsigned >(it.index ());
8789 reduce_vars[index] = reduce;
88- auto zero = getZeroVal (rewriter, loc, type);
89- mapping.map (op.initVals ()[index], zero);
90- for (unsigned i = 0 ; i < max_concurrency; ++i)
90+ }
91+
92+ auto reduce_init_body_builder = [&](mlir::OpBuilder &builder, mlir::Location loc, mlir::Value index, mlir::ValueRange args)
93+ {
94+ assert (args.empty ());
95+ (void )args;
96+ for (auto it : llvm::enumerate (reduce_vars))
9197 {
92- mlir::Value index = rewriter.create <mlir::ConstantIndexOp>(loc, i);
93- rewriter.create <mlir::StoreOp>(loc, zero, reduce, index);
98+ auto reduce = it.value ();
99+ auto type = op.getResultTypes ()[it.index ()];
100+ auto zero = getZeroVal (rewriter, loc, type);
101+ builder.create <mlir::StoreOp>(loc, zero, reduce, index);
94102 }
95- }
103+ builder.create <mlir::scf::YieldOp>(loc);
104+ };
105+
106+ auto reduce_lower_bound = rewriter.create <mlir::ConstantIndexOp>(loc, 0 );
107+ auto reduce_upper_bound = rewriter.create <mlir::ConstantIndexOp>(loc, max_concurrency);
108+ auto reduce_step = rewriter.create <mlir::ConstantIndexOp>(loc, 1 );
109+ rewriter.create <mlir::scf::ForOp>(loc, reduce_lower_bound, reduce_upper_bound, reduce_step, llvm::None, reduce_init_body_builder);
96110
97111 auto & old_body = op.getLoopBody ().front ();
98112 auto orig_lower_bound = op.lowerBound ().front ();
99113 auto orig_upper_bound = op.upperBound ().front ();
100114 auto orig_step = op.step ().front ();
101115 auto body_builder = [&](mlir::OpBuilder &builder, ::mlir::Location loc, mlir::Value lower_bound, mlir::Value upper_bound, mlir::Value thread_index)
102116 {
117+ llvm::SmallVector<mlir::Value, 8 > initVals (op.initVals ().size ());
103118 for (auto it : llvm::enumerate (op.initVals ()))
104119 {
105120 auto reduce_var = reduce_vars[it.index ()];
106121 auto val = builder.create <mlir::LoadOp>(loc, reduce_var, thread_index);
107- mapping. map ( it.value (), val) ;
122+ initVals[ it.index ()] = val;
108123 }
109124 auto new_op = mlir::cast<mlir::scf::ParallelOp>(builder.clone (*op, mapping));
125+ new_op->removeAttr (plier::attributes::getParallelName ());
110126 assert (new_op->getNumResults () == reduce_vars.size ());
111127 new_op.lowerBoundMutable ().assign (lower_bound);
112128 new_op.upperBoundMutable ().assign (upper_bound);
129+ new_op.initValsMutable ().assign (initVals);
113130 for (auto it : llvm::enumerate (new_op->getResults ()))
114131 {
115132 auto reduce_var = reduce_vars[it.index ()];
@@ -119,10 +136,6 @@ struct ParallelToTbb : public mlir::OpRewritePattern<mlir::scf::ParallelOp>
119136
120137 rewriter.create <plier::ParallelOp>(loc, orig_lower_bound, orig_upper_bound, orig_step, body_builder);
121138
122- auto reduce_lower_bound = rewriter.create <mlir::ConstantIndexOp>(loc, 0 );
123- auto reduce_upper_bound = rewriter.create <mlir::ConstantIndexOp>(loc, max_concurrency);
124- auto reduce_step = rewriter.create <mlir::ConstantIndexOp>(loc, 1 );
125-
126139 auto reduce_body_builder = [&](mlir::OpBuilder &builder, mlir::Location loc, mlir::Value index, mlir::ValueRange args)
127140 {
128141 assert (args.size () == reduce_vars.size ());
0 commit comments