Unverified Commit 2bf2214f authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

use codegen block_begin/end methods to clean up code a little (#738)

parent b6240057
...@@ -133,8 +133,7 @@ namespace ngraph ...@@ -133,8 +133,7 @@ namespace ngraph
{ {
// TODO: Audit all uses of Add and fix this to use // TODO: Audit all uses of Add and fix this to use
// the right alignment instead of Eigen::Unaligned // the right alignment instead of Eigen::Unaligned
writer << "{ // " << node->get_name() << "\n"; writer.block_begin();
writer.indent++;
#if PREFER_EIGEN == 1 #if PREFER_EIGEN == 1
writer << "Eigen::Map<Eigen::Array<" << out[0].get_element_type().c_type_string() writer << "Eigen::Map<Eigen::Array<" << out[0].get_element_type().c_type_string()
<< ", " << out[0].get_size() << ", 1>, Eigen::Unaligned> out(" << ", " << out[0].get_size() << ", 1>, Eigen::Unaligned> out("
...@@ -189,14 +188,13 @@ namespace ngraph ...@@ -189,14 +188,13 @@ namespace ngraph
{ {
writer << "#pragma omp parallel for\n"; writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n"; writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer << "{\n"; writer.block_begin();
writer << " " << out[0].get_name() << "[i] = " << args[0].get_name() writer << out[0].get_name() << "[i] = " << args[0].get_name() << "[i] + "
<< "[i] + " << args[1].get_name() << "[i];\n"; << args[1].get_name() << "[i];\n";
writer << "}\n"; writer.block_end();
} }
#endif #endif
writer.indent--; writer.block_end();
writer << "}\n";
} }
#ifdef NGRAPH_DISTRIBUTED #ifdef NGRAPH_DISTRIBUTED
...@@ -215,13 +213,11 @@ namespace ngraph ...@@ -215,13 +213,11 @@ namespace ngraph
data_type = "MPI_DOUBLE"; data_type = "MPI_DOUBLE";
} }
writer << "{ // " << node->get_name() << "\n"; writer.block_begin();
writer.indent++;
writer << "MPI_Allreduce(" << args[0].get_name() << ", " << out[0].get_name() writer << "MPI_Allreduce(" << args[0].get_name() << ", " << out[0].get_name()
<< ", " << out[0].get_size() << ", " << data_type << ", " << out[0].get_size() << ", " << data_type
<< ", MPI_SUM, MPI_COMM_WORLD);\n"; << ", MPI_SUM, MPI_COMM_WORLD);\n";
writer.indent--; writer.block_end();
writer << "}\n";
} }
#endif #endif
...@@ -259,8 +255,7 @@ namespace ngraph ...@@ -259,8 +255,7 @@ namespace ngraph
n = arg1_shape[0]; n = arg1_shape[0];
} }
writer << "{ // " << node->get_name() << "\n"; writer.block_begin();
writer.indent++;
const char* cbeta = "0.0f"; const char* cbeta = "0.0f";
...@@ -364,8 +359,7 @@ namespace ngraph ...@@ -364,8 +359,7 @@ namespace ngraph
} }
} }
writer.indent--; writer.block_end();
writer << "}\n";
} }
template <> template <>
...@@ -374,8 +368,7 @@ namespace ngraph ...@@ -374,8 +368,7 @@ namespace ngraph
const ngraph::op::BatchNorm* batchnorm = const ngraph::op::BatchNorm* batchnorm =
static_cast<const ngraph::op::BatchNorm*>(node); static_cast<const ngraph::op::BatchNorm*>(node);
writer.indent++; writer.block_begin();
writer << "{\n";
// define weights // define weights
writer << "std::vector<" << args[0].get_element_type().c_type_string() writer << "std::vector<" << args[0].get_element_type().c_type_string()
<< ">bn_weights(2*" << args[0].get_size() << ");\n"; << ">bn_weights(2*" << args[0].get_size() << ");\n";
...@@ -477,8 +470,7 @@ namespace ngraph ...@@ -477,8 +470,7 @@ namespace ngraph
writer << "cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, " writer << "cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, "
<< to_string(batchnorm_index) << ");\n"; << to_string(batchnorm_index) << ");\n";
} }
writer.indent--; writer.block_end();
writer << "}\n";
} }
template <> template <>
...@@ -487,8 +479,7 @@ namespace ngraph ...@@ -487,8 +479,7 @@ namespace ngraph
const ngraph::op::BatchNormBackprop* batchnorm = const ngraph::op::BatchNormBackprop* batchnorm =
static_cast<const ngraph::op::BatchNormBackprop*>(node); static_cast<const ngraph::op::BatchNormBackprop*>(node);
writer.indent++; writer.block_begin();
writer << "{\n";
// define weights // define weights
writer << "std::vector<" << args[0].get_element_type().c_type_string() writer << "std::vector<" << args[0].get_element_type().c_type_string()
<< ">bn_weights(2*" << args[0].get_size() << ");\n"; << ">bn_weights(2*" << args[0].get_size() << ");\n";
...@@ -554,8 +545,7 @@ namespace ngraph ...@@ -554,8 +545,7 @@ namespace ngraph
writer << "memcpy(" << out[2].get_name() << ", &bn_dweights[0]+" writer << "memcpy(" << out[2].get_name() << ", &bn_dweights[0]+"
<< args[0].get_size() << ", " << args[0].get_size() << ", "
<< args[1].get_size() * args[1].get_element_type().size() << ");\n"; << args[1].get_size() * args[1].get_element_type().size() << ");\n";
writer.indent--; writer.block_end();
writer << "}\n";
} }
template <> template <>
...@@ -570,34 +560,28 @@ namespace ngraph ...@@ -570,34 +560,28 @@ namespace ngraph
auto& first = (arg0_shape.empty() ? args[0] : args[1]); auto& first = (arg0_shape.empty() ? args[0] : args[1]);
auto& second = (arg0_shape.empty() ? args[1] : args[0]); auto& second = (arg0_shape.empty() ? args[1] : args[0]);
writer << "{ // " << node->get_name() << "\n"; writer.block_begin();
writer.indent++;
writer << emit_vector(out[0]) << "\n = "; writer << emit_vector(out[0]) << "\n = ";
writer << first.get_name() << "[0]\n * " << emit_vector(second) << ";\n"; writer << first.get_name() << "[0]\n * " << emit_vector(second) << ";\n";
writer.indent--; writer.block_end();
writer << "}\n";
} }
else if ((arg0_shape.size() == 1) && (arg1_shape.size() == 1) && else if ((arg0_shape.size() == 1) && (arg1_shape.size() == 1) &&
dot->get_reduction_axes_count() == 1) dot->get_reduction_axes_count() == 1)
{ {
writer << "{ // " << node->get_name() << "\n"; writer.block_begin();
writer.indent++;
writer << emit_vector(out[0]) << " << \n" writer << emit_vector(out[0]) << " << \n"
<< " " << emit_vector(args[0]) << ".dot(" << emit_vector(args[1]) << " " << emit_vector(args[0]) << ".dot(" << emit_vector(args[1])
<< ");\n"; << ");\n";
writer.indent--; writer.block_end();
writer << "}\n";
} }
else if ((arg0_shape.size() == 2) && (arg1_shape.size() == 1) && else if ((arg0_shape.size() == 2) && (arg1_shape.size() == 1) &&
dot->get_reduction_axes_count() == 1) dot->get_reduction_axes_count() == 1)
{ {
writer << "{ // " << node->get_name() << "\n"; writer.block_begin();
writer.indent++;
writer << emit_vector(out[0]) << " = \n" writer << emit_vector(out[0]) << " = \n"
<< " " << emit_matrix(args[0]) << " * " << emit_vector(args[1]) << " " << emit_matrix(args[0]) << " * " << emit_vector(args[1])
<< ";\n"; << ";\n";
writer.indent--; writer.block_end();
writer << "}\n";
} }
else if ((arg0_shape.size() == 2) && (arg1_shape.size() == 2) && else if ((arg0_shape.size() == 2) && (arg1_shape.size() == 2) &&
dot->get_reduction_axes_count() == 1) dot->get_reduction_axes_count() == 1)
...@@ -605,8 +589,7 @@ namespace ngraph ...@@ -605,8 +589,7 @@ namespace ngraph
// Emit an MKL SGEMM call if possible // Emit an MKL SGEMM call if possible
if (args[0].get_element_type() == element::f32) if (args[0].get_element_type() == element::f32)
{ {
writer << "{ // " << node->get_name() << "\n"; writer.block_begin();
writer.indent++;
writer << "cblas::cblas_sgemm(" writer << "cblas::cblas_sgemm("
<< "cblas::Layout::RowMajor, " << "cblas::Layout::RowMajor, "
<< "cblas::Transpose::None, " << "cblas::Transpose::None, "
...@@ -617,18 +600,15 @@ namespace ngraph ...@@ -617,18 +600,15 @@ namespace ngraph
<< max(1UL, arg1_shape[1]) << ", 0.0f,\n" << max(1UL, arg1_shape[1]) << ", 0.0f,\n"
<< " " << out[0].get_name() << ", " << max(1UL, arg1_shape[1]) << " " << out[0].get_name() << ", " << max(1UL, arg1_shape[1])
<< ");\n"; << ");\n";
writer.indent--; writer.block_end();
writer << "}\n";
} }
else else
{ {
writer << "{ // " << node->get_name() << "\n"; writer.block_begin();
writer.indent++;
writer << emit_matrix(out[0]) << " = \n" writer << emit_matrix(out[0]) << " = \n"
<< " " << emit_matrix(args[0]) << " * " << emit_matrix(args[1]) << " " << emit_matrix(args[0]) << " * " << emit_matrix(args[1])
<< ";\n"; << ";\n";
writer.indent--; writer.block_end();
writer << "}\n";
} }
} }
else else
...@@ -646,8 +626,7 @@ namespace ngraph ...@@ -646,8 +626,7 @@ namespace ngraph
template <> template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Multiply) void CPU_Emitter::EMITTER_DECL(ngraph::op::Multiply)
{ {
writer << "{ // " << node->get_name() << "\n"; writer.block_begin();
writer.indent++;
#if PREFER_EIGEN == 1 #if PREFER_EIGEN == 1
writer << emit_array1d(out[0]) << " =\n" writer << emit_array1d(out[0]) << " =\n"
<< " " << emit_array1d(args[0]) << " *\n" << " " << emit_array1d(args[0]) << " *\n"
...@@ -655,13 +634,12 @@ namespace ngraph ...@@ -655,13 +634,12 @@ namespace ngraph
#else #else
writer << "#pragma omp parallel for\n"; writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n"; writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer << "{\n"; writer.block_begin();
writer << " " << out[0].get_name() << "[i] = " << args[0].get_name() << "[i] * " writer << out[0].get_name() << "[i] = " << args[0].get_name() << "[i] * "
<< args[1].get_name() << "[i];\n"; << args[1].get_name() << "[i];\n";
writer << "}\n"; writer.block_end();
#endif #endif
writer.indent--; writer.block_end();
writer << "}\n";
} }
template <> template <>
...@@ -669,20 +647,17 @@ namespace ngraph ...@@ -669,20 +647,17 @@ namespace ngraph
{ {
auto get_tuple_element = static_cast<const ngraph::op::GetOutputElement*>(node); auto get_tuple_element = static_cast<const ngraph::op::GetOutputElement*>(node);
writer << "{ // " << node->get_name() << "\n"; writer.block_begin();
writer.indent++;
writer << "memcpy(" << out[0].get_name() << ", " writer << "memcpy(" << out[0].get_name() << ", "
<< args[get_tuple_element->get_n()].get_name() << ", " << args[get_tuple_element->get_n()].get_name() << ", "
<< out[0].get_size() * out[0].get_element_type().size() << ");\n"; << out[0].get_size() * out[0].get_element_type().size() << ");\n";
writer.indent--; writer.block_end();
writer << "}\n";
} }
template <> template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Abs) void CPU_Emitter::EMITTER_DECL(ngraph::op::Abs)
{ {
writer << "{ // " << node->get_name() << "\n"; writer.block_begin();
writer.indent++;
#if PREFER_EIGEN == 1 #if PREFER_EIGEN == 1
writer << emit_array1d(out[0]) << " =\n"; writer << emit_array1d(out[0]) << " =\n";
writer << "Eigen::abs(" << emit_array1d(args[0]) << ");\n"; writer << "Eigen::abs(" << emit_array1d(args[0]) << ");\n";
...@@ -693,14 +668,13 @@ namespace ngraph ...@@ -693,14 +668,13 @@ namespace ngraph
writer << "#pragma omp parallel for\n"; writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n"; writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer << "{\n"; writer.block_begin();
writer << " " << out[0].get_name() writer << out[0].get_name()
<< "[i] = " << (result_element_type.is_signed() ? "std::abs" : "") << "(" << "[i] = " << (result_element_type.is_signed() ? "std::abs" : "") << "("
<< args[0].get_name() << "[i]);\n"; << args[0].get_name() << "[i]);\n";
writer << "}\n"; writer.block_end();
#endif #endif
writer.indent--; writer.block_end();
writer << "}\n";
} }
template <> template <>
...@@ -711,8 +685,7 @@ namespace ngraph ...@@ -711,8 +685,7 @@ namespace ngraph
#if PREFER_EIGEN == 1 #if PREFER_EIGEN == 1
if (result_shape.size() == 1) if (result_shape.size() == 1)
{ {
writer << "{ // " << node->get_name() << "\n"; writer.block_begin();
writer.indent++;
writer << emit_vector(out[0], "out_vector") << ";\n"; writer << emit_vector(out[0], "out_vector") << ";\n";
size_t concat_pos = 0; size_t concat_pos = 0;
...@@ -723,16 +696,14 @@ namespace ngraph ...@@ -723,16 +696,14 @@ namespace ngraph
<< ";\n"; << ";\n";
concat_pos += args[i].get_shape().at(0); concat_pos += args[i].get_shape().at(0);
} }
writer.indent--; writer.block_end();
writer << "}\n";
} }
else if (result_shape.size() == 2) else if (result_shape.size() == 2)
{ {
auto axis = auto axis =
(dynamic_cast<const ngraph::op::Concat*>(node))->get_concatenation_axis(); (dynamic_cast<const ngraph::op::Concat*>(node))->get_concatenation_axis();
writer << "{ // " << node->get_name() << "\n"; writer.block_begin();
writer.indent++;
writer << emit_matrix(out[0], "out_matrix") << ";\n"; writer << emit_matrix(out[0], "out_matrix") << ";\n";
size_t concat_pos[2]{0, 0}; size_t concat_pos[2]{0, 0};
...@@ -747,8 +718,7 @@ namespace ngraph ...@@ -747,8 +718,7 @@ namespace ngraph
concat_pos[axis] += arg_shape.at(axis); concat_pos[axis] += arg_shape.at(axis);
} }
writer.indent--; writer.block_end();
writer << "}\n";
} }
else else
{ {
...@@ -822,17 +792,16 @@ namespace ngraph ...@@ -822,17 +792,16 @@ namespace ngraph
template <> template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Divide) void CPU_Emitter::EMITTER_DECL(ngraph::op::Divide)
{ {
writer << "{ // " << node->get_name() << "\n"; writer.block_begin();
writer.indent++;
if (node->get_element_type().is_real() == false) if (node->get_element_type().is_real() == false)
{ {
// Check for divide by zero for integer types only // Check for divide by zero for integer types only
size_t element_count = args[1].get_size(); size_t element_count = args[1].get_size();
writer << "for (size_t i=0; i<" << element_count << "; i++)\n"; writer << "for (size_t i=0; i<" << element_count << "; i++)\n";
writer << "{\n"; writer.block_begin();
writer << " if (" << args.at(1).get_name() writer << "if (" << args.at(1).get_name()
<< "[i] == 0) throw std::runtime_error(\"integer divide by zero\");\n"; << "[i] == 0) throw std::runtime_error(\"integer divide by zero\");\n";
writer << "}\n"; writer.block_end();
} }
#if PREFER_EIGEN == 1 #if PREFER_EIGEN == 1
writer << emit_array1d(out[0]) << " =\n" writer << emit_array1d(out[0]) << " =\n"
...@@ -841,20 +810,18 @@ namespace ngraph ...@@ -841,20 +810,18 @@ namespace ngraph
#else #else
writer << "#pragma omp parallel for\n"; writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n"; writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer << "{\n"; writer.block_begin();
writer << " " << out[0].get_name() << "[i] = " << args[0].get_name() << "[i] / " writer << out[0].get_name() << "[i] = " << args[0].get_name() << "[i] / "
<< args[1].get_name() << "[i];\n"; << args[1].get_name() << "[i];\n";
writer << "}\n"; writer.block_end();
#endif #endif
writer.indent--; writer.block_end();
writer << "}\n";
} }
template <> template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Equal) void CPU_Emitter::EMITTER_DECL(ngraph::op::Equal)
{ {
writer << "{ // " << node->get_name() << "\n"; writer.block_begin();
writer.indent++;
#if PREFER_EIGEN == 1 #if PREFER_EIGEN == 1
writer << emit_array1d(out[0]) << " =\n" writer << emit_array1d(out[0]) << " =\n"
<< " (" << emit_array1d(args[0]) << " ==\n" << " (" << emit_array1d(args[0]) << " ==\n"
...@@ -862,20 +829,18 @@ namespace ngraph ...@@ -862,20 +829,18 @@ namespace ngraph
#else #else
writer << "#pragma omp parallel for\n"; writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n"; writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer << "{\n"; writer.block_begin();
writer << " " << out[0].get_name() << "[i] = " << args[0].get_name() writer << out[0].get_name() << "[i] = " << args[0].get_name()
<< "[i] == " << args[1].get_name() << "[i];\n"; << "[i] == " << args[1].get_name() << "[i];\n";
writer << "}\n"; writer.block_end();
#endif #endif
writer.indent--; writer.block_end();
writer << "}\n";
} }
template <> template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Greater) void CPU_Emitter::EMITTER_DECL(ngraph::op::Greater)
{ {
writer << "{ // " << node->get_name() << " xxx\n"; writer.block_begin();
writer.indent++;
#if PREFER_EIGEN == 1 #if PREFER_EIGEN == 1
writer << emit_array1d(out[0]) << " =\n" writer << emit_array1d(out[0]) << " =\n"
<< " (" << emit_array1d(args[0]) << " >\n" << " (" << emit_array1d(args[0]) << " >\n"
...@@ -883,20 +848,18 @@ namespace ngraph ...@@ -883,20 +848,18 @@ namespace ngraph
#else #else
writer << "#pragma omp parallel for\n"; writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n"; writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer << "{\n"; writer.block_begin();
writer << " " << out[0].get_name() << "[i] = " << args[0].get_name() << "[i] > " writer << out[0].get_name() << "[i] = " << args[0].get_name() << "[i] > "
<< args[1].get_name() << "[i];\n"; << args[1].get_name() << "[i];\n";
writer << "}\n"; writer.block_end();
#endif #endif
writer.indent--; writer.block_end();
writer << "}\n";
} }
template <> template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::GreaterEq) void CPU_Emitter::EMITTER_DECL(ngraph::op::GreaterEq)
{ {
writer << "{ // " << node->get_name() << "\n"; writer.block_begin();
writer.indent++;
#if PREFER_EIGEN == 1 #if PREFER_EIGEN == 1
writer << emit_array1d(out[0]) << " =\n" writer << emit_array1d(out[0]) << " =\n"
<< " (" << emit_array1d(args[0]) << " >=\n" << " (" << emit_array1d(args[0]) << " >=\n"
...@@ -904,20 +867,18 @@ namespace ngraph ...@@ -904,20 +867,18 @@ namespace ngraph
#else #else
writer << "#pragma omp parallel for\n"; writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n"; writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer << "{\n"; writer.block_begin();
writer << " " << out[0].get_name() << "[i] = " << args[0].get_name() writer << out[0].get_name() << "[i] = " << args[0].get_name()
<< "[i] >= " << args[1].get_name() << "[i];\n"; << "[i] >= " << args[1].get_name() << "[i];\n";
writer << "}\n"; writer.block_end();
#endif #endif
writer.indent--; writer.block_end();
writer << "}\n";
} }
template <> template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Less) void CPU_Emitter::EMITTER_DECL(ngraph::op::Less)
{ {
writer << "{ // " << node->get_name() << "\n"; writer.block_begin();
writer.indent++;
#if PREFER_EIGEN == 1 #if PREFER_EIGEN == 1
writer << emit_array1d(out[0]) << " =\n" writer << emit_array1d(out[0]) << " =\n"
<< " (" << emit_array1d(args[0]) << " <\n" << " (" << emit_array1d(args[0]) << " <\n"
...@@ -925,20 +886,18 @@ namespace ngraph ...@@ -925,20 +886,18 @@ namespace ngraph
#else #else
writer << "#pragma omp parallel for\n"; writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n"; writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer << "{\n"; writer.block_begin();
writer << " " << out[0].get_name() << "[i] = " << args[0].get_name() << "[i] < " writer << out[0].get_name() << "[i] = " << args[0].get_name() << "[i] < "
<< args[1].get_name() << "[i];\n"; << args[1].get_name() << "[i];\n";
writer << "}\n"; writer.block_end();
#endif #endif
writer.indent--; writer.block_end();
writer << "}\n";
} }
template <> template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::LessEq) void CPU_Emitter::EMITTER_DECL(ngraph::op::LessEq)
{ {
writer << "{ // " << node->get_name() << "\n"; writer.block_begin();
writer.indent++;
#if PREFER_EIGEN == 1 #if PREFER_EIGEN == 1
writer << emit_array1d(out[0]) << " =\n" writer << emit_array1d(out[0]) << " =\n"
<< " (" << emit_array1d(args[0]) << " <=\n" << " (" << emit_array1d(args[0]) << " <=\n"
...@@ -946,40 +905,35 @@ namespace ngraph ...@@ -946,40 +905,35 @@ namespace ngraph
#else #else
writer << "#pragma omp parallel for\n"; writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n"; writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer << "{\n"; writer.block_begin();
writer << " " << out[0].get_name() << "[i] = " << args[0].get_name() writer << out[0].get_name() << "[i] = " << args[0].get_name()
<< "[i] <= " << args[1].get_name() << "[i];\n"; << "[i] <= " << args[1].get_name() << "[i];\n";
writer << "}\n"; writer.block_end();
#endif #endif
writer.indent--; writer.block_end();
writer << "}\n";
} }
template <> template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Log) void CPU_Emitter::EMITTER_DECL(ngraph::op::Log)
{ {
writer << "{ // " << node->get_name() << "\n"; writer.block_begin();
writer.indent++;
#if PREFER_EIGEN == 1 #if PREFER_EIGEN == 1
writer << emit_array1d(out[0]) << " =\n" writer << emit_array1d(out[0]) << " =\n"
<< " Eigen::log(" << emit_array1d(args[0]) << ");\n"; << " Eigen::log(" << emit_array1d(args[0]) << ");\n";
#else #else
writer << "#pragma omp parallel for\n"; writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n"; writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer << "{\n"; writer.block_begin();
writer << " " << out[0].get_name() << "[i] = log(" << args[0].get_name() writer << out[0].get_name() << "[i] = log(" << args[0].get_name() << "[i]);\n";
<< "[i]);\n"; writer.block_end();
writer << "}\n";
#endif #endif
writer.indent--; writer.block_end();
writer << "}\n";
} }
template <> template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Maximum) void CPU_Emitter::EMITTER_DECL(ngraph::op::Maximum)
{ {
writer << "{ // " << node->get_name() << "\n"; writer.block_begin();
writer.indent++;
#if PREFER_EIGEN == 1 #if PREFER_EIGEN == 1
writer << emit_array1d(out[0]) << " =\n" writer << emit_array1d(out[0]) << " =\n"
<< " " << emit_array1d(args[0]) << ".max(\n" << " " << emit_array1d(args[0]) << ".max(\n"
...@@ -987,21 +941,19 @@ namespace ngraph ...@@ -987,21 +941,19 @@ namespace ngraph
#else #else
writer << "#pragma omp parallel for\n"; writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n"; writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer << "{\n"; writer.block_begin();
writer << " " << out[0].get_name() << "[i] = " << args[0].get_name() << "[i] > " writer << out[0].get_name() << "[i] = " << args[0].get_name() << "[i] > "
<< args[1].get_name() << "[i] ? " << args[0].get_name() << args[1].get_name() << "[i] ? " << args[0].get_name()
<< "[i] : " << args[1].get_name() << "[i] ;\n"; << "[i] : " << args[1].get_name() << "[i] ;\n";
writer << "}\n"; writer.block_end();
#endif #endif
writer.indent--; writer.block_end();
writer << "}\n";
} }
template <> template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Minimum) void CPU_Emitter::EMITTER_DECL(ngraph::op::Minimum)
{ {
writer << "{ // " << node->get_name() << "\n"; writer.block_begin();
writer.indent++;
#if PREFER_EIGEN == 1 #if PREFER_EIGEN == 1
writer << emit_array1d(out[0]) << " =\n" writer << emit_array1d(out[0]) << " =\n"
<< " " << emit_array1d(args[0]) << ".min(\n" << " " << emit_array1d(args[0]) << ".min(\n"
...@@ -1009,41 +961,36 @@ namespace ngraph ...@@ -1009,41 +961,36 @@ namespace ngraph
#else #else
writer << "#pragma omp parallel for\n"; writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n"; writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer << "{\n"; writer.block_begin();
writer << " " << out[0].get_name() << "[i] = " << args[0].get_name() << "[i] < " writer << out[0].get_name() << "[i] = " << args[0].get_name() << "[i] < "
<< args[1].get_name() << "[i] ? " << args[0].get_name() << args[1].get_name() << "[i] ? " << args[0].get_name()
<< "[i] : " << args[1].get_name() << "[i] ;\n"; << "[i] : " << args[1].get_name() << "[i] ;\n";
writer << "}\n"; writer.block_end();
#endif #endif
writer.indent--; writer.block_end();
writer << "}\n";
} }
template <> template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Negative) void CPU_Emitter::EMITTER_DECL(ngraph::op::Negative)
{ {
writer << "{ // " << node->get_name() << "\n"; writer.block_begin();
writer.indent++;
#if PREFER_EIGEN == 1 #if PREFER_EIGEN == 1
writer << emit_array1d(out[0]) << " =\n" writer << emit_array1d(out[0]) << " =\n"
<< " -" << emit_array1d(args[0]) << ";\n"; << " -" << emit_array1d(args[0]) << ";\n";
#else #else
writer << "#pragma omp parallel for\n"; writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n"; writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer << "{\n"; writer.block_begin();
writer << " " << out[0].get_name() << "[i] = -" << args[0].get_name() writer << out[0].get_name() << "[i] = -" << args[0].get_name() << "[i];\n";
<< "[i];\n"; writer.block_end();
writer << "}\n";
#endif #endif
writer.indent--; writer.block_end();
writer << "}\n";
} }
template <> template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::NotEqual) void CPU_Emitter::EMITTER_DECL(ngraph::op::NotEqual)
{ {
writer << "{ // " << node->get_name() << "\n"; writer.block_begin();
writer.indent++;
#if PREFER_EIGEN == 1 #if PREFER_EIGEN == 1
writer << emit_array1d(out[0]) << " =\n" writer << emit_array1d(out[0]) << " =\n"
<< " (" << emit_array1d(args[0]) << " !=\n" << " (" << emit_array1d(args[0]) << " !=\n"
...@@ -1051,20 +998,18 @@ namespace ngraph ...@@ -1051,20 +998,18 @@ namespace ngraph
#else #else
writer << "#pragma omp parallel for\n"; writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n"; writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer << "{\n"; writer.block_begin();
writer << " " << out[0].get_name() << "[i] = " << args[0].get_name() writer << out[0].get_name() << "[i] = " << args[0].get_name()
<< "[i] != " << args[1].get_name() << "[i];\n"; << "[i] != " << args[1].get_name() << "[i];\n";
writer << "}\n"; writer.block_end();
#endif #endif
writer.indent--; writer.block_end();
writer << "}\n";
} }
template <> template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Select) void CPU_Emitter::EMITTER_DECL(ngraph::op::Select)
{ {
writer << "{ // " << node->get_name() << "\n"; writer.block_begin();
writer.indent++;
#if PREFER_EIGEN == 1 #if PREFER_EIGEN == 1
writer << emit_array1d(out[0]) << " =\n" writer << emit_array1d(out[0]) << " =\n"
<< " " << emit_array1d(args[0]) << "\n" << " " << emit_array1d(args[0]) << "\n"
...@@ -1073,20 +1018,18 @@ namespace ngraph ...@@ -1073,20 +1018,18 @@ namespace ngraph
#else #else
writer << "#pragma omp parallel for\n"; writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n"; writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer << "{\n"; writer.block_begin();
writer << " " << out[0].get_name() << "[i] = " << args[0].get_name() << "[i] ? " writer << out[0].get_name() << "[i] = " << args[0].get_name() << "[i] ? "
<< args[1].get_name() << "[i] : " << args[2].get_name() << "[i];\n"; << args[1].get_name() << "[i] : " << args[2].get_name() << "[i];\n";
writer << "}\n"; writer.block_end();
#endif #endif
writer.indent--; writer.block_end();
writer << "}\n";
} }
template <> template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Subtract) void CPU_Emitter::EMITTER_DECL(ngraph::op::Subtract)
{ {
writer << "{ // " << node->get_name() << "\n"; writer.block_begin();
writer.indent++;
#if PREFER_EIGEN == 1 #if PREFER_EIGEN == 1
writer << emit_array1d(out[0]) << " =\n" writer << emit_array1d(out[0]) << " =\n"
<< " " << emit_array1d(args[0]) << " -\n" << " " << emit_array1d(args[0]) << " -\n"
...@@ -1094,13 +1037,12 @@ namespace ngraph ...@@ -1094,13 +1037,12 @@ namespace ngraph
#else #else
writer << "#pragma omp parallel for\n"; writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n"; writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer << "{\n"; writer.block_begin();
writer << " " << out[0].get_name() << "[i] = " << args[0].get_name() << "[i] - " writer << out[0].get_name() << "[i] = " << args[0].get_name() << "[i] - "
<< args[1].get_name() << "[i];\n"; << args[1].get_name() << "[i];\n";
writer << "}\n"; writer.block_end();
#endif #endif
writer.indent--; writer.block_end();
writer << "}\n";
} }
template <> template <>
...@@ -1108,45 +1050,37 @@ namespace ngraph ...@@ -1108,45 +1050,37 @@ namespace ngraph
{ {
auto broadcast = static_cast<const ngraph::op::Broadcast*>(node); auto broadcast = static_cast<const ngraph::op::Broadcast*>(node);
writer << "{ // " << node->get_name() << "\n"; writer.block_begin();
writer.indent++;
#if PREFER_EIGEN == 1 #if PREFER_EIGEN == 1
auto arg_shape = args[0].get_shape(); auto arg_shape = args[0].get_shape();
auto result_shape = out[0].get_shape(); auto result_shape = out[0].get_shape();
if (broadcast->get_broadcast_axes().empty()) if (broadcast->get_broadcast_axes().empty())
{ {
writer << "{ // " << node->get_name() << "\n"; writer.block_begin();
writer.indent++;
writer << "memcpy(" << out[0].get_name() << ", " << args[0].get_name() << ", " writer << "memcpy(" << out[0].get_name() << ", " << args[0].get_name() << ", "
<< out[0].get_size() * out[0].get_element_type().size() << ");\n"; << out[0].get_size() * out[0].get_element_type().size() << ");\n";
writer.indent--; writer.block_end();
writer << "}\n";
} }
else if (arg_shape.size() == 0) else if (arg_shape.size() == 0)
{ {
writer << "{ // " << node->get_name() << "\n"; writer.block_begin();
writer.indent++;
writer << emit_array1d(out[0]) << " =\n" writer << emit_array1d(out[0]) << " =\n"
<< " " << emit_array1d(args[0]) << "(0, 0);\n"; << " " << emit_array1d(args[0]) << "(0, 0);\n";
writer.indent--; writer.block_end();
writer << "}\n";
} }
else if (arg_shape.size() == 1 && result_shape.size() == 2) else if (arg_shape.size() == 1 && result_shape.size() == 2)
{ {
if (broadcast->get_broadcast_axes() == AxisSet{1}) if (broadcast->get_broadcast_axes() == AxisSet{1})
{ {
writer << "{ // " << node->get_name() << "\n"; writer.block_begin();
writer.indent++;
writer << emit_matrix(out[0]) << ".colwise() =\n" writer << emit_matrix(out[0]) << ".colwise() =\n"
<< " " << emit_vector(args[0]) << ";\n"; << " " << emit_vector(args[0]) << ";\n";
writer.indent--; writer.block_end();
writer << "}\n";
} }
else if (broadcast->get_broadcast_axes() == AxisSet{0}) else if (broadcast->get_broadcast_axes() == AxisSet{0})
{ {
writer << "{ // " << node->get_name() << "\n"; writer.block_begin();
writer.indent++;
writer << "Eigen::Map<Eigen::Matrix<" writer << "Eigen::Map<Eigen::Matrix<"
<< out[0].get_element_type().c_type_string() << ", " << out[0].get_element_type().c_type_string() << ", "
...@@ -1163,8 +1097,7 @@ namespace ngraph ...@@ -1163,8 +1097,7 @@ namespace ngraph
writer << "out = arg0.replicate<" << out[0].get_shape().at(0) writer << "out = arg0.replicate<" << out[0].get_shape().at(0)
<< ", 1>();\n"; << ", 1>();\n";
writer.indent--; writer.block_end();
writer << "}\n";
} }
else else
{ {
...@@ -1193,8 +1126,7 @@ namespace ngraph ...@@ -1193,8 +1126,7 @@ namespace ngraph
out[0].get_shape(), out[0].get_shape(),
broadcast->get_broadcast_axes()); broadcast->get_broadcast_axes());
#endif #endif
writer.indent--; writer.block_end();
writer << "}\n";
} }
template <> template <>
...@@ -1202,8 +1134,7 @@ namespace ngraph ...@@ -1202,8 +1134,7 @@ namespace ngraph
{ {
auto& result_element_type = out[0].get_element_type(); auto& result_element_type = out[0].get_element_type();
writer << "{ // " << node->get_name() << "\n"; writer.block_begin();
writer.indent++;
#if PREFER_EIGEN == 1 #if PREFER_EIGEN == 1
writer << emit_array1d(out[0]) << " =\n" writer << emit_array1d(out[0]) << " =\n"
<< " " << emit_array1d(args[0]) << "\n" << " " << emit_array1d(args[0]) << "\n"
...@@ -1211,14 +1142,12 @@ namespace ngraph ...@@ -1211,14 +1142,12 @@ namespace ngraph
#else #else
writer << "#pragma omp parallel for\n"; writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n"; writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer << "{\n"; writer.block_begin();
writer << " " << out[0].get_name() << "[i] = (" writer << out[0].get_name() << "[i] = (" << result_element_type.c_type_string()
<< result_element_type.c_type_string() << ")(" << args[0].get_name() << ")(" << args[0].get_name() << "[i]);\n";
<< "[i]);\n"; writer.block_end();
writer << "}\n";
#endif #endif
writer.indent--; writer.block_end();
writer << "}\n";
} }
template <> template <>
...@@ -1242,8 +1171,7 @@ namespace ngraph ...@@ -1242,8 +1171,7 @@ namespace ngraph
void CPU_Emitter::EMITTER_DECL(ngraph::op::Reshape) void CPU_Emitter::EMITTER_DECL(ngraph::op::Reshape)
{ {
auto reshape = static_cast<const ngraph::op::Reshape*>(node); auto reshape = static_cast<const ngraph::op::Reshape*>(node);
writer << "{ // " << node->get_name() << "\n"; writer.block_begin();
writer.indent++;
#if PREFER_EIGEN == 1 #if PREFER_EIGEN == 1
auto arg_shape = args[0].get_shape(); auto arg_shape = args[0].get_shape();
auto arg_rank = arg_shape.size(); auto arg_rank = arg_shape.size();
...@@ -1265,12 +1193,10 @@ namespace ngraph ...@@ -1265,12 +1193,10 @@ namespace ngraph
// we can just copy. // we can just copy.
if (same_layout || result_shape_product < 2) if (same_layout || result_shape_product < 2)
{ {
writer << "{ // " << node->get_name() << " 1\n"; writer.block_begin();
writer.indent++;
writer << "memcpy(" << out[0].get_name() << ", " << args[0].get_name() << ", " writer << "memcpy(" << out[0].get_name() << ", " << args[0].get_name() << ", "
<< out[0].get_size() * out[0].get_element_type().size() << ");\n"; << out[0].get_size() * out[0].get_element_type().size() << ");\n";
writer.indent--; writer.block_end();
writer << "}\n";
} }
// If there *is* a layout change in the 2D case, we transpose the input. // If there *is* a layout change in the 2D case, we transpose the input.
else if (arg_rank == 2) else if (arg_rank == 2)
...@@ -1278,8 +1204,7 @@ namespace ngraph ...@@ -1278,8 +1204,7 @@ namespace ngraph
// Emit an MKL transpose call if possible // Emit an MKL transpose call if possible
if (result_element_type == ngraph::element::f32) if (result_element_type == ngraph::element::f32)
{ {
writer << "{ // " << node->get_name() << " 2\n"; writer.block_begin();
writer.indent++;
writer << "mkl::MKL_Somatcopy('R', 'T', " << to_string(arg_shape[0]) writer << "mkl::MKL_Somatcopy('R', 'T', " << to_string(arg_shape[0])
<< ",\n" << ",\n"
<< " " << to_string(arg_shape[1]) << ", 1.0f,\n" << " " << to_string(arg_shape[1]) << ", 1.0f,\n"
...@@ -1287,17 +1212,14 @@ namespace ngraph ...@@ -1287,17 +1212,14 @@ namespace ngraph
<< to_string(arg_shape[1]) << ",\n" << to_string(arg_shape[1]) << ",\n"
<< " " << out[0].get_name() << ", " << " " << out[0].get_name() << ", "
<< to_string(arg_shape[0]) << ");\n"; << to_string(arg_shape[0]) << ");\n";
writer.indent--; writer.block_end();
writer << "}\n";
} }
else else
{ {
writer << "{ // " << node->get_name() << " 3\n"; writer.block_begin();
writer.indent++;
writer << emit_matrix(out[0]) << " =\n" writer << emit_matrix(out[0]) << " =\n"
<< " " << emit_matrix(args[0]) << ".transpose();\n"; << " " << emit_matrix(args[0]) << ".transpose();\n";
writer.indent--; writer.block_end();
writer << "}\n";
} }
} }
// Other cases // Other cases
...@@ -1320,8 +1242,7 @@ namespace ngraph ...@@ -1320,8 +1242,7 @@ namespace ngraph
out[0].get_shape(), out[0].get_shape(),
reshape->get_input_order()); reshape->get_input_order());
#endif #endif
writer.indent--; writer.block_end();
writer << "}\n";
} }
template <> template <>
...@@ -1330,8 +1251,7 @@ namespace ngraph ...@@ -1330,8 +1251,7 @@ namespace ngraph
auto function_call = static_cast<const ngraph::op::FunctionCall*>(node); auto function_call = static_cast<const ngraph::op::FunctionCall*>(node);
shared_ptr<Function> function = function_call->get_functions()[0]; shared_ptr<Function> function = function_call->get_functions()[0];
writer << "{ // Call " << function->get_name() << "\n"; writer.block_begin();
writer.indent++;
{ {
vector<string> input_names; vector<string> input_names;
vector<string> output_names; vector<string> output_names;
...@@ -1346,23 +1266,22 @@ namespace ngraph ...@@ -1346,23 +1266,22 @@ namespace ngraph
output_names.push_back(output.get_name()); output_names.push_back(output.get_name());
} }
writer << "void* args[] =\n{"; writer << "void* args[] =\n";
writer.indent++; writer.block_begin();
writer << "\n" << join(input_names, ",\n"); writer << "\n" << join(input_names, ",\n");
writer.indent--; writer.block_end();
writer << "\n};\n"; writer << ";\n";
writer << "void* out[] =\n{"; writer << "void* out[] =\n";
writer.indent++; writer.block_begin();
writer << "\n" << join(output_names, ",\n"); writer << "\n" << join(output_names, ",\n");
writer.indent--; writer.block_end();
writer << "\n};\n"; writer << ";\n";
writer << "\n"; writer << "\n";
writer << function->get_name() << "(args, out, ctx);\n"; writer << function->get_name() << "(args, out, ctx);\n";
} }
writer.indent--; writer.block_end();
writer << "}\n";
} }
// TODO: This and other ops include comments/notes that // TODO: This and other ops include comments/notes that
...@@ -1387,12 +1306,10 @@ namespace ngraph ...@@ -1387,12 +1306,10 @@ namespace ngraph
// Trivial case: no reduction axes (this includes the scalar-reductee case). // Trivial case: no reduction axes (this includes the scalar-reductee case).
if (reduction_axes.empty()) if (reduction_axes.empty())
{ {
writer << "{ // " << node->get_name() << " 1\n"; writer.block_begin();
writer.indent++;
writer << "memcpy(" << out[0].get_name() << ", " << args[0].get_name() << ", " writer << "memcpy(" << out[0].get_name() << ", " << args[0].get_name() << ", "
<< out[0].get_size() * out[0].get_element_type().size() << ");\n"; << out[0].get_size() * out[0].get_element_type().size() << ");\n";
writer.indent--; writer.block_end();
writer << "}\n";
} }
// Behavior for zero-size axes bears some explanation here. XLA's reduce // Behavior for zero-size axes bears some explanation here. XLA's reduce
// operator provides an "base" element (usually, but not necessarily, // operator provides an "base" element (usually, but not necessarily,
...@@ -1425,23 +1342,19 @@ namespace ngraph ...@@ -1425,23 +1342,19 @@ namespace ngraph
if (reductee_shape.at(0) == 0 || if (reductee_shape.at(0) == 0 ||
(reductee_shape.size() == 2 && reductee_shape.at(1) == 0)) (reductee_shape.size() == 2 && reductee_shape.at(1) == 0))
{ {
writer << "{ // " << node->get_name() << " 2\n"; writer.block_begin();
writer.indent++;
writer << "memcpy(" << out[0].get_name() << ", " << args[1].get_name() writer << "memcpy(" << out[0].get_name() << ", " << args[1].get_name()
<< ", " << out[0].get_size() * out[0].get_element_type().size() << ", " << out[0].get_size() * out[0].get_element_type().size()
<< ");\n"; << ");\n";
writer.indent--; writer.block_end();
writer << "}\n";
} }
else else
{ {
writer << "{ // " << node->get_name() << " 3\n"; writer.block_begin();
writer.indent++;
string type = f_result_element_type.c_type_string(); string type = f_result_element_type.c_type_string();
writer << "auto f = [&](" << type << " x, " << type << " y) -> " << type writer << "auto f = [&](" << type << " x, " << type << " y) -> " << type
<< "\n{"; << " {\n";
writer.indent++; writer.indent++;
writer << "\n";
writer << type << " result;\n"; writer << type << " result;\n";
writer << "void* args[] = {&x, &y};\n"; writer << "void* args[] = {&x, &y};\n";
writer << "void* out[] = {&result};\n"; writer << "void* out[] = {&result};\n";
...@@ -1451,30 +1364,25 @@ namespace ngraph ...@@ -1451,30 +1364,25 @@ namespace ngraph
writer << "};\n"; writer << "};\n";
writer << emit_array1d(out[0]) << " =\n" writer << emit_array1d(out[0]) << " =\n"
<< " " << emit_array1d(args[0]) << ".redux(f);\n"; << " " << emit_array1d(args[0]) << ".redux(f);\n";
writer.indent--; writer.block_end();
writer << "}\n";
} }
} }
else if (reductee_shape.size() == 2 && reduction_axes == AxisSet{1}) else if (reductee_shape.size() == 2 && reduction_axes == AxisSet{1})
{ {
if (reductee_shape.at(1) == 0) if (reductee_shape.at(1) == 0)
{ {
writer << "{ // " << node->get_name() << " 4\n"; writer.block_begin();
writer.indent++;
writer << emit_array1d(out[0]) << " =\n" writer << emit_array1d(out[0]) << " =\n"
<< " " << emit_array1d(args[1]) << "(0, 0);\n"; << " " << emit_array1d(args[1]) << "(0, 0);\n";
writer.indent--; writer.block_end();
writer << "}\n";
} }
else else
{ {
writer << "{ // " << node->get_name() << " 5\n"; writer.block_begin();
writer.indent++;
string type = f_result_element_type.c_type_string(); string type = f_result_element_type.c_type_string();
writer << "auto f = [&](" << type << " x, " << type << " y) -> " << type writer << "auto f = [&](" << type << " x, " << type << " y) -> " << type
<< "\n{"; << " {\n";
writer.indent++; writer.indent++;
writer << "\n";
writer << type << " result;\n"; writer << type << " result;\n";
writer << "void* args[] = {&x, &y};\n"; writer << "void* args[] = {&x, &y};\n";
writer << "void* out[] = {&result};\n"; writer << "void* out[] = {&result};\n";
...@@ -1484,30 +1392,25 @@ namespace ngraph ...@@ -1484,30 +1392,25 @@ namespace ngraph
writer << "};\n"; writer << "};\n";
writer << emit_vector(out[0]) << " =\n" writer << emit_vector(out[0]) << " =\n"
<< " " << emit_matrix(args[0]) << ".rowwise().redux(f);\n"; << " " << emit_matrix(args[0]) << ".rowwise().redux(f);\n";
writer.indent--; writer.block_end();
writer << "}\n";
} }
} }
else if (reductee_shape.size() == 2 && reduction_axes == AxisSet{0}) else if (reductee_shape.size() == 2 && reduction_axes == AxisSet{0})
{ {
if (reductee_shape.at(0) == 0) if (reductee_shape.at(0) == 0)
{ {
writer << "{ // " << node->get_name() << " 6\n"; writer.block_begin();
writer.indent++;
writer << emit_array1d(out[0]) << " =\n" writer << emit_array1d(out[0]) << " =\n"
<< " " << emit_array1d(args[1]) << "(0, 0);\n"; << " " << emit_array1d(args[1]) << "(0, 0);\n";
writer.indent--; writer.block_end();
writer << "}\n";
} }
else else
{ {
writer << "{ // " << node->get_name() << " 7\n"; writer.block_begin();
writer.indent++;
string type = f_result_element_type.c_type_string(); string type = f_result_element_type.c_type_string();
writer << "auto f = [&](" << type << " x, " << type << " y) -> " << type writer << "auto f = [&](" << type << " x, " << type << " y) -> " << type
<< "\n{"; << " {\n";
writer.indent++; writer.indent++;
writer << "\n";
writer << type << " result;\n"; writer << type << " result;\n";
writer << "void* args[] = {&x, &y};\n"; writer << "void* args[] = {&x, &y};\n";
writer << "void* out[] = {&result};\n"; writer << "void* out[] = {&result};\n";
...@@ -1517,20 +1420,17 @@ namespace ngraph ...@@ -1517,20 +1420,17 @@ namespace ngraph
writer << "};\n"; writer << "};\n";
writer << emit_vector(out[0]) << " =\n" writer << emit_vector(out[0]) << " =\n"
<< " " << emit_matrix(args[0]) << ".colwise().redux(f);\n"; << " " << emit_matrix(args[0]) << ".colwise().redux(f);\n";
writer.indent--; writer.block_end();
writer << "}\n";
} }
} }
else else
{ {
writer << "{ // " << node->get_name() << "\n"; writer.block_begin();
writer.indent++;
string type = f_result_element_type.c_type_string(); string type = f_result_element_type.c_type_string();
writer << "auto f = [&](" << type << " x, " << type << " y) -> " << type writer << "auto f = [&](" << type << " x, " << type << " y) -> " << type
<< "\n{"; << " {\n";
writer.indent++; writer.indent++;
writer << "\n";
writer << type << " result;\n"; writer << type << " result;\n";
writer << "void* args[] = {&x, &y};\n"; writer << "void* args[] = {&x, &y};\n";
writer << "void* out[] = {&result};\n"; writer << "void* out[] = {&result};\n";
...@@ -1548,18 +1448,15 @@ namespace ngraph ...@@ -1548,18 +1448,15 @@ namespace ngraph
writer << " {" << join(reduce->get_reduction_axes()) << "},\n"; writer << " {" << join(reduce->get_reduction_axes()) << "},\n";
writer << " f);\n"; writer << " f);\n";
writer.indent--; writer.block_end();
writer << "}\n";
} }
#else #else
writer << "{ // " << node->get_name() << " 1\n"; writer.block_begin();
writer.indent++;
string type = f_result_element_type.c_type_string(); string type = f_result_element_type.c_type_string();
writer << "auto f = [&](" << type << " x, " << type << " y) -> " << type << "\n{"; writer << "auto f = [&](" << type << " x, " << type << " y) -> " << type << " {\n";
writer.indent++; writer.indent++;
writer << "\n";
writer << type << " result;\n"; writer << type << " result;\n";
writer << "void* args[] = {&x, &y};\n"; writer << "void* args[] = {&x, &y};\n";
writer << "void* out[] = {&result};\n"; writer << "void* out[] = {&result};\n";
...@@ -1577,29 +1474,26 @@ namespace ngraph ...@@ -1577,29 +1474,26 @@ namespace ngraph
out[0].get_shape(), out[0].get_shape(),
reduce->get_reduction_axes()); reduce->get_reduction_axes());
writer.indent--; writer.block_end();
writer << "}\n";
#endif #endif
} }
template <> template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Sign) void CPU_Emitter::EMITTER_DECL(ngraph::op::Sign)
{ {
writer << "{ // " << node->get_name() << "\n"; writer.block_begin();
writer.indent++;
#if PREFER_EIGEN == 1 #if PREFER_EIGEN == 1
writer << emit_array1d(out[0]) << " =\n" writer << emit_array1d(out[0]) << " =\n"
<< " " << emit_array1d(args[0]) << ".sign();\n"; << " " << emit_array1d(args[0]) << ".sign();\n";
#else #else
writer << "#pragma omp parallel for\n"; writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n"; writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer << "{\n"; writer.block_begin();
writer << " " << out[0].get_name() << "[i] = (0 < " << args[0].get_name() writer << out[0].get_name() << "[i] = (0 < " << args[0].get_name() << "[i]) - ("
<< "[i]) - (" << args[0].get_name() << "[i] < 0);\n"; << args[0].get_name() << "[i] < 0);\n";
writer << "}\n"; writer.block_end();
#endif #endif
writer.indent--; writer.block_end();
writer << "}\n";
} }
template <> template <>
...@@ -1607,8 +1501,7 @@ namespace ngraph ...@@ -1607,8 +1501,7 @@ namespace ngraph
{ {
const ngraph::op::Slice* slice = static_cast<const ngraph::op::Slice*>(node); const ngraph::op::Slice* slice = static_cast<const ngraph::op::Slice*>(node);
writer << "{ // " << node->get_name() << "\n"; writer.block_begin();
writer.indent++;
#if PREFER_EIGEN == 1 #if PREFER_EIGEN == 1
size_t arg_rank = args[0].get_shape().size(); size_t arg_rank = args[0].get_shape().size();
...@@ -1628,36 +1521,30 @@ namespace ngraph ...@@ -1628,36 +1521,30 @@ namespace ngraph
// Scalar slice is necessarily just a copy. // Scalar slice is necessarily just a copy.
if (!strided && arg_rank == 0) if (!strided && arg_rank == 0)
{ {
writer << "{ // " << node->get_name() << " 1\n"; writer.block_begin();
writer.indent++;
writer << "memcpy(" << out[0].get_name() << ", " << args[0].get_name() << ", " writer << "memcpy(" << out[0].get_name() << ", " << args[0].get_name() << ", "
<< out[0].get_size() * out[0].get_element_type().size() << ");\n"; << out[0].get_size() * out[0].get_element_type().size() << ");\n";
writer.indent--; writer.block_end();
writer << "}\n";
} }
else if (!strided && arg_rank == 1) else if (!strided && arg_rank == 1)
{ {
writer << "{ // " << node->get_name() << " 2\n"; writer.block_begin();
writer.indent++;
writer << emit_vector(out[0]) << " =\n" writer << emit_vector(out[0]) << " =\n"
<< " " << emit_vector(args[0]) << ".segment(\n" << " " << emit_vector(args[0]) << ".segment(\n"
<< " " << to_string(lower_bounds[0]) << ", " << " " << to_string(lower_bounds[0]) << ", "
<< to_string(upper_bounds[0] - lower_bounds[0]) << ");\n"; << to_string(upper_bounds[0] - lower_bounds[0]) << ");\n";
writer.indent--; writer.block_end();
writer << "}\n";
} }
else if (!strided && arg_rank == 2) else if (!strided && arg_rank == 2)
{ {
writer << "{ // " << node->get_name() << " 3\n"; writer.block_begin();
writer.indent++;
writer << emit_matrix(out[0]) << " = \n" writer << emit_matrix(out[0]) << " = \n"
<< " " << emit_matrix(args[0]) << ".block(" << " " << emit_matrix(args[0]) << ".block("
<< to_string(lower_bounds[0]) << ", " << to_string(lower_bounds[1]) << to_string(lower_bounds[0]) << ", " << to_string(lower_bounds[1])
<< ",\n" << ",\n"
<< " " << to_string(upper_bounds[0] - lower_bounds[0]) << ",\n" << " " << to_string(upper_bounds[0] - lower_bounds[0]) << ",\n"
<< " " << to_string(upper_bounds[1] - lower_bounds[1]) << ");\n"; << " " << to_string(upper_bounds[1] - lower_bounds[1]) << ");\n";
writer.indent--; writer.block_end();
writer << "}\n";
} }
// Other cases (reordering of axes for tensors with rank>2) are not handled yet. // Other cases (reordering of axes for tensors with rank>2) are not handled yet.
else else
...@@ -1684,16 +1571,14 @@ namespace ngraph ...@@ -1684,16 +1571,14 @@ namespace ngraph
slice->get_upper_bounds(), slice->get_upper_bounds(),
slice->get_strides()); slice->get_strides());
#endif #endif
writer.indent--; writer.block_end();
writer << "}\n";
} }
template <> template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Sum) void CPU_Emitter::EMITTER_DECL(ngraph::op::Sum)
{ {
const ngraph::op::Sum* sum = static_cast<const ngraph::op::Sum*>(node); const ngraph::op::Sum* sum = static_cast<const ngraph::op::Sum*>(node);
writer << "{ // " << node->get_name() << "\n"; writer.block_begin();
writer.indent++;
#if PREFER_EIGEN == 1 #if PREFER_EIGEN == 1
const Shape& arg_shape = args[0].get_shape(); const Shape& arg_shape = args[0].get_shape();
size_t arg_rank = arg_shape.size(); size_t arg_rank = arg_shape.size();
...@@ -1702,41 +1587,33 @@ namespace ngraph ...@@ -1702,41 +1587,33 @@ namespace ngraph
// Trivial case: no reduction axes. // Trivial case: no reduction axes.
if (reduction_axes.size() == 0) if (reduction_axes.size() == 0)
{ {
writer << "{ // " << node->get_name() << "\n"; writer.block_begin();
writer.indent++;
writer << "memcpy(" << out[0].get_name() << ", " << args[0].get_name() << ", " writer << "memcpy(" << out[0].get_name() << ", " << args[0].get_name() << ", "
<< out[0].get_size() * out[0].get_element_type().size() << ");\n"; << out[0].get_size() * out[0].get_element_type().size() << ");\n";
writer.indent--; writer.block_end();
writer << "}\n";
} }
// Full reduction? Then sum to scalar. // Full reduction? Then sum to scalar.
else if ((arg_rank == 1 && reduction_axes == AxisSet{0}) || else if ((arg_rank == 1 && reduction_axes == AxisSet{0}) ||
(arg_rank == 2 && reduction_axes == AxisSet{0, 1})) (arg_rank == 2 && reduction_axes == AxisSet{0, 1}))
{ {
writer << "{ // " << node->get_name() << "\n"; writer.block_begin();
writer.indent++;
writer << emit_array1d(out[0]) << " =\n" writer << emit_array1d(out[0]) << " =\n"
<< " " << emit_array1d(args[0]) << ".sum();\n"; << " " << emit_array1d(args[0]) << ".sum();\n";
writer.indent--; writer.block_end();
writer << "}\n";
} }
else if (arg_rank == 2 && reduction_axes == AxisSet{1}) else if (arg_rank == 2 && reduction_axes == AxisSet{1})
{ {
writer << "{ // " << node->get_name() << "\n"; writer.block_begin();
writer.indent++;
writer << emit_vector(out[0]) << " =\n" writer << emit_vector(out[0]) << " =\n"
<< " " << emit_matrix(args[0]) << ".rowwise().sum();\n"; << " " << emit_matrix(args[0]) << ".rowwise().sum();\n";
writer.indent--; writer.block_end();
writer << "}\n";
} }
else if (arg_rank == 2 && reduction_axes == AxisSet{0}) else if (arg_rank == 2 && reduction_axes == AxisSet{0})
{ {
writer << "{ // " << node->get_name() << "\n"; writer.block_begin();
writer.indent++;
writer << emit_vector(out[0]) << " =\n" writer << emit_vector(out[0]) << " =\n"
<< " " << emit_matrix(args[0]) << ".colwise().sum();\n"; << " " << emit_matrix(args[0]) << ".colwise().sum();\n";
writer.indent--; writer.block_end();
writer << "}\n";
} }
else else
{ {
...@@ -1757,128 +1634,109 @@ namespace ngraph ...@@ -1757,128 +1634,109 @@ namespace ngraph
out[0].get_shape(), out[0].get_shape(),
sum->get_reduction_axes()); sum->get_reduction_axes());
#endif #endif
writer.indent--; writer.block_end();
writer << "}\n";
} }
template <> template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Exp) void CPU_Emitter::EMITTER_DECL(ngraph::op::Exp)
{ {
writer << "{ // " << node->get_name() << "\n"; writer.block_begin();
writer.indent++;
#if PREFER_EIGEN == 1 #if PREFER_EIGEN == 1
writer << emit_array1d(out[0]) << " =\n" writer << emit_array1d(out[0]) << " =\n"
<< " " << emit_array1d(args[0]) << ".exp();\n"; << " " << emit_array1d(args[0]) << ".exp();\n";
#else #else
writer << "#pragma omp parallel for\n"; writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n"; writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer << "{\n"; writer.block_begin();
writer << " " << out[0].get_name() << "[i] = exp(" << args[0].get_name() writer << out[0].get_name() << "[i] = exp(" << args[0].get_name() << "[i]);\n";
<< "[i]);\n"; writer.block_end();
writer << "}\n";
#endif #endif
writer.indent--; writer.block_end();
writer << "}\n";
} }
template <> template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Sin) void CPU_Emitter::EMITTER_DECL(ngraph::op::Sin)
{ {
writer << "{ // " << node->get_name() << "\n"; writer.block_begin();
writer.indent++;
#if PREFER_EIGEN == 1 #if PREFER_EIGEN == 1
writer << emit_array1d(out[0]) << " =\n" writer << emit_array1d(out[0]) << " =\n"
<< " " << emit_array1d(args[0]) << ".sin();\n"; << " " << emit_array1d(args[0]) << ".sin();\n";
#else #else
writer << "#pragma omp parallel for\n"; writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n"; writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer << "{\n"; writer.block_begin();
writer << " " << out[0].get_name() << "[i] = sin(" << args[0].get_name() writer << out[0].get_name() << "[i] = sin(" << args[0].get_name() << "[i]);\n";
<< "[i]);\n"; writer.block_end();
writer << "}\n";
#endif #endif
writer.indent--; writer.block_end();
writer << "}\n";
} }
template <> template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Sinh) void CPU_Emitter::EMITTER_DECL(ngraph::op::Sinh)
{ {
writer << "{ // " << node->get_name() << "\n"; writer.block_begin();
writer.indent++;
#if PREFER_EIGEN == 1 #if PREFER_EIGEN == 1
writer << emit_array1d(out[0]) << " =\n" writer << emit_array1d(out[0]) << " =\n"
<< " " << emit_array1d(args[0]) << ".sinh();\n"; << " " << emit_array1d(args[0]) << ".sinh();\n";
#else #else
writer << "#pragma omp parallel for\n"; writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n"; writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer << "{\n"; writer.block_begin();
writer << " " << out[0].get_name() << "[i] = sinh(" << args[0].get_name() writer << out[0].get_name() << "[i] = sinh(" << args[0].get_name() << "[i]);\n";
<< "[i]);\n"; writer.block_end();
writer << "}\n";
#endif #endif
writer.indent--; writer.block_end();
writer << "}\n";
} }
template <> template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Cos) void CPU_Emitter::EMITTER_DECL(ngraph::op::Cos)
{ {
writer << "{ // " << node->get_name() << "\n"; writer.block_begin();
writer.indent++;
#if PREFER_EIGEN == 1 #if PREFER_EIGEN == 1
writer << emit_array1d(out[0]) << " =\n" writer << emit_array1d(out[0]) << " =\n"
<< " " << emit_array1d(args[0]) << ".cos();\n"; << " " << emit_array1d(args[0]) << ".cos();\n";
#else #else
writer << "#pragma omp parallel for\n"; writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n"; writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer << "{\n"; writer.block_begin();
writer << " " << out[0].get_name() << "[i] = cos(" << args[0].get_name() writer << out[0].get_name() << "[i] = cos(" << args[0].get_name() << "[i]);\n";
<< "[i]);\n"; writer.block_end();
writer << "}\n";
#endif #endif
writer.indent--; writer.block_end();
writer << "}\n";
} }
template <> template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Cosh) void CPU_Emitter::EMITTER_DECL(ngraph::op::Cosh)
{ {
writer << "{ // " << node->get_name() << "\n"; writer.block_begin();
writer.indent++;
#if PREFER_EIGEN == 1 #if PREFER_EIGEN == 1
writer << emit_array1d(out[0]) << " =\n" writer << emit_array1d(out[0]) << " =\n"
<< " " << emit_array1d(args[0]) << ".cosh();\n"; << " " << emit_array1d(args[0]) << ".cosh();\n";
#else #else
writer << "#pragma omp parallel for\n"; writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n"; writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer << "{\n"; writer.block_begin();
writer << " " << out[0].get_name() << "[i] = cosh(" << args[0].get_name() writer << out[0].get_name() << "[i] = cosh(" << args[0].get_name() << "[i]);\n";
<< "[i]);\n"; writer.block_end();
writer << "}\n";
#endif #endif
writer.indent--; writer.block_end();
writer << "}\n";
} }
template <> template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Tan) void CPU_Emitter::EMITTER_DECL(ngraph::op::Tan)
{ {
writer << "{ // " << node->get_name() << "\n"; writer.block_begin();
writer.indent++;
#if PREFER_EIGEN == 1 #if PREFER_EIGEN == 1
writer << emit_array1d(out[0]) << " =\n" writer << emit_array1d(out[0]) << " =\n"
<< " " << emit_array1d(args[0]) << ".tan();\n"; << " " << emit_array1d(args[0]) << ".tan();\n";
#else #else
writer << "#pragma omp parallel for\n"; writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n"; writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer << "{\n"; writer.block_begin();
writer << " " << out[0].get_name() << "[i] = tan(" << args[0].get_name() writer << out[0].get_name() << "[i] = tan(" << args[0].get_name() << "[i]);\n";
<< "[i]);\n"; writer.block_end();
writer << "}\n";
#endif #endif
writer.indent--; writer.block_end();
writer << "}\n";
} }
template <> template <>
...@@ -1888,85 +1746,72 @@ namespace ngraph ...@@ -1888,85 +1746,72 @@ namespace ngraph
// so we fall-back to tanh // so we fall-back to tanh
// TODO: Implement our own internal fast/approximate tanh if this actually gets used // TODO: Implement our own internal fast/approximate tanh if this actually gets used
// by models // by models
writer << "{ // " << node->get_name() << "\n"; writer.block_begin();
writer.indent++;
#if PREFER_EIGEN == 0 #if PREFER_EIGEN == 0
writer << "#pragma omp parallel for\n"; writer << "#pragma omp parallel for\n";
#endif #endif
writer << "for (size_t i=0; i<" << out[0].get_size() << "; i++)\n"; writer << "for (size_t i=0; i<" << out[0].get_size() << "; i++)\n";
writer << "{\n"; writer.block_begin();
writer << " " << out[0].get_name() << "[i] = tanh(" << args[0].get_name() writer << out[0].get_name() << "[i] = tanh(" << args[0].get_name() << "[i]);\n";
<< "[i]);\n"; writer.block_end();
writer << "}\n"; writer.block_end();
writer.indent--;
writer << "}\n";
} }
template <> template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Asin) void CPU_Emitter::EMITTER_DECL(ngraph::op::Asin)
{ {
writer << "{ // " << node->get_name() << "\n"; writer.block_begin();
writer.indent++;
#if PREFER_EIGEN == 1 #if PREFER_EIGEN == 1
writer << emit_array1d(out[0]) << " =\n" writer << emit_array1d(out[0]) << " =\n"
<< " " << emit_array1d(args[0]) << ".asin();\n"; << " " << emit_array1d(args[0]) << ".asin();\n";
#else #else
writer << "#pragma omp parallel for\n"; writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n"; writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer << "{\n"; writer.block_begin();
writer << " " << out[0].get_name() << "[i] = asin(" << args[0].get_name() writer << out[0].get_name() << "[i] = asin(" << args[0].get_name() << "[i]);\n";
<< "[i]);\n"; writer.block_end();
writer << "}\n";
#endif #endif
writer.indent--; writer.block_end();
writer << "}\n";
} }
template <> template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Acos) void CPU_Emitter::EMITTER_DECL(ngraph::op::Acos)
{ {
writer << "{ // " << node->get_name() << "\n"; writer.block_begin();
writer.indent++;
#if PREFER_EIGEN == 1 #if PREFER_EIGEN == 1
writer << emit_array1d(out[0]) << " =\n" writer << emit_array1d(out[0]) << " =\n"
<< " " << emit_array1d(args[0]) << ".acos();\n"; << " " << emit_array1d(args[0]) << ".acos();\n";
#else #else
writer << "#pragma omp parallel for\n"; writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n"; writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer << "{\n"; writer.block_begin();
writer << " " << out[0].get_name() << "[i] = acos(" << args[0].get_name() writer << out[0].get_name() << "[i] = acos(" << args[0].get_name() << "[i]);\n";
<< "[i]);\n"; writer.block_end();
writer << "}\n";
#endif #endif
writer.indent--; writer.block_end();
writer << "}\n";
} }
template <> template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Atan) void CPU_Emitter::EMITTER_DECL(ngraph::op::Atan)
{ {
writer << "{ // " << node->get_name() << "\n"; writer.block_begin();
writer.indent++;
#if PREFER_EIGEN == 1 #if PREFER_EIGEN == 1
writer << emit_array1d(out[0]) << " =\n" writer << emit_array1d(out[0]) << " =\n"
<< " " << emit_array1d(args[0]) << ".atan();\n"; << " " << emit_array1d(args[0]) << ".atan();\n";
#else #else
writer << "#pragma omp parallel for\n"; writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n"; writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer << "{\n"; writer.block_begin();
writer << " " << out[0].get_name() << "[i] = atan(" << args[0].get_name() writer << out[0].get_name() << "[i] = atan(" << args[0].get_name() << "[i]);\n";
<< "[i]);\n"; writer.block_end();
writer << "}\n";
#endif #endif
writer.indent--; writer.block_end();
writer << "}\n";
} }
template <> template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Power) void CPU_Emitter::EMITTER_DECL(ngraph::op::Power)
{ {
writer << "{ // " << node->get_name() << "\n"; writer.block_begin();
writer.indent++;
#if PREFER_EIGEN == 1 #if PREFER_EIGEN == 1
writer << emit_array1d(out[0]) << " = \n"; writer << emit_array1d(out[0]) << " = \n";
writer.indent++; writer.indent++;
...@@ -1976,21 +1821,19 @@ namespace ngraph ...@@ -1976,21 +1821,19 @@ namespace ngraph
#else #else
writer << "#pragma omp parallel for\n"; writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n"; writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer << "{\n"; writer.block_begin();
writer << " " << out[0].get_name() << "[i] = pow(" << args[0].get_name() writer << out[0].get_name() << "[i] = pow(" << args[0].get_name() << "[i], "
<< "[i], " << args[1].get_name() << "[i]);\n"; << args[1].get_name() << "[i]);\n";
writer << "}\n"; writer.block_end();
#endif #endif
writer.indent--; writer.block_end();
writer << "}\n";
} }
template <> template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::ReplaceSlice) void CPU_Emitter::EMITTER_DECL(ngraph::op::ReplaceSlice)
{ {
auto replace_slice = static_cast<const ngraph::op::Slice*>(node); auto replace_slice = static_cast<const ngraph::op::Slice*>(node);
writer << "{ // " << node->get_name() << "\n"; writer.block_begin();
writer.indent++;
#if PREFER_EIGEN == 1 #if PREFER_EIGEN == 1
size_t arg0_rank = args[0].get_shape().size(); size_t arg0_rank = args[0].get_shape().size();
...@@ -2010,30 +1853,25 @@ namespace ngraph ...@@ -2010,30 +1853,25 @@ namespace ngraph
// Scalar slice is necessarily just a copy. // Scalar slice is necessarily just a copy.
if (!strided && arg0_rank == 0) if (!strided && arg0_rank == 0)
{ {
writer << "{ // " << node->get_name() << " 1\n"; writer.block_begin();
writer.indent++;
writer << "memcpy(" << out[0].get_name() << ", " << args[1].get_name() << ", " writer << "memcpy(" << out[0].get_name() << ", " << args[1].get_name() << ", "
<< out[0].get_size() * out[0].get_element_type().size() << ");\n"; << out[0].get_size() * out[0].get_element_type().size() << ");\n";
writer.indent--; writer.block_end();
writer << "}\n";
} }
else if (!strided && arg0_rank == 1) else if (!strided && arg0_rank == 1)
{ {
writer << "{ // " << node->get_name() << " 2\n"; writer.block_begin();
writer.indent++;
writer << emit_vector(out[0]) << " =\n" writer << emit_vector(out[0]) << " =\n"
<< " " << emit_vector(args[0]) << ";\n" << " " << emit_vector(args[0]) << ";\n"
<< emit_vector(out[0]) << ".segment(\n" << emit_vector(out[0]) << ".segment(\n"
<< " " << to_string(lower_bounds[0]) << ", " << " " << to_string(lower_bounds[0]) << ", "
<< to_string(upper_bounds[0] - lower_bounds[0]) << ") =\n" << to_string(upper_bounds[0] - lower_bounds[0]) << ") =\n"
<< " " << emit_vector(args[1]) << ";\n"; << " " << emit_vector(args[1]) << ";\n";
writer.indent--; writer.block_end();
writer << "}\n";
} }
else if (!strided && arg0_rank == 2) else if (!strided && arg0_rank == 2)
{ {
writer << "{ // " << node->get_name() << " 3\n"; writer.block_begin();
writer.indent++;
writer << emit_matrix(out[0]) << " =\n" writer << emit_matrix(out[0]) << " =\n"
<< " " << emit_matrix(args[0]) << ";\n" << " " << emit_matrix(args[0]) << ";\n"
<< emit_matrix(out[0]) << ".block(\n" << emit_matrix(out[0]) << ".block(\n"
...@@ -2042,8 +1880,7 @@ namespace ngraph ...@@ -2042,8 +1880,7 @@ namespace ngraph
<< " " << to_string(upper_bounds[0] - lower_bounds[0]) << ",\n" << " " << to_string(upper_bounds[0] - lower_bounds[0]) << ",\n"
<< " " << to_string(upper_bounds[1] - lower_bounds[1]) << ") =\n" << " " << to_string(upper_bounds[1] - lower_bounds[1]) << ") =\n"
<< " " << emit_matrix(args[1]) << ";\n"; << " " << emit_matrix(args[1]) << ";\n";
writer.indent--; writer.block_end();
writer << "}\n";
} }
// Other cases (reordering of axes for tensors with rank>2) are not handled yet. // Other cases (reordering of axes for tensors with rank>2) are not handled yet.
else else
...@@ -2073,8 +1910,7 @@ namespace ngraph ...@@ -2073,8 +1910,7 @@ namespace ngraph
replace_slice->get_upper_bounds(), replace_slice->get_upper_bounds(),
replace_slice->get_strides()); replace_slice->get_strides());
#endif #endif
writer.indent--; writer.block_end();
writer << "}\n";
} }
template <> template <>
...@@ -2088,80 +1924,66 @@ namespace ngraph ...@@ -2088,80 +1924,66 @@ namespace ngraph
if (arg_rank == 0) if (arg_rank == 0)
{ {
writer << "{ // " << node->get_name() << " 1\n"; writer.block_begin();
writer.indent++;
writer << emit_vector(out[0], "out_vector") << ";\n"; writer << emit_vector(out[0], "out_vector") << ";\n";
writer << "out_vector.setZero();\n" writer << "out_vector.setZero();\n"
<< "" << ""
<< "auto pos_raw = " << emit_vector(args[0]) << "(0, 0);\n" << "auto pos_raw = " << emit_vector(args[0]) << "(0, 0);\n"
<< "if (floor(pos_raw) != pos_raw)\n" << "if (floor(pos_raw) != pos_raw)\n";
<< "{\n"; writer.block_begin();
writer.indent++;
writer writer
<< "throw(std::range_error(\"One-hot: non-integral value in input\"));\n"; << "throw(std::range_error(\"One-hot: non-integral value in input\"));\n";
writer.indent--; writer.block_end();
writer << "}\n";
writer << "size_t pos = pos_raw;\n" writer << "size_t pos = pos_raw;\n"
<< "if (pos >= " << bounds << ")\n"; << "if (pos >= " << bounds << ")\n";
writer << "{\n"; writer.block_begin();
writer.indent++;
writer << "throw(std::range_error(\"One-hot: value is out of category " writer << "throw(std::range_error(\"One-hot: value is out of category "
"range\"));\n"; "range\"));\n";
writer.indent--; writer.block_end();
writer << "}\n";
writer << "out_vector(pos, 0) = 1;\n"; writer << "out_vector(pos, 0) = 1;\n";
writer.indent--; writer.block_end();
writer << "}\n";
} }
else if (arg_rank == 1) else if (arg_rank == 1)
{ {
writer << "{ // " << node->get_name() << " 1\n"; writer.block_begin();
writer.indent++;
writer << emit_vector(args[0], "arg_vector") << ";\n"; writer << emit_vector(args[0], "arg_vector") << ";\n";
writer << emit_matrix(out[0], "out_vector") << ";\n"; writer << emit_matrix(out[0], "out_vector") << ";\n";
writer << "out_vector.setZero();\n"; writer << "out_vector.setZero();\n";
writer << "for (size_t i = 0; i < " << args[0].get_shape()[0] << "; i++)\n" writer << "for (size_t i = 0; i < " << args[0].get_shape()[0] << "; i++)\n";
<< "{\n"; writer.block_begin();
writer.indent++;
writer << "auto pos_raw = arg_vector(i, 0);\n"; writer << "auto pos_raw = arg_vector(i, 0);\n";
writer << "if (floor(pos_raw) != pos_raw)\n" writer << "if (floor(pos_raw) != pos_raw)\n";
<< "{\n"; writer.block_begin();
writer.indent++;
writer writer
<< "throw(std::range_error(\"One-hot: non-integral value in input\"));\n"; << "throw(std::range_error(\"One-hot: non-integral value in input\"));\n";
writer.indent--; writer.block_end();
writer << "}\n";
writer << "size_t pos = pos_raw;\n"; writer << "size_t pos = pos_raw;\n";
writer << "bool found = false;\n"; writer << "bool found = false;\n";
writer << "if (pos >= " << bounds << ")\n" writer << "if (pos >= " << bounds << ")\n";
<< "{\n"; writer.block_begin();
writer.indent++;
writer << "throw(std::range_error(\"One-hot: value is out of category " writer << "throw(std::range_error(\"One-hot: value is out of category "
"range\"));\n"; "range\"));\n";
writer.indent--; writer.block_end();
writer << "}\n";
writer << "out_vector" writer << "out_vector"
<< (oh->get_one_hot_axis() == 0 ? "(pos, i)" : "(i, pos)") << " = 1;\n"; << (oh->get_one_hot_axis() == 0 ? "(pos, i)" : "(i, pos)") << " = 1;\n";
writer.indent--; writer.block_end();
writer << "}\n";
writer.indent--; writer.block_end();
writer << "}\n";
} }
// Other cases are not handled yet. // Other cases are not handled yet.
else else
...@@ -2178,55 +2000,46 @@ namespace ngraph ...@@ -2178,55 +2000,46 @@ namespace ngraph
template <> template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Ceiling) void CPU_Emitter::EMITTER_DECL(ngraph::op::Ceiling)
{ {
writer << "{ // " << node->get_name() << "\n"; writer.block_begin();
writer.indent++;
size_t element_count = out[0].get_size(); size_t element_count = out[0].get_size();
#if PREFER_EIGEN == 0 #if PREFER_EIGEN == 0
writer << "#pragma omp parallel for\n"; writer << "#pragma omp parallel for\n";
#endif #endif
writer << "for (size_t i = 0; i < " << element_count << "; i++)\n"; writer << "for (size_t i = 0; i < " << element_count << "; i++)\n";
writer << "{\n"; writer.block_begin();
writer << " " << out[0].get_name() << "[i] = ceil(" << args[0].get_name() writer << out[0].get_name() << "[i] = ceil(" << args[0].get_name() << "[i]);\n";
<< "[i]);\n"; writer.block_end();
writer << "}\n"; writer.block_end();
writer.indent--;
writer << "}\n";
} }
template <> template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Floor) void CPU_Emitter::EMITTER_DECL(ngraph::op::Floor)
{ {
writer << "{ // " << node->get_name() << "\n"; writer.block_begin();
writer.indent++;
size_t element_count = out[0].get_size(); size_t element_count = out[0].get_size();
#if PREFER_EIGEN == 0 #if PREFER_EIGEN == 0
writer << "#pragma omp parallel for\n"; writer << "#pragma omp parallel for\n";
#endif #endif
writer << "for (size_t i = 0; i < " << element_count << "; i++)\n"; writer << "for (size_t i = 0; i < " << element_count << "; i++)\n";
writer << "{\n"; writer.block_begin();
writer << " " << out[0].get_name() << "[i] = floor(" << args[0].get_name() writer << out[0].get_name() << "[i] = floor(" << args[0].get_name() << "[i]);\n";
<< "[i]);\n"; writer.block_end();
writer << "}\n"; writer.block_end();
writer.indent--;
writer << "}\n";
} }
template <> template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Sqrt) void CPU_Emitter::EMITTER_DECL(ngraph::op::Sqrt)
{ {
writer << "{ // " << node->get_name() << "\n"; writer.block_begin();
writer.indent++;
size_t element_count = out[0].get_size(); size_t element_count = out[0].get_size();
#if PREFER_EIGEN == 0 #if PREFER_EIGEN == 0
writer << "#pragma omp parallel for\n"; writer << "#pragma omp parallel for\n";
#endif #endif
writer << "for (size_t i = 0; i < " << element_count << "; i++)\n"; writer << "for (size_t i = 0; i < " << element_count << "; i++)\n";
writer << "{\n"; writer.block_begin();
writer << " " << out[0].get_name() << "[i] = sqrt(" << args[0].get_name() writer << out[0].get_name() << "[i] = sqrt(" << args[0].get_name() << "[i]);\n";
<< "[i]);\n"; writer.block_end();
writer << "}\n"; writer.block_end();
writer.indent--;
writer << "}\n";
} }
template <> template <>
...@@ -2748,13 +2561,11 @@ namespace ngraph ...@@ -2748,13 +2561,11 @@ namespace ngraph
auto reduction_function = reduce_window->get_functions()[0]; auto reduction_function = reduce_window->get_functions()[0];
auto& f_result_element_type = out[0].get_element_type(); auto& f_result_element_type = out[0].get_element_type();
writer << "{ // " << node->get_name() << "\n"; writer.block_begin();
writer.indent++;
string type = f_result_element_type.c_type_string(); string type = f_result_element_type.c_type_string();
writer << "auto f = [&](" << type << " x, " << type << " y) -> " << type << "\n{"; writer << "auto f = [&](" << type << " x, " << type << " y) -> " << type << " {\n";
writer.indent++; writer.indent++;
writer << "\n";
writer << type << " result;\n"; writer << type << " result;\n";
writer << "void* args[] = {&x, &y};\n"; writer << "void* args[] = {&x, &y};\n";
writer << "void* out[] = {&result};\n"; writer << "void* out[] = {&result};\n";
...@@ -2775,8 +2586,7 @@ namespace ngraph ...@@ -2775,8 +2586,7 @@ namespace ngraph
writer << " {" writer << " {"
<< join(reduce_window->get_window_movement_strides()) << "});\n"; << join(reduce_window->get_window_movement_strides()) << "});\n";
writer.indent--; writer.block_end();
writer << "}\n";
} }
template <> template <>
...@@ -2790,14 +2600,12 @@ namespace ngraph ...@@ -2790,14 +2600,12 @@ namespace ngraph
auto arg1_shape = args[1].get_shape(); auto arg1_shape = args[1].get_shape();
auto result_shape = out[0].get_shape(); auto result_shape = out[0].get_shape();
writer << "{ // " << node->get_name() << "\n"; writer.block_begin();
writer.indent++;
string type = node->get_output_element_type(0).c_type_string(); string type = node->get_output_element_type(0).c_type_string();
writer << "auto f_select = [&](" << type << " x, " << type << " y) -> char\n{"; writer << "auto f_select = [&](" << type << " x, " << type << " y) -> char {\n";
writer.indent++; writer.indent++;
writer << "\n";
writer << "char result;\n"; writer << "char result;\n";
writer << "void* args[] = {&x, &y};\n"; writer << "void* args[] = {&x, &y};\n";
writer << "void* out[] = {&result};\n"; writer << "void* out[] = {&result};\n";
...@@ -2807,9 +2615,8 @@ namespace ngraph ...@@ -2807,9 +2615,8 @@ namespace ngraph
writer << "};\n"; writer << "};\n";
writer << "auto f_scatter = [&](" << type << " x, " << type << " y) -> " << type writer << "auto f_scatter = [&](" << type << " x, " << type << " y) -> " << type
<< "\n{"; << " {\n";
writer.indent++; writer.indent++;
writer << "\n";
writer << type << " result;\n"; writer << type << " result;\n";
writer << "void* args[] = {&x, &y};\n"; writer << "void* args[] = {&x, &y};\n";
writer << "void* out[] = {&result};\n"; writer << "void* out[] = {&result};\n";
...@@ -2833,8 +2640,7 @@ namespace ngraph ...@@ -2833,8 +2640,7 @@ namespace ngraph
writer << " {" writer << " {"
<< join(select_and_scatter->get_window_movement_strides()) << "});\n"; << join(select_and_scatter->get_window_movement_strides()) << "});\n";
writer.indent--; writer.block_end();
writer << "}\n";
} }
template <> template <>
...@@ -3057,8 +2863,7 @@ namespace ngraph ...@@ -3057,8 +2863,7 @@ namespace ngraph
void CPU_Emitter::EMITTER_DECL(ngraph::op::Product) void CPU_Emitter::EMITTER_DECL(ngraph::op::Product)
{ {
const ngraph::op::Product* product = static_cast<const ngraph::op::Product*>(node); const ngraph::op::Product* product = static_cast<const ngraph::op::Product*>(node);
writer << "{ // " << node->get_name() << "\n"; writer.block_begin();
writer.indent++;
#if PREFER_EIGEN == 1 #if PREFER_EIGEN == 1
const Shape& arg_shape = args[0].get_shape(); const Shape& arg_shape = args[0].get_shape();
size_t arg_rank = arg_shape.size(); size_t arg_rank = arg_shape.size();
...@@ -3067,41 +2872,33 @@ namespace ngraph ...@@ -3067,41 +2872,33 @@ namespace ngraph
// Trivial case: no reduction axes. // Trivial case: no reduction axes.
if (reduction_axes.size() == 0) if (reduction_axes.size() == 0)
{ {
writer << "{ // " << node->get_name() << "\n"; writer.block_begin();
writer.indent++;
writer << "memcpy(" << out[0].get_name() << ", " << args[0].get_name() << ", " writer << "memcpy(" << out[0].get_name() << ", " << args[0].get_name() << ", "
<< out[0].get_size() * out[0].get_element_type().size() << ");\n"; << out[0].get_size() * out[0].get_element_type().size() << ");\n";
writer.indent--; writer.block_end();
writer << "}\n";
} }
// Full reduction? Then reduce to scalar. // Full reduction? Then reduce to scalar.
else if ((arg_rank == 1 && reduction_axes == AxisSet{0}) || else if ((arg_rank == 1 && reduction_axes == AxisSet{0}) ||
(arg_rank == 2 && reduction_axes == AxisSet{0, 1})) (arg_rank == 2 && reduction_axes == AxisSet{0, 1}))
{ {
writer << "{ // " << node->get_name() << "\n"; writer.block_begin();
writer.indent++;
writer << emit_array1d(out[0]) << " =\n" writer << emit_array1d(out[0]) << " =\n"
<< " " << emit_array1d(args[0]) << ".prod();\n"; << " " << emit_array1d(args[0]) << ".prod();\n";
writer.indent--; writer.block_end();
writer << "}\n";
} }
else if (arg_rank == 2 && reduction_axes == AxisSet{1}) else if (arg_rank == 2 && reduction_axes == AxisSet{1})
{ {
writer << "{ // " << node->get_name() << "\n"; writer.block_begin();
writer.indent++;
writer << emit_vector(out[0]) << " =\n" writer << emit_vector(out[0]) << " =\n"
<< " " << emit_matrix(args[0]) << ".rowwise().prod();\n"; << " " << emit_matrix(args[0]) << ".rowwise().prod();\n";
writer.indent--; writer.block_end();
writer << "}\n";
} }
else if (arg_rank == 2 && reduction_axes == AxisSet{0}) else if (arg_rank == 2 && reduction_axes == AxisSet{0})
{ {
writer << "{ // " << node->get_name() << "\n"; writer.block_begin();
writer.indent++;
writer << emit_vector(out[0]) << " =\n" writer << emit_vector(out[0]) << " =\n"
<< " " << emit_matrix(args[0]) << ".colwise().prod();\n"; << " " << emit_matrix(args[0]) << ".colwise().prod();\n";
writer.indent--; writer.block_end();
writer << "}\n";
} }
else else
{ {
...@@ -3123,16 +2920,14 @@ namespace ngraph ...@@ -3123,16 +2920,14 @@ namespace ngraph
writer << " {" << join(product->get_reduction_axes()) writer << " {" << join(product->get_reduction_axes())
<< "});\n"; << "});\n";
#endif #endif
writer.indent--; writer.block_end();
writer << "}\n";
} }
template <> template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Max) void CPU_Emitter::EMITTER_DECL(ngraph::op::Max)
{ {
const ngraph::op::Max* max = static_cast<const ngraph::op::Max*>(node); const ngraph::op::Max* max = static_cast<const ngraph::op::Max*>(node);
writer << "{ // " << node->get_name() << "\n"; writer.block_begin();
writer.indent++;
#if PREFER_EIGEN == 1 #if PREFER_EIGEN == 1
const Shape& arg_shape = args[0].get_shape(); const Shape& arg_shape = args[0].get_shape();
size_t arg_rank = arg_shape.size(); size_t arg_rank = arg_shape.size();
...@@ -3147,41 +2942,33 @@ namespace ngraph ...@@ -3147,41 +2942,33 @@ namespace ngraph
// Trivial case: no reduction axes. // Trivial case: no reduction axes.
if (!zero_sized && reduction_axes.size() == 0) if (!zero_sized && reduction_axes.size() == 0)
{ {
writer << "{ // " << node->get_name() << "\n"; writer.block_begin();
writer.indent++;
writer << "memcpy(" << out[0].get_name() << ", " << args[0].get_name() << ", " writer << "memcpy(" << out[0].get_name() << ", " << args[0].get_name() << ", "
<< out[0].get_size() * out[0].get_element_type().size() << ");\n"; << out[0].get_size() * out[0].get_element_type().size() << ");\n";
writer.indent--; writer.block_end();
writer << "}\n";
} }
// Full reduction? Then reduce to scalar. // Full reduction? Then reduce to scalar.
else if (!zero_sized && ((arg_rank == 1 && reduction_axes == AxisSet{0}) || else if (!zero_sized && ((arg_rank == 1 && reduction_axes == AxisSet{0}) ||
(arg_rank == 2 && reduction_axes == AxisSet{0, 1}))) (arg_rank == 2 && reduction_axes == AxisSet{0, 1})))
{ {
writer << "{ // " << node->get_name() << "\n"; writer.block_begin();
writer.indent++;
writer << emit_array1d(out[0]) << " =\n" writer << emit_array1d(out[0]) << " =\n"
<< " " << emit_array1d(args[0]) << ".maxCoeff();\n"; << " " << emit_array1d(args[0]) << ".maxCoeff();\n";
writer.indent--; writer.block_end();
writer << "}\n";
} }
else if (!zero_sized && arg_rank == 2 && reduction_axes == AxisSet{1}) else if (!zero_sized && arg_rank == 2 && reduction_axes == AxisSet{1})
{ {
writer << "{ // " << node->get_name() << "\n"; writer.block_begin();
writer.indent++;
writer << emit_vector(out[0]) << " =\n" writer << emit_vector(out[0]) << " =\n"
<< " " << emit_matrix(args[0]) << ".rowwise().maxCoeff();\n"; << " " << emit_matrix(args[0]) << ".rowwise().maxCoeff();\n";
writer.indent--; writer.block_end();
writer << "}\n";
} }
else if (!zero_sized && arg_rank == 2 && reduction_axes == AxisSet{0}) else if (!zero_sized && arg_rank == 2 && reduction_axes == AxisSet{0})
{ {
writer << "{ // " << node->get_name() << "\n"; writer.block_begin();
writer.indent++;
writer << emit_vector(out[0]) << " =\n" writer << emit_vector(out[0]) << " =\n"
<< " " << emit_matrix(args[0]) << ".colwise().maxCoeff();\n"; << " " << emit_matrix(args[0]) << ".colwise().maxCoeff();\n";
writer.indent--; writer.block_end();
writer << "}\n";
} }
else else
{ {
...@@ -3203,16 +2990,14 @@ namespace ngraph ...@@ -3203,16 +2990,14 @@ namespace ngraph
writer << " {" << join(max->get_reduction_axes()) writer << " {" << join(max->get_reduction_axes())
<< "});\n"; << "});\n";
#endif #endif
writer.indent--; writer.block_end();
writer << "}\n";
} }
template <> template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Min) void CPU_Emitter::EMITTER_DECL(ngraph::op::Min)
{ {
const ngraph::op::Min* min = static_cast<const ngraph::op::Min*>(node); const ngraph::op::Min* min = static_cast<const ngraph::op::Min*>(node);
writer << "{ // " << node->get_name() << "\n"; writer.block_begin();
writer.indent++;
#if PREFER_EIGEN == 1 #if PREFER_EIGEN == 1
const Shape& arg_shape = args[0].get_shape(); const Shape& arg_shape = args[0].get_shape();
size_t arg_rank = arg_shape.size(); size_t arg_rank = arg_shape.size();
...@@ -3227,41 +3012,33 @@ namespace ngraph ...@@ -3227,41 +3012,33 @@ namespace ngraph
// Trivial case: no reduction axes. // Trivial case: no reduction axes.
if (!zero_sized && reduction_axes.size() == 0) if (!zero_sized && reduction_axes.size() == 0)
{ {
writer << "{ // " << node->get_name() << "\n"; writer.block_begin();
writer.indent++;
writer << "memcpy(" << out[0].get_name() << ", " << args[0].get_name() << ", " writer << "memcpy(" << out[0].get_name() << ", " << args[0].get_name() << ", "
<< out[0].get_size() * out[0].get_element_type().size() << ");\n"; << out[0].get_size() * out[0].get_element_type().size() << ");\n";
writer.indent--; writer.block_end();
writer << "}\n";
} }
// Full reduction? Then reduce to scalar. // Full reduction? Then reduce to scalar.
else if (!zero_sized && ((arg_rank == 1 && reduction_axes == AxisSet{0}) || else if (!zero_sized && ((arg_rank == 1 && reduction_axes == AxisSet{0}) ||
(arg_rank == 2 && reduction_axes == AxisSet{0, 1}))) (arg_rank == 2 && reduction_axes == AxisSet{0, 1})))
{ {
writer << "{ // " << node->get_name() << "\n"; writer.block_begin();
writer.indent++;
writer << emit_array1d(out[0]) << " =\n" writer << emit_array1d(out[0]) << " =\n"
<< " " << emit_array1d(args[0]) << ".minCoeff();\n"; << " " << emit_array1d(args[0]) << ".minCoeff();\n";
writer.indent--; writer.block_end();
writer << "}\n";
} }
else if (!zero_sized && arg_rank == 2 && reduction_axes == AxisSet{1}) else if (!zero_sized && arg_rank == 2 && reduction_axes == AxisSet{1})
{ {
writer << "{ // " << node->get_name() << "\n"; writer.block_begin();
writer.indent++;
writer << emit_vector(out[0]) << " =\n" writer << emit_vector(out[0]) << " =\n"
<< " " << emit_matrix(args[0]) << ".rowwise().minCoeff();\n"; << " " << emit_matrix(args[0]) << ".rowwise().minCoeff();\n";
writer.indent--; writer.block_end();
writer << "}\n";
} }
else if (!zero_sized && arg_rank == 2 && reduction_axes == AxisSet{0}) else if (!zero_sized && arg_rank == 2 && reduction_axes == AxisSet{0})
{ {
writer << "{ // " << node->get_name() << "\n"; writer.block_begin();
writer.indent++;
writer << emit_vector(out[0]) << " =\n" writer << emit_vector(out[0]) << " =\n"
<< " " << emit_matrix(args[0]) << ".colwise().minCoeff();\n"; << " " << emit_matrix(args[0]) << ".colwise().minCoeff();\n";
writer.indent--; writer.block_end();
writer << "}\n";
} }
else else
{ {
...@@ -3283,8 +3060,7 @@ namespace ngraph ...@@ -3283,8 +3060,7 @@ namespace ngraph
writer << " {" << join(min->get_reduction_axes()) writer << " {" << join(min->get_reduction_axes())
<< "});\n"; << "});\n";
#endif #endif
writer.indent--; writer.block_end();
writer << "}\n";
} }
template <> template <>
...@@ -3358,10 +3134,10 @@ namespace ngraph ...@@ -3358,10 +3134,10 @@ namespace ngraph
{ {
writer << "#pragma omp parallel for\n"; writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n"; writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer << "{\n"; writer.block_begin();
writer << " " << out[0].get_name() << "[i] = " << args[0].get_name() writer << out[0].get_name() << "[i] = " << args[0].get_name() << "[i] > 0 ? "
<< "[i] > 0 ? " << args[1].get_name() << "[i] : 0;\n"; << args[1].get_name() << "[i] : 0;\n";
writer << "}\n"; writer.block_end();
} }
} }
...@@ -3391,10 +3167,10 @@ namespace ngraph ...@@ -3391,10 +3167,10 @@ namespace ngraph
{ {
writer << "#pragma omp parallel for\n"; writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n"; writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer << "{\n"; writer.block_begin();
writer << " " << out[0].get_name() << "[i] = " << args[0].get_name() writer << out[0].get_name() << "[i] = " << args[0].get_name() << "[i] > 0 ? "
<< "[i] > 0 ? " << args[0].get_name() << "[i] : 0;\n"; << args[0].get_name() << "[i] : 0;\n";
writer << "}\n"; writer.block_end();
} }
} }
...@@ -3518,8 +3294,7 @@ namespace ngraph ...@@ -3518,8 +3294,7 @@ namespace ngraph
writer << "#pragma omp parallel for\n"; writer << "#pragma omp parallel for\n";
writer << "for (size_t i" << d << " = 0; i" << d << " < " << shape[d] writer << "for (size_t i" << d << " = 0; i" << d << " < " << shape[d]
<< "; ++i" << d << ")\n"; << "; ++i" << d << ")\n";
writer << "{\n"; writer.block_begin();
writer.indent++;
} }
} }
...@@ -3532,25 +3307,21 @@ namespace ngraph ...@@ -3532,25 +3307,21 @@ namespace ngraph
{ {
writer << "for (size_t i" << d << " = 0; i" << d << " < " << shape[d] writer << "for (size_t i" << d << " = 0; i" << d << " < " << shape[d]
<< "; ++i" << d << ")\n"; << "; ++i" << d << ")\n";
writer << "{\n"; writer.block_begin();
writer.indent++;
} }
} }
writer << "if (arg" << index << " > m)\n"; writer << "if (arg" << index << " > m)\n";
writer << "{\n"; writer.block_begin();
writer.indent++;
writer << "m = arg" << index << ";\n"; writer << "m = arg" << index << ";\n";
writer.indent--; writer.block_end();
writer << "}\n";
// end max inner loop(s) // end max inner loop(s)
for (size_t d = 0; d < dims; ++d) for (size_t d = 0; d < dims; ++d)
{ {
if (axes.find(d) != axes.end()) if (axes.find(d) != axes.end())
{ {
writer.indent--; writer.block_end();
writer << "}\n";
} }
} }
...@@ -3561,8 +3332,7 @@ namespace ngraph ...@@ -3561,8 +3332,7 @@ namespace ngraph
{ {
writer << "for (size_t i" << d << " = 0; i" << d << " < " << shape[d] writer << "for (size_t i" << d << " = 0; i" << d << " < " << shape[d]
<< "; ++i" << d << ")\n"; << "; ++i" << d << ")\n";
writer << "{\n"; writer.block_begin();
writer.indent++;
} }
} }
...@@ -3573,8 +3343,7 @@ namespace ngraph ...@@ -3573,8 +3343,7 @@ namespace ngraph
{ {
if (axes.find(d) != axes.end()) if (axes.find(d) != axes.end())
{ {
writer.indent--; writer.block_end();
writer << "}\n";
} }
} }
...@@ -3583,8 +3352,7 @@ namespace ngraph ...@@ -3583,8 +3352,7 @@ namespace ngraph
{ {
if (axes.find(d) == axes.end()) if (axes.find(d) == axes.end())
{ {
writer.indent--; writer.block_end();
writer << "}\n";
} }
} }
...@@ -3597,8 +3365,7 @@ namespace ngraph ...@@ -3597,8 +3365,7 @@ namespace ngraph
writer << "#pragma omp parallel for\n"; writer << "#pragma omp parallel for\n";
writer << "for (size_t i" << d << " = 0; i" << d << " < " << shape[d] writer << "for (size_t i" << d << " = 0; i" << d << " < " << shape[d]
<< "; ++i" << d << ")\n"; << "; ++i" << d << ")\n";
writer << "{\n"; writer.block_begin();
writer.indent++;
} }
} }
...@@ -3611,8 +3378,7 @@ namespace ngraph ...@@ -3611,8 +3378,7 @@ namespace ngraph
{ {
writer << "for (size_t i" << d << " = 0; i" << d << " < " << shape[d] writer << "for (size_t i" << d << " = 0; i" << d << " < " << shape[d]
<< "; ++i" << d << ")\n"; << "; ++i" << d << ")\n";
writer << "{\n"; writer.block_begin();
writer.indent++;
} }
} }
...@@ -3623,8 +3389,7 @@ namespace ngraph ...@@ -3623,8 +3389,7 @@ namespace ngraph
{ {
if (axes.find(d) != axes.end()) if (axes.find(d) != axes.end())
{ {
writer.indent--; writer.block_end();
writer << "}\n";
} }
} }
...@@ -3637,8 +3402,7 @@ namespace ngraph ...@@ -3637,8 +3402,7 @@ namespace ngraph
{ {
writer << "for (size_t i" << d << " = 0; i" << d << " < " << shape[d] writer << "for (size_t i" << d << " = 0; i" << d << " < " << shape[d]
<< "; ++i" << d << ")\n"; << "; ++i" << d << ")\n";
writer << "{\n"; writer.block_begin();
writer.indent++;
} }
} }
...@@ -3649,8 +3413,7 @@ namespace ngraph ...@@ -3649,8 +3413,7 @@ namespace ngraph
{ {
if (axes.find(d) != axes.end()) if (axes.find(d) != axes.end())
{ {
writer.indent--; writer.block_end();
writer << "}\n";
} }
} }
...@@ -3659,8 +3422,7 @@ namespace ngraph ...@@ -3659,8 +3422,7 @@ namespace ngraph
{ {
if (axes.find(d) == axes.end()) if (axes.find(d) == axes.end())
{ {
writer.indent--; writer.block_end();
writer << "}\n";
} }
} }
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment