Commit f8941a12 authored by Louis Feng's avatar Louis Feng Committed by Scott Cyphers

More efficient sum for some cases (#1251)

* hacking to support dot of 3 by 2 inputs with gemm_batch.

* clean up.

* testing inplace reshape.

* fixed a compile error.

* added comments on todo.

* check for output.

* check for annotation.

* more optimizations WIP.

* sum simd.

* moved parallel for

* testing sum vectorization.

* fixed merge errors.

* sum wip.

* more logic.

* sum refactor and clean up.

* clean up.

* removed unrelated changes.

* removed related changes from merge.

* fixed clang compile errors.
parent 92adea38
......@@ -181,8 +181,9 @@ void codegen::CompilerCore::initialize()
args.push_back("-inline-threshold=1000000");
if (m_enable_pass_report)
{
args.push_back("-Rpass-analysis=loop-vectorize");
args.push_back("-Rpass=loop-vectorize");
args.push_back("-Rpass-analysis=.*");
args.push_back("-Rpass=.*");
args.push_back("-Rpass-missed=.*");
}
// Prevent Eigen from using any LGPL3 code
args.push_back("-DEIGEN_MPL2_ONLY");
......
......@@ -19,6 +19,7 @@
#include "ngraph/codegen/code_writer.hpp"
#include "ngraph/runtime/cpu/cpu_kernel_emitters.hpp"
#include "ngraph/runtime/cpu/cpu_kernel_utils.hpp"
#include "ngraph/util.hpp"
using namespace ngraph;
using namespace std;
......@@ -290,6 +291,178 @@ void ngraph::runtime::cpu::kernel::emit_reshape(codegen::CodeWriter& writer,
close_for_loops(writer, index_vars);
}
struct SumHeuristic
{
SumHeuristic(const Shape& in_shape, const AxisSet& reduction_axes, const string& output_var)
: m_in_shape(in_shape)
, m_reduction_axes(reduction_axes)
, m_output_var(output_var)
{
analyze();
}
string get_thread_safe_dest() const { return m_thread_safe_dest; }
void emit_omp(codegen::CodeWriter& writer, const size_t loop_index) const
{
if (!m_skip_parallel_for && loop_index == m_parallel_for_index)
{
writer << m_omp_parallel_for;
if (m_loop_reduction)
{
writer << " reduction(+:" + m_thread_safe_dest + ")";
}
writer << "\n";
}
else if (loop_index == m_simd_index && !m_fused_parallel_simd)
{
writer << m_omp_simd;
if (m_simd_reduction)
{
writer << " reduction(+:" + m_thread_safe_dest + ")";
}
writer << "\n";
}
}
void emit_thread_local(codegen::CodeWriter& writer,
const size_t loop_index,
const vector<string>& out_indexes,
const Shape& out_shape,
const std::string& element_type)
{
if (!m_skip_parallel_for && !m_thread_safe && loop_index == m_parallel_for_index)
{
string thread_local_buff = writer.generate_temporary_name("thread_local_buff");
m_thread_safe_dest = writer.generate_temporary_name("thread_safe_dest");
// generate thread local buffer
writer << "std::vector<" << element_type << "> " << thread_local_buff << "("
<< shape_size(out_shape) << ", 0);\n";
string bracketed_shape = emit_bracketed_string(out_shape);
writer << element_type << "(&" << m_thread_safe_dest << ")" << bracketed_shape
<< " = *reinterpret_cast<" << element_type << "(*)" << bracketed_shape << ">(&"
<< thread_local_buff << "[0]);\n";
m_thread_safe_dest += emit_bracketed_string(out_indexes);
}
}
void emit_thread_local_finalize(codegen::CodeWriter& writer,
const size_t loop_index,
const vector<string>& index_vars,
const vector<string>& out_indexes,
const Shape& out_shape)
{
// generate global reduction loop inside the parallel for
if (!m_skip_parallel_for && !m_thread_safe && loop_index == m_parallel_for_index)
{
auto out_brackets = emit_bracketed_string(out_indexes);
// iterate from parallel_for_index to remaining indexes
// emit loops in out_indexs
size_t j = 0;
size_t emit_count = 0;
for (size_t k = 0; k < m_in_shape.size() && j < out_shape.size(); ++k)
{
// found a match
if (index_vars[k] == out_indexes[j])
{
// emit loop
if (k > m_parallel_for_index)
{
string index_var = out_indexes[j];
writer << ngraph::runtime::cpu::kernel::start_index_loop(
index_var, 0, out_shape[j], false);
writer.indent++;
emit_count++;
}
// move to next out index
++j;
}
}
writer << "#pragma omp atomic\n";
writer << m_output_var << " += " << m_thread_safe_dest << ";\n";
for (size_t k = 0; k < emit_count; k++)
{
writer.indent--;
writer << "}\n";
}
}
}
private:
void analyze()
{
// Heuristics
// set simd_index to inner most loop
m_simd_index = m_in_shape.size() - 1;
// for inference we may have batch size 1 in the outer
// loop, skip those to get better thread level parallelism
for (size_t i = 0; i < m_in_shape.size(); i++)
{
if (m_in_shape[i] > 1)
{
m_parallel_for_index = i;
break;
}
}
// check if there's any varying (non-reduction) indexes starting from
// parallel_for_index, if not, we can do reduction for parallel for
m_loop_reduction = true;
for (size_t i = m_parallel_for_index; i < m_in_shape.size(); ++i)
{
if (m_reduction_axes.count(i) == 0)
{
m_loop_reduction = false;
break;
}
}
// use simd reduction if simd_index is a reduction axis assuming
// it's the inner most loop
m_simd_reduction =
(m_reduction_axes.count(m_simd_index) != 0) && (m_simd_index == m_in_shape.size() - 1);
// if output var has parallel_for_index, it's thread safe
// due to parallel for partition or if loop_reduction is true
m_thread_safe = (m_reduction_axes.count(m_parallel_for_index) == 0) || m_loop_reduction;
// if we have two level of nested loop between parallel for and simd,
// and output is not thread safe, we skip parallel for and use just simd
if ((m_simd_index - m_parallel_for_index) == 1 && !m_thread_safe)
{
m_skip_parallel_for = true;
}
// use output variable directly if thread safe, otherwise a thread safe
// temp output will be generated in emit_thread_local()
if (m_thread_safe || m_skip_parallel_for)
{
m_thread_safe_dest = m_output_var;
}
// parallel_for_index matches simd_index so fuse them
if (m_simd_index == m_parallel_for_index)
{
m_fused_parallel_simd = true;
m_omp_parallel_for = "#pragma omp parallel for simd";
}
}
Shape m_in_shape;
AxisSet m_reduction_axes;
string m_output_var;
// use simd for the inner most loop
size_t m_simd_index{0};
// set parallel for to inner loop unless we find something better below
size_t m_parallel_for_index{m_simd_index};
// Optimization heuristics
// parallel for and simd for the same loop
bool m_fused_parallel_simd{false};
// global sum is thread safe in parallel for
bool m_thread_safe{false};
// global sum can use parallel for reduction clause
bool m_loop_reduction{false};
bool m_simd_reduction{false};
string m_omp_parallel_for{"#pragma omp parallel for"};
string m_omp_simd{"#pragma omp simd"};
string m_thread_safe_dest;
bool m_skip_parallel_for{false};
};
void ngraph::runtime::cpu::kernel::emit_sum(codegen::CodeWriter& writer,
const string& element_type,
const string& arg0, // replacement context
......@@ -306,17 +479,19 @@ void ngraph::runtime::cpu::kernel::emit_sum(codegen::CodeWriter& writer,
if (out_shape.size() == 0)
{
writer << dest_nd_name << " = 0;\n";
writer << element_type << " residual = 0;\n";
}
else
{
writer << element_type << " residual" << emit_bracketed_string(out_shape) << ";\n";
auto output_vars = open_for_loops(writer, out_shape);
writer << dest_nd_name << emit_bracketed_string(output_vars) << " = 0;\n";
writer << "residual" << emit_bracketed_string(output_vars) << " = 0;\n";
close_for_loops(writer, output_vars);
auto out_array =
recast_tmp_var(writer, element_type, out, Shape(1, shape_size(out_shape)), "out_array");
writer << "#pragma omp parallel for simd\n";
size_t s = shape_size(out_shape);
string index_var = writer.generate_temporary_name("i");
writer << "for(size_t " << index_var << " = 0; " << index_var << " < " << s << "; "
<< index_var << "++)\n";
writer.block_begin();
writer << out_array << "[" << index_var << "] = 0;\n";
writer.block_end();
}
// If we don't have a zero index in the input, perform the sum
......@@ -324,6 +499,7 @@ void ngraph::runtime::cpu::kernel::emit_sum(codegen::CodeWriter& writer,
{
// create the the interation variables without writing the for loops
vector<string> index_vars;
for (size_t i = 0; i < arg0_shape.size(); i++)
{
string index_var = writer.generate_temporary_name("i");
......@@ -332,45 +508,38 @@ void ngraph::runtime::cpu::kernel::emit_sum(codegen::CodeWriter& writer,
// calculate the output indexes based on what's being reduced
vector<string> out_indexes;
size_t outer_arg_index = -1;
for (size_t i = 0; i < index_vars.size(); ++i)
{
if (reduction_axes.count(i) == 0)
{
if (out_indexes.size() == 0)
{
outer_arg_index = i;
}
out_indexes.push_back(index_vars[i]);
}
}
// make the first output shape our outer loop, optimize with openmp
if (outer_arg_index != -1)
auto out_brackets = emit_bracketed_string(out_indexes);
auto dst = dest_nd_name + out_brackets;
auto src = source_nd_name + emit_bracketed_string(index_vars);
SumHeuristic heuristic(arg0_shape, reduction_axes, dst);
// emit the for loops
for (size_t i = 0; i < arg0_shape.size(); i++)
{
writer << start_index_loop(
index_vars[outer_arg_index], 0, arg0_shape[outer_arg_index], true);
string index_var = index_vars[i];
heuristic.emit_omp(writer, i);
writer << start_index_loop(index_var, 0, arg0_shape[i], false);
writer.indent++;
heuristic.emit_thread_local(writer, i, out_indexes, out_shape, element_type);
}
// create the rest of the loops, don't parallelize.
for (size_t i = 0; i < arg0_shape.size(); i++)
// thread local reduction
writer << heuristic.get_thread_safe_dest() << " += " << src << ";\n";
// close the loops
for (size_t i = arg0_shape.size(); i > 0; i--)
{
if (i != outer_arg_index)
{
string index_var = index_vars[i];
writer << start_index_loop(index_var, 0, arg0_shape[i], false);
writer.indent++;
}
heuristic.emit_thread_local_finalize(writer, i - 1, index_vars, out_indexes, out_shape);
writer.indent--;
writer << "}\n";
}
auto out_brackets = emit_bracketed_string(out_indexes);
auto dst = dest_nd_name + out_brackets;
auto src = source_nd_name + emit_bracketed_string(index_vars);
writer << element_type << " y = " << src << " - residual" << out_brackets << ";\n";
writer << element_type << " t = " << dst << " + y;\n";
writer << "residual" << out_brackets << " = (t - " << dst << ") - y;\n";
writer << dst << " = t;\n";
close_for_loops(writer, index_vars);
}
}
void ngraph::runtime::cpu::kernel::emit_reduce(codegen::CodeWriter& writer,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment