Unverified Commit 2bf2214f authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

use codegen block_begin/end methods to clean up code a little (#738)

parent b6240057
......@@ -133,8 +133,7 @@ namespace ngraph
{
// TODO: Audit all uses of Add and fix this to use
// the right alignment instead of Eigen::Unaligned
writer << "{ // " << node->get_name() << "\n";
writer.indent++;
writer.block_begin();
#if PREFER_EIGEN == 1
writer << "Eigen::Map<Eigen::Array<" << out[0].get_element_type().c_type_string()
<< ", " << out[0].get_size() << ", 1>, Eigen::Unaligned> out("
......@@ -189,14 +188,13 @@ namespace ngraph
{
writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer << "{\n";
writer << " " << out[0].get_name() << "[i] = " << args[0].get_name()
<< "[i] + " << args[1].get_name() << "[i];\n";
writer << "}\n";
writer.block_begin();
writer << out[0].get_name() << "[i] = " << args[0].get_name() << "[i] + "
<< args[1].get_name() << "[i];\n";
writer.block_end();
}
#endif
writer.indent--;
writer << "}\n";
writer.block_end();
}
#ifdef NGRAPH_DISTRIBUTED
......@@ -215,13 +213,11 @@ namespace ngraph
data_type = "MPI_DOUBLE";
}
writer << "{ // " << node->get_name() << "\n";
writer.indent++;
writer.block_begin();
writer << "MPI_Allreduce(" << args[0].get_name() << ", " << out[0].get_name()
<< ", " << out[0].get_size() << ", " << data_type
<< ", MPI_SUM, MPI_COMM_WORLD);\n";
writer.indent--;
writer << "}\n";
writer.block_end();
}
#endif
......@@ -259,8 +255,7 @@ namespace ngraph
n = arg1_shape[0];
}
writer << "{ // " << node->get_name() << "\n";
writer.indent++;
writer.block_begin();
const char* cbeta = "0.0f";
......@@ -364,8 +359,7 @@ namespace ngraph
}
}
writer.indent--;
writer << "}\n";
writer.block_end();
}
template <>
......@@ -374,8 +368,7 @@ namespace ngraph
const ngraph::op::BatchNorm* batchnorm =
static_cast<const ngraph::op::BatchNorm*>(node);
writer.indent++;
writer << "{\n";
writer.block_begin();
// define weights
writer << "std::vector<" << args[0].get_element_type().c_type_string()
<< ">bn_weights(2*" << args[0].get_size() << ");\n";
......@@ -477,8 +470,7 @@ namespace ngraph
writer << "cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, "
<< to_string(batchnorm_index) << ");\n";
}
writer.indent--;
writer << "}\n";
writer.block_end();
}
template <>
......@@ -487,8 +479,7 @@ namespace ngraph
const ngraph::op::BatchNormBackprop* batchnorm =
static_cast<const ngraph::op::BatchNormBackprop*>(node);
writer.indent++;
writer << "{\n";
writer.block_begin();
// define weights
writer << "std::vector<" << args[0].get_element_type().c_type_string()
<< ">bn_weights(2*" << args[0].get_size() << ");\n";
......@@ -554,8 +545,7 @@ namespace ngraph
writer << "memcpy(" << out[2].get_name() << ", &bn_dweights[0]+"
<< args[0].get_size() << ", "
<< args[1].get_size() * args[1].get_element_type().size() << ");\n";
writer.indent--;
writer << "}\n";
writer.block_end();
}
template <>
......@@ -570,34 +560,28 @@ namespace ngraph
auto& first = (arg0_shape.empty() ? args[0] : args[1]);
auto& second = (arg0_shape.empty() ? args[1] : args[0]);
writer << "{ // " << node->get_name() << "\n";
writer.indent++;
writer.block_begin();
writer << emit_vector(out[0]) << "\n = ";
writer << first.get_name() << "[0]\n * " << emit_vector(second) << ";\n";
writer.indent--;
writer << "}\n";
writer.block_end();
}
else if ((arg0_shape.size() == 1) && (arg1_shape.size() == 1) &&
dot->get_reduction_axes_count() == 1)
{
writer << "{ // " << node->get_name() << "\n";
writer.indent++;
writer.block_begin();
writer << emit_vector(out[0]) << " << \n"
<< " " << emit_vector(args[0]) << ".dot(" << emit_vector(args[1])
<< ");\n";
writer.indent--;
writer << "}\n";
writer.block_end();
}
else if ((arg0_shape.size() == 2) && (arg1_shape.size() == 1) &&
dot->get_reduction_axes_count() == 1)
{
writer << "{ // " << node->get_name() << "\n";
writer.indent++;
writer.block_begin();
writer << emit_vector(out[0]) << " = \n"
<< " " << emit_matrix(args[0]) << " * " << emit_vector(args[1])
<< ";\n";
writer.indent--;
writer << "}\n";
writer.block_end();
}
else if ((arg0_shape.size() == 2) && (arg1_shape.size() == 2) &&
dot->get_reduction_axes_count() == 1)
......@@ -605,8 +589,7 @@ namespace ngraph
// Emit an MKL SGEMM call if possible
if (args[0].get_element_type() == element::f32)
{
writer << "{ // " << node->get_name() << "\n";
writer.indent++;
writer.block_begin();
writer << "cblas::cblas_sgemm("
<< "cblas::Layout::RowMajor, "
<< "cblas::Transpose::None, "
......@@ -617,18 +600,15 @@ namespace ngraph
<< max(1UL, arg1_shape[1]) << ", 0.0f,\n"
<< " " << out[0].get_name() << ", " << max(1UL, arg1_shape[1])
<< ");\n";
writer.indent--;
writer << "}\n";
writer.block_end();
}
else
{
writer << "{ // " << node->get_name() << "\n";
writer.indent++;
writer.block_begin();
writer << emit_matrix(out[0]) << " = \n"
<< " " << emit_matrix(args[0]) << " * " << emit_matrix(args[1])
<< ";\n";
writer.indent--;
writer << "}\n";
writer.block_end();
}
}
else
......@@ -646,8 +626,7 @@ namespace ngraph
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Multiply)
{
writer << "{ // " << node->get_name() << "\n";
writer.indent++;
writer.block_begin();
#if PREFER_EIGEN == 1
writer << emit_array1d(out[0]) << " =\n"
<< " " << emit_array1d(args[0]) << " *\n"
......@@ -655,13 +634,12 @@ namespace ngraph
#else
writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer << "{\n";
writer << " " << out[0].get_name() << "[i] = " << args[0].get_name() << "[i] * "
writer.block_begin();
writer << out[0].get_name() << "[i] = " << args[0].get_name() << "[i] * "
<< args[1].get_name() << "[i];\n";
writer << "}\n";
writer.block_end();
#endif
writer.indent--;
writer << "}\n";
writer.block_end();
}
template <>
......@@ -669,20 +647,17 @@ namespace ngraph
{
auto get_tuple_element = static_cast<const ngraph::op::GetOutputElement*>(node);
writer << "{ // " << node->get_name() << "\n";
writer.indent++;
writer.block_begin();
writer << "memcpy(" << out[0].get_name() << ", "
<< args[get_tuple_element->get_n()].get_name() << ", "
<< out[0].get_size() * out[0].get_element_type().size() << ");\n";
writer.indent--;
writer << "}\n";
writer.block_end();
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Abs)
{
writer << "{ // " << node->get_name() << "\n";
writer.indent++;
writer.block_begin();
#if PREFER_EIGEN == 1
writer << emit_array1d(out[0]) << " =\n";
writer << "Eigen::abs(" << emit_array1d(args[0]) << ");\n";
......@@ -693,14 +668,13 @@ namespace ngraph
writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer << "{\n";
writer << " " << out[0].get_name()
writer.block_begin();
writer << out[0].get_name()
<< "[i] = " << (result_element_type.is_signed() ? "std::abs" : "") << "("
<< args[0].get_name() << "[i]);\n";
writer << "}\n";
writer.block_end();
#endif
writer.indent--;
writer << "}\n";
writer.block_end();
}
template <>
......@@ -711,8 +685,7 @@ namespace ngraph
#if PREFER_EIGEN == 1
if (result_shape.size() == 1)
{
writer << "{ // " << node->get_name() << "\n";
writer.indent++;
writer.block_begin();
writer << emit_vector(out[0], "out_vector") << ";\n";
size_t concat_pos = 0;
......@@ -723,16 +696,14 @@ namespace ngraph
<< ";\n";
concat_pos += args[i].get_shape().at(0);
}
writer.indent--;
writer << "}\n";
writer.block_end();
}
else if (result_shape.size() == 2)
{
auto axis =
(dynamic_cast<const ngraph::op::Concat*>(node))->get_concatenation_axis();
writer << "{ // " << node->get_name() << "\n";
writer.indent++;
writer.block_begin();
writer << emit_matrix(out[0], "out_matrix") << ";\n";
size_t concat_pos[2]{0, 0};
......@@ -747,8 +718,7 @@ namespace ngraph
concat_pos[axis] += arg_shape.at(axis);
}
writer.indent--;
writer << "}\n";
writer.block_end();
}
else
{
......@@ -822,17 +792,16 @@ namespace ngraph
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Divide)
{
writer << "{ // " << node->get_name() << "\n";
writer.indent++;
writer.block_begin();
if (node->get_element_type().is_real() == false)
{
// Check for divide by zero for integer types only
size_t element_count = args[1].get_size();
writer << "for (size_t i=0; i<" << element_count << "; i++)\n";
writer << "{\n";
writer << " if (" << args.at(1).get_name()
writer.block_begin();
writer << "if (" << args.at(1).get_name()
<< "[i] == 0) throw std::runtime_error(\"integer divide by zero\");\n";
writer << "}\n";
writer.block_end();
}
#if PREFER_EIGEN == 1
writer << emit_array1d(out[0]) << " =\n"
......@@ -841,20 +810,18 @@ namespace ngraph
#else
writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer << "{\n";
writer << " " << out[0].get_name() << "[i] = " << args[0].get_name() << "[i] / "
writer.block_begin();
writer << out[0].get_name() << "[i] = " << args[0].get_name() << "[i] / "
<< args[1].get_name() << "[i];\n";
writer << "}\n";
writer.block_end();
#endif
writer.indent--;
writer << "}\n";
writer.block_end();
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Equal)
{
writer << "{ // " << node->get_name() << "\n";
writer.indent++;
writer.block_begin();
#if PREFER_EIGEN == 1
writer << emit_array1d(out[0]) << " =\n"
<< " (" << emit_array1d(args[0]) << " ==\n"
......@@ -862,20 +829,18 @@ namespace ngraph
#else
writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer << "{\n";
writer << " " << out[0].get_name() << "[i] = " << args[0].get_name()
writer.block_begin();
writer << out[0].get_name() << "[i] = " << args[0].get_name()
<< "[i] == " << args[1].get_name() << "[i];\n";
writer << "}\n";
writer.block_end();
#endif
writer.indent--;
writer << "}\n";
writer.block_end();
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Greater)
{
writer << "{ // " << node->get_name() << " xxx\n";
writer.indent++;
writer.block_begin();
#if PREFER_EIGEN == 1
writer << emit_array1d(out[0]) << " =\n"
<< " (" << emit_array1d(args[0]) << " >\n"
......@@ -883,20 +848,18 @@ namespace ngraph
#else
writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer << "{\n";
writer << " " << out[0].get_name() << "[i] = " << args[0].get_name() << "[i] > "
writer.block_begin();
writer << out[0].get_name() << "[i] = " << args[0].get_name() << "[i] > "
<< args[1].get_name() << "[i];\n";
writer << "}\n";
writer.block_end();
#endif
writer.indent--;
writer << "}\n";
writer.block_end();
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::GreaterEq)
{
writer << "{ // " << node->get_name() << "\n";
writer.indent++;
writer.block_begin();
#if PREFER_EIGEN == 1
writer << emit_array1d(out[0]) << " =\n"
<< " (" << emit_array1d(args[0]) << " >=\n"
......@@ -904,20 +867,18 @@ namespace ngraph
#else
writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer << "{\n";
writer << " " << out[0].get_name() << "[i] = " << args[0].get_name()
writer.block_begin();
writer << out[0].get_name() << "[i] = " << args[0].get_name()
<< "[i] >= " << args[1].get_name() << "[i];\n";
writer << "}\n";
writer.block_end();
#endif
writer.indent--;
writer << "}\n";
writer.block_end();
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Less)
{
writer << "{ // " << node->get_name() << "\n";
writer.indent++;
writer.block_begin();
#if PREFER_EIGEN == 1
writer << emit_array1d(out[0]) << " =\n"
<< " (" << emit_array1d(args[0]) << " <\n"
......@@ -925,20 +886,18 @@ namespace ngraph
#else
writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer << "{\n";
writer << " " << out[0].get_name() << "[i] = " << args[0].get_name() << "[i] < "
writer.block_begin();
writer << out[0].get_name() << "[i] = " << args[0].get_name() << "[i] < "
<< args[1].get_name() << "[i];\n";
writer << "}\n";
writer.block_end();
#endif
writer.indent--;
writer << "}\n";
writer.block_end();
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::LessEq)
{
writer << "{ // " << node->get_name() << "\n";
writer.indent++;
writer.block_begin();
#if PREFER_EIGEN == 1
writer << emit_array1d(out[0]) << " =\n"
<< " (" << emit_array1d(args[0]) << " <=\n"
......@@ -946,40 +905,35 @@ namespace ngraph
#else
writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer << "{\n";
writer << " " << out[0].get_name() << "[i] = " << args[0].get_name()
writer.block_begin();
writer << out[0].get_name() << "[i] = " << args[0].get_name()
<< "[i] <= " << args[1].get_name() << "[i];\n";
writer << "}\n";
writer.block_end();
#endif
writer.indent--;
writer << "}\n";
writer.block_end();
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Log)
{
writer << "{ // " << node->get_name() << "\n";
writer.indent++;
writer.block_begin();
#if PREFER_EIGEN == 1
writer << emit_array1d(out[0]) << " =\n"
<< " Eigen::log(" << emit_array1d(args[0]) << ");\n";
#else
writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer << "{\n";
writer << " " << out[0].get_name() << "[i] = log(" << args[0].get_name()
<< "[i]);\n";
writer << "}\n";
writer.block_begin();
writer << out[0].get_name() << "[i] = log(" << args[0].get_name() << "[i]);\n";
writer.block_end();
#endif
writer.indent--;
writer << "}\n";
writer.block_end();
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Maximum)
{
writer << "{ // " << node->get_name() << "\n";
writer.indent++;
writer.block_begin();
#if PREFER_EIGEN == 1
writer << emit_array1d(out[0]) << " =\n"
<< " " << emit_array1d(args[0]) << ".max(\n"
......@@ -987,21 +941,19 @@ namespace ngraph
#else
writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer << "{\n";
writer << " " << out[0].get_name() << "[i] = " << args[0].get_name() << "[i] > "
writer.block_begin();
writer << out[0].get_name() << "[i] = " << args[0].get_name() << "[i] > "
<< args[1].get_name() << "[i] ? " << args[0].get_name()
<< "[i] : " << args[1].get_name() << "[i] ;\n";
writer << "}\n";
writer.block_end();
#endif
writer.indent--;
writer << "}\n";
writer.block_end();
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Minimum)
{
writer << "{ // " << node->get_name() << "\n";
writer.indent++;
writer.block_begin();
#if PREFER_EIGEN == 1
writer << emit_array1d(out[0]) << " =\n"
<< " " << emit_array1d(args[0]) << ".min(\n"
......@@ -1009,41 +961,36 @@ namespace ngraph
#else
writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer << "{\n";
writer << " " << out[0].get_name() << "[i] = " << args[0].get_name() << "[i] < "
writer.block_begin();
writer << out[0].get_name() << "[i] = " << args[0].get_name() << "[i] < "
<< args[1].get_name() << "[i] ? " << args[0].get_name()
<< "[i] : " << args[1].get_name() << "[i] ;\n";
writer << "}\n";
writer.block_end();
#endif
writer.indent--;
writer << "}\n";
writer.block_end();
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Negative)
{
writer << "{ // " << node->get_name() << "\n";
writer.indent++;
writer.block_begin();
#if PREFER_EIGEN == 1
writer << emit_array1d(out[0]) << " =\n"
<< " -" << emit_array1d(args[0]) << ";\n";
#else
writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer << "{\n";
writer << " " << out[0].get_name() << "[i] = -" << args[0].get_name()
<< "[i];\n";
writer << "}\n";
writer.block_begin();
writer << out[0].get_name() << "[i] = -" << args[0].get_name() << "[i];\n";
writer.block_end();
#endif
writer.indent--;
writer << "}\n";
writer.block_end();
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::NotEqual)
{
writer << "{ // " << node->get_name() << "\n";
writer.indent++;
writer.block_begin();
#if PREFER_EIGEN == 1
writer << emit_array1d(out[0]) << " =\n"
<< " (" << emit_array1d(args[0]) << " !=\n"
......@@ -1051,20 +998,18 @@ namespace ngraph
#else
writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer << "{\n";
writer << " " << out[0].get_name() << "[i] = " << args[0].get_name()
writer.block_begin();
writer << out[0].get_name() << "[i] = " << args[0].get_name()
<< "[i] != " << args[1].get_name() << "[i];\n";
writer << "}\n";
writer.block_end();
#endif
writer.indent--;
writer << "}\n";
writer.block_end();
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Select)
{
writer << "{ // " << node->get_name() << "\n";
writer.indent++;
writer.block_begin();
#if PREFER_EIGEN == 1
writer << emit_array1d(out[0]) << " =\n"
<< " " << emit_array1d(args[0]) << "\n"
......@@ -1073,20 +1018,18 @@ namespace ngraph
#else
writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer << "{\n";
writer << " " << out[0].get_name() << "[i] = " << args[0].get_name() << "[i] ? "
writer.block_begin();
writer << out[0].get_name() << "[i] = " << args[0].get_name() << "[i] ? "
<< args[1].get_name() << "[i] : " << args[2].get_name() << "[i];\n";
writer << "}\n";
writer.block_end();
#endif
writer.indent--;
writer << "}\n";
writer.block_end();
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Subtract)
{
writer << "{ // " << node->get_name() << "\n";
writer.indent++;
writer.block_begin();
#if PREFER_EIGEN == 1
writer << emit_array1d(out[0]) << " =\n"
<< " " << emit_array1d(args[0]) << " -\n"
......@@ -1094,13 +1037,12 @@ namespace ngraph
#else
writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer << "{\n";
writer << " " << out[0].get_name() << "[i] = " << args[0].get_name() << "[i] - "
writer.block_begin();
writer << out[0].get_name() << "[i] = " << args[0].get_name() << "[i] - "
<< args[1].get_name() << "[i];\n";
writer << "}\n";
writer.block_end();
#endif
writer.indent--;
writer << "}\n";
writer.block_end();
}
template <>
......@@ -1108,45 +1050,37 @@ namespace ngraph
{
auto broadcast = static_cast<const ngraph::op::Broadcast*>(node);
writer << "{ // " << node->get_name() << "\n";
writer.indent++;
writer.block_begin();
#if PREFER_EIGEN == 1
auto arg_shape = args[0].get_shape();
auto result_shape = out[0].get_shape();
if (broadcast->get_broadcast_axes().empty())
{
writer << "{ // " << node->get_name() << "\n";
writer.indent++;
writer.block_begin();
writer << "memcpy(" << out[0].get_name() << ", " << args[0].get_name() << ", "
<< out[0].get_size() * out[0].get_element_type().size() << ");\n";
writer.indent--;
writer << "}\n";
writer.block_end();
}
else if (arg_shape.size() == 0)
{
writer << "{ // " << node->get_name() << "\n";
writer.indent++;
writer.block_begin();
writer << emit_array1d(out[0]) << " =\n"
<< " " << emit_array1d(args[0]) << "(0, 0);\n";
writer.indent--;
writer << "}\n";
writer.block_end();
}
else if (arg_shape.size() == 1 && result_shape.size() == 2)
{
if (broadcast->get_broadcast_axes() == AxisSet{1})
{
writer << "{ // " << node->get_name() << "\n";
writer.indent++;
writer.block_begin();
writer << emit_matrix(out[0]) << ".colwise() =\n"
<< " " << emit_vector(args[0]) << ";\n";
writer.indent--;
writer << "}\n";
writer.block_end();
}
else if (broadcast->get_broadcast_axes() == AxisSet{0})
{
writer << "{ // " << node->get_name() << "\n";
writer.indent++;
writer.block_begin();
writer << "Eigen::Map<Eigen::Matrix<"
<< out[0].get_element_type().c_type_string() << ", "
......@@ -1163,8 +1097,7 @@ namespace ngraph
writer << "out = arg0.replicate<" << out[0].get_shape().at(0)
<< ", 1>();\n";
writer.indent--;
writer << "}\n";
writer.block_end();
}
else
{
......@@ -1193,8 +1126,7 @@ namespace ngraph
out[0].get_shape(),
broadcast->get_broadcast_axes());
#endif
writer.indent--;
writer << "}\n";
writer.block_end();
}
template <>
......@@ -1202,8 +1134,7 @@ namespace ngraph
{
auto& result_element_type = out[0].get_element_type();
writer << "{ // " << node->get_name() << "\n";
writer.indent++;
writer.block_begin();
#if PREFER_EIGEN == 1
writer << emit_array1d(out[0]) << " =\n"
<< " " << emit_array1d(args[0]) << "\n"
......@@ -1211,14 +1142,12 @@ namespace ngraph
#else
writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer << "{\n";
writer << " " << out[0].get_name() << "[i] = ("
<< result_element_type.c_type_string() << ")(" << args[0].get_name()
<< "[i]);\n";
writer << "}\n";
writer.block_begin();
writer << out[0].get_name() << "[i] = (" << result_element_type.c_type_string()
<< ")(" << args[0].get_name() << "[i]);\n";
writer.block_end();
#endif
writer.indent--;
writer << "}\n";
writer.block_end();
}
template <>
......@@ -1242,8 +1171,7 @@ namespace ngraph
void CPU_Emitter::EMITTER_DECL(ngraph::op::Reshape)
{
auto reshape = static_cast<const ngraph::op::Reshape*>(node);
writer << "{ // " << node->get_name() << "\n";
writer.indent++;
writer.block_begin();
#if PREFER_EIGEN == 1
auto arg_shape = args[0].get_shape();
auto arg_rank = arg_shape.size();
......@@ -1265,12 +1193,10 @@ namespace ngraph
// we can just copy.
if (same_layout || result_shape_product < 2)
{
writer << "{ // " << node->get_name() << " 1\n";
writer.indent++;
writer.block_begin();
writer << "memcpy(" << out[0].get_name() << ", " << args[0].get_name() << ", "
<< out[0].get_size() * out[0].get_element_type().size() << ");\n";
writer.indent--;
writer << "}\n";
writer.block_end();
}
// If there *is* a layout change in the 2D case, we transpose the input.
else if (arg_rank == 2)
......@@ -1278,8 +1204,7 @@ namespace ngraph
// Emit an MKL transpose call if possible
if (result_element_type == ngraph::element::f32)
{
writer << "{ // " << node->get_name() << " 2\n";
writer.indent++;
writer.block_begin();
writer << "mkl::MKL_Somatcopy('R', 'T', " << to_string(arg_shape[0])
<< ",\n"
<< " " << to_string(arg_shape[1]) << ", 1.0f,\n"
......@@ -1287,17 +1212,14 @@ namespace ngraph
<< to_string(arg_shape[1]) << ",\n"
<< " " << out[0].get_name() << ", "
<< to_string(arg_shape[0]) << ");\n";
writer.indent--;
writer << "}\n";
writer.block_end();
}
else
{
writer << "{ // " << node->get_name() << " 3\n";
writer.indent++;
writer.block_begin();
writer << emit_matrix(out[0]) << " =\n"
<< " " << emit_matrix(args[0]) << ".transpose();\n";
writer.indent--;
writer << "}\n";
writer.block_end();
}
}
// Other cases
......@@ -1320,8 +1242,7 @@ namespace ngraph
out[0].get_shape(),
reshape->get_input_order());
#endif
writer.indent--;
writer << "}\n";
writer.block_end();
}
template <>
......@@ -1330,8 +1251,7 @@ namespace ngraph
auto function_call = static_cast<const ngraph::op::FunctionCall*>(node);
shared_ptr<Function> function = function_call->get_functions()[0];
writer << "{ // Call " << function->get_name() << "\n";
writer.indent++;
writer.block_begin();
{
vector<string> input_names;
vector<string> output_names;
......@@ -1346,23 +1266,22 @@ namespace ngraph
output_names.push_back(output.get_name());
}
writer << "void* args[] =\n{";
writer.indent++;
writer << "void* args[] =\n";
writer.block_begin();
writer << "\n" << join(input_names, ",\n");
writer.indent--;
writer << "\n};\n";
writer.block_end();
writer << ";\n";
writer << "void* out[] =\n{";
writer.indent++;
writer << "void* out[] =\n";
writer.block_begin();
writer << "\n" << join(output_names, ",\n");
writer.indent--;
writer << "\n};\n";
writer.block_end();
writer << ";\n";
writer << "\n";
writer << function->get_name() << "(args, out, ctx);\n";
}
writer.indent--;
writer << "}\n";
writer.block_end();
}
// TODO: This and other ops include comments/notes that
......@@ -1387,12 +1306,10 @@ namespace ngraph
// Trivial case: no reduction axes (this includes the scalar-reductee case).
if (reduction_axes.empty())
{
writer << "{ // " << node->get_name() << " 1\n";
writer.indent++;
writer.block_begin();
writer << "memcpy(" << out[0].get_name() << ", " << args[0].get_name() << ", "
<< out[0].get_size() * out[0].get_element_type().size() << ");\n";
writer.indent--;
writer << "}\n";
writer.block_end();
}
// Behavior for zero-size axes bears some explanation here. XLA's reduce
// operator provides an "base" element (usually, but not necessarily,
......@@ -1425,23 +1342,19 @@ namespace ngraph
if (reductee_shape.at(0) == 0 ||
(reductee_shape.size() == 2 && reductee_shape.at(1) == 0))
{
writer << "{ // " << node->get_name() << " 2\n";
writer.indent++;
writer.block_begin();
writer << "memcpy(" << out[0].get_name() << ", " << args[1].get_name()
<< ", " << out[0].get_size() * out[0].get_element_type().size()
<< ");\n";
writer.indent--;
writer << "}\n";
writer.block_end();
}
else
{
writer << "{ // " << node->get_name() << " 3\n";
writer.indent++;
writer.block_begin();
string type = f_result_element_type.c_type_string();
writer << "auto f = [&](" << type << " x, " << type << " y) -> " << type
<< "\n{";
<< " {\n";
writer.indent++;
writer << "\n";
writer << type << " result;\n";
writer << "void* args[] = {&x, &y};\n";
writer << "void* out[] = {&result};\n";
......@@ -1451,30 +1364,25 @@ namespace ngraph
writer << "};\n";
writer << emit_array1d(out[0]) << " =\n"
<< " " << emit_array1d(args[0]) << ".redux(f);\n";
writer.indent--;
writer << "}\n";
writer.block_end();
}
}
else if (reductee_shape.size() == 2 && reduction_axes == AxisSet{1})
{
if (reductee_shape.at(1) == 0)
{
writer << "{ // " << node->get_name() << " 4\n";
writer.indent++;
writer.block_begin();
writer << emit_array1d(out[0]) << " =\n"
<< " " << emit_array1d(args[1]) << "(0, 0);\n";
writer.indent--;
writer << "}\n";
writer.block_end();
}
else
{
writer << "{ // " << node->get_name() << " 5\n";
writer.indent++;
writer.block_begin();
string type = f_result_element_type.c_type_string();
writer << "auto f = [&](" << type << " x, " << type << " y) -> " << type
<< "\n{";
<< " {\n";
writer.indent++;
writer << "\n";
writer << type << " result;\n";
writer << "void* args[] = {&x, &y};\n";
writer << "void* out[] = {&result};\n";
......@@ -1484,30 +1392,25 @@ namespace ngraph
writer << "};\n";
writer << emit_vector(out[0]) << " =\n"
<< " " << emit_matrix(args[0]) << ".rowwise().redux(f);\n";
writer.indent--;
writer << "}\n";
writer.block_end();
}
}
else if (reductee_shape.size() == 2 && reduction_axes == AxisSet{0})
{
if (reductee_shape.at(0) == 0)
{
writer << "{ // " << node->get_name() << " 6\n";
writer.indent++;
writer.block_begin();
writer << emit_array1d(out[0]) << " =\n"
<< " " << emit_array1d(args[1]) << "(0, 0);\n";
writer.indent--;
writer << "}\n";
writer.block_end();
}
else
{
writer << "{ // " << node->get_name() << " 7\n";
writer.indent++;
writer.block_begin();
string type = f_result_element_type.c_type_string();
writer << "auto f = [&](" << type << " x, " << type << " y) -> " << type
<< "\n{";
<< " {\n";
writer.indent++;
writer << "\n";
writer << type << " result;\n";
writer << "void* args[] = {&x, &y};\n";
writer << "void* out[] = {&result};\n";
......@@ -1517,20 +1420,17 @@ namespace ngraph
writer << "};\n";
writer << emit_vector(out[0]) << " =\n"
<< " " << emit_matrix(args[0]) << ".colwise().redux(f);\n";
writer.indent--;
writer << "}\n";
writer.block_end();
}
}
else
{
writer << "{ // " << node->get_name() << "\n";
writer.indent++;
writer.block_begin();
string type = f_result_element_type.c_type_string();
writer << "auto f = [&](" << type << " x, " << type << " y) -> " << type
<< "\n{";
<< " {\n";
writer.indent++;
writer << "\n";
writer << type << " result;\n";
writer << "void* args[] = {&x, &y};\n";
writer << "void* out[] = {&result};\n";
......@@ -1548,18 +1448,15 @@ namespace ngraph
writer << " {" << join(reduce->get_reduction_axes()) << "},\n";
writer << " f);\n";
writer.indent--;
writer << "}\n";
writer.block_end();
}
#else
writer << "{ // " << node->get_name() << " 1\n";
writer.indent++;
writer.block_begin();
string type = f_result_element_type.c_type_string();
writer << "auto f = [&](" << type << " x, " << type << " y) -> " << type << "\n{";
writer << "auto f = [&](" << type << " x, " << type << " y) -> " << type << " {\n";
writer.indent++;
writer << "\n";
writer << type << " result;\n";
writer << "void* args[] = {&x, &y};\n";
writer << "void* out[] = {&result};\n";
......@@ -1577,29 +1474,26 @@ namespace ngraph
out[0].get_shape(),
reduce->get_reduction_axes());
writer.indent--;
writer << "}\n";
writer.block_end();
#endif
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Sign)
{
writer << "{ // " << node->get_name() << "\n";
writer.indent++;
writer.block_begin();
#if PREFER_EIGEN == 1
writer << emit_array1d(out[0]) << " =\n"
<< " " << emit_array1d(args[0]) << ".sign();\n";
#else
writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer << "{\n";
writer << " " << out[0].get_name() << "[i] = (0 < " << args[0].get_name()
<< "[i]) - (" << args[0].get_name() << "[i] < 0);\n";
writer << "}\n";
writer.block_begin();
writer << out[0].get_name() << "[i] = (0 < " << args[0].get_name() << "[i]) - ("
<< args[0].get_name() << "[i] < 0);\n";
writer.block_end();
#endif
writer.indent--;
writer << "}\n";
writer.block_end();
}
template <>
......@@ -1607,8 +1501,7 @@ namespace ngraph
{
const ngraph::op::Slice* slice = static_cast<const ngraph::op::Slice*>(node);
writer << "{ // " << node->get_name() << "\n";
writer.indent++;
writer.block_begin();
#if PREFER_EIGEN == 1
size_t arg_rank = args[0].get_shape().size();
......@@ -1628,36 +1521,30 @@ namespace ngraph
// Scalar slice is necessarily just a copy.
if (!strided && arg_rank == 0)
{
writer << "{ // " << node->get_name() << " 1\n";
writer.indent++;
writer.block_begin();
writer << "memcpy(" << out[0].get_name() << ", " << args[0].get_name() << ", "
<< out[0].get_size() * out[0].get_element_type().size() << ");\n";
writer.indent--;
writer << "}\n";
writer.block_end();
}
else if (!strided && arg_rank == 1)
{
writer << "{ // " << node->get_name() << " 2\n";
writer.indent++;
writer.block_begin();
writer << emit_vector(out[0]) << " =\n"
<< " " << emit_vector(args[0]) << ".segment(\n"
<< " " << to_string(lower_bounds[0]) << ", "
<< to_string(upper_bounds[0] - lower_bounds[0]) << ");\n";
writer.indent--;
writer << "}\n";
writer.block_end();
}
else if (!strided && arg_rank == 2)
{
writer << "{ // " << node->get_name() << " 3\n";
writer.indent++;
writer.block_begin();
writer << emit_matrix(out[0]) << " = \n"
<< " " << emit_matrix(args[0]) << ".block("
<< to_string(lower_bounds[0]) << ", " << to_string(lower_bounds[1])
<< ",\n"
<< " " << to_string(upper_bounds[0] - lower_bounds[0]) << ",\n"
<< " " << to_string(upper_bounds[1] - lower_bounds[1]) << ");\n";
writer.indent--;
writer << "}\n";
writer.block_end();
}
// Other cases (reordering of axes for tensors with rank>2) are not handled yet.
else
......@@ -1684,16 +1571,14 @@ namespace ngraph
slice->get_upper_bounds(),
slice->get_strides());
#endif
writer.indent--;
writer << "}\n";
writer.block_end();
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Sum)
{
const ngraph::op::Sum* sum = static_cast<const ngraph::op::Sum*>(node);
writer << "{ // " << node->get_name() << "\n";
writer.indent++;
writer.block_begin();
#if PREFER_EIGEN == 1
const Shape& arg_shape = args[0].get_shape();
size_t arg_rank = arg_shape.size();
......@@ -1702,41 +1587,33 @@ namespace ngraph
// Trivial case: no reduction axes.
if (reduction_axes.size() == 0)
{
writer << "{ // " << node->get_name() << "\n";
writer.indent++;
writer.block_begin();
writer << "memcpy(" << out[0].get_name() << ", " << args[0].get_name() << ", "
<< out[0].get_size() * out[0].get_element_type().size() << ");\n";
writer.indent--;
writer << "}\n";
writer.block_end();
}
// Full reduction? Then sum to scalar.
else if ((arg_rank == 1 && reduction_axes == AxisSet{0}) ||
(arg_rank == 2 && reduction_axes == AxisSet{0, 1}))
{
writer << "{ // " << node->get_name() << "\n";
writer.indent++;
writer.block_begin();
writer << emit_array1d(out[0]) << " =\n"
<< " " << emit_array1d(args[0]) << ".sum();\n";
writer.indent--;
writer << "}\n";
writer.block_end();
}
else if (arg_rank == 2 && reduction_axes == AxisSet{1})
{
writer << "{ // " << node->get_name() << "\n";
writer.indent++;
writer.block_begin();
writer << emit_vector(out[0]) << " =\n"
<< " " << emit_matrix(args[0]) << ".rowwise().sum();\n";
writer.indent--;
writer << "}\n";
writer.block_end();
}
else if (arg_rank == 2 && reduction_axes == AxisSet{0})
{
writer << "{ // " << node->get_name() << "\n";
writer.indent++;
writer.block_begin();
writer << emit_vector(out[0]) << " =\n"
<< " " << emit_matrix(args[0]) << ".colwise().sum();\n";
writer.indent--;
writer << "}\n";
writer.block_end();
}
else
{
......@@ -1757,128 +1634,109 @@ namespace ngraph
out[0].get_shape(),
sum->get_reduction_axes());
#endif
writer.indent--;
writer << "}\n";
writer.block_end();
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Exp)
{
writer << "{ // " << node->get_name() << "\n";
writer.indent++;
writer.block_begin();
#if PREFER_EIGEN == 1
writer << emit_array1d(out[0]) << " =\n"
<< " " << emit_array1d(args[0]) << ".exp();\n";
#else
writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer << "{\n";
writer << " " << out[0].get_name() << "[i] = exp(" << args[0].get_name()
<< "[i]);\n";
writer << "}\n";
writer.block_begin();
writer << out[0].get_name() << "[i] = exp(" << args[0].get_name() << "[i]);\n";
writer.block_end();
#endif
writer.indent--;
writer << "}\n";
writer.block_end();
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Sin)
{
writer << "{ // " << node->get_name() << "\n";
writer.indent++;
writer.block_begin();
#if PREFER_EIGEN == 1
writer << emit_array1d(out[0]) << " =\n"
<< " " << emit_array1d(args[0]) << ".sin();\n";
#else
writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer << "{\n";
writer << " " << out[0].get_name() << "[i] = sin(" << args[0].get_name()
<< "[i]);\n";
writer << "}\n";
writer.block_begin();
writer << out[0].get_name() << "[i] = sin(" << args[0].get_name() << "[i]);\n";
writer.block_end();
#endif
writer.indent--;
writer << "}\n";
writer.block_end();
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Sinh)
{
writer << "{ // " << node->get_name() << "\n";
writer.indent++;
writer.block_begin();
#if PREFER_EIGEN == 1
writer << emit_array1d(out[0]) << " =\n"
<< " " << emit_array1d(args[0]) << ".sinh();\n";
#else
writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer << "{\n";
writer << " " << out[0].get_name() << "[i] = sinh(" << args[0].get_name()
<< "[i]);\n";
writer << "}\n";
writer.block_begin();
writer << out[0].get_name() << "[i] = sinh(" << args[0].get_name() << "[i]);\n";
writer.block_end();
#endif
writer.indent--;
writer << "}\n";
writer.block_end();
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Cos)
{
writer << "{ // " << node->get_name() << "\n";
writer.indent++;
writer.block_begin();
#if PREFER_EIGEN == 1
writer << emit_array1d(out[0]) << " =\n"
<< " " << emit_array1d(args[0]) << ".cos();\n";
#else
writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer << "{\n";
writer << " " << out[0].get_name() << "[i] = cos(" << args[0].get_name()
<< "[i]);\n";
writer << "}\n";
writer.block_begin();
writer << out[0].get_name() << "[i] = cos(" << args[0].get_name() << "[i]);\n";
writer.block_end();
#endif
writer.indent--;
writer << "}\n";
writer.block_end();
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Cosh)
{
writer << "{ // " << node->get_name() << "\n";
writer.indent++;
writer.block_begin();
#if PREFER_EIGEN == 1
writer << emit_array1d(out[0]) << " =\n"
<< " " << emit_array1d(args[0]) << ".cosh();\n";
#else
writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer << "{\n";
writer << " " << out[0].get_name() << "[i] = cosh(" << args[0].get_name()
<< "[i]);\n";
writer << "}\n";
writer.block_begin();
writer << out[0].get_name() << "[i] = cosh(" << args[0].get_name() << "[i]);\n";
writer.block_end();
#endif
writer.indent--;
writer << "}\n";
writer.block_end();
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Tan)
{
writer << "{ // " << node->get_name() << "\n";
writer.indent++;
writer.block_begin();
#if PREFER_EIGEN == 1
writer << emit_array1d(out[0]) << " =\n"
<< " " << emit_array1d(args[0]) << ".tan();\n";
#else
writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer << "{\n";
writer << " " << out[0].get_name() << "[i] = tan(" << args[0].get_name()
<< "[i]);\n";
writer << "}\n";
writer.block_begin();
writer << out[0].get_name() << "[i] = tan(" << args[0].get_name() << "[i]);\n";
writer.block_end();
#endif
writer.indent--;
writer << "}\n";
writer.block_end();
}
template <>
......@@ -1888,85 +1746,72 @@ namespace ngraph
// so we fall-back to tanh
// TODO: Implement our own internal fast/approximate tanh if this actually gets used
// by models
writer << "{ // " << node->get_name() << "\n";
writer.indent++;
writer.block_begin();
#if PREFER_EIGEN == 0
writer << "#pragma omp parallel for\n";
#endif
writer << "for (size_t i=0; i<" << out[0].get_size() << "; i++)\n";
writer << "{\n";
writer << " " << out[0].get_name() << "[i] = tanh(" << args[0].get_name()
<< "[i]);\n";
writer << "}\n";
writer.indent--;
writer << "}\n";
writer.block_begin();
writer << out[0].get_name() << "[i] = tanh(" << args[0].get_name() << "[i]);\n";
writer.block_end();
writer.block_end();
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Asin)
{
writer << "{ // " << node->get_name() << "\n";
writer.indent++;
writer.block_begin();
#if PREFER_EIGEN == 1
writer << emit_array1d(out[0]) << " =\n"
<< " " << emit_array1d(args[0]) << ".asin();\n";
#else
writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer << "{\n";
writer << " " << out[0].get_name() << "[i] = asin(" << args[0].get_name()
<< "[i]);\n";
writer << "}\n";
writer.block_begin();
writer << out[0].get_name() << "[i] = asin(" << args[0].get_name() << "[i]);\n";
writer.block_end();
#endif
writer.indent--;
writer << "}\n";
writer.block_end();
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Acos)
{
writer << "{ // " << node->get_name() << "\n";
writer.indent++;
writer.block_begin();
#if PREFER_EIGEN == 1
writer << emit_array1d(out[0]) << " =\n"
<< " " << emit_array1d(args[0]) << ".acos();\n";
#else
writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer << "{\n";
writer << " " << out[0].get_name() << "[i] = acos(" << args[0].get_name()
<< "[i]);\n";
writer << "}\n";
writer.block_begin();
writer << out[0].get_name() << "[i] = acos(" << args[0].get_name() << "[i]);\n";
writer.block_end();
#endif
writer.indent--;
writer << "}\n";
writer.block_end();
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Atan)
{
writer << "{ // " << node->get_name() << "\n";
writer.indent++;
writer.block_begin();
#if PREFER_EIGEN == 1
writer << emit_array1d(out[0]) << " =\n"
<< " " << emit_array1d(args[0]) << ".atan();\n";
#else
writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer << "{\n";
writer << " " << out[0].get_name() << "[i] = atan(" << args[0].get_name()
<< "[i]);\n";
writer << "}\n";
writer.block_begin();
writer << out[0].get_name() << "[i] = atan(" << args[0].get_name() << "[i]);\n";
writer.block_end();
#endif
writer.indent--;
writer << "}\n";
writer.block_end();
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Power)
{
writer << "{ // " << node->get_name() << "\n";
writer.indent++;
writer.block_begin();
#if PREFER_EIGEN == 1
writer << emit_array1d(out[0]) << " = \n";
writer.indent++;
......@@ -1976,21 +1821,19 @@ namespace ngraph
#else
writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer << "{\n";
writer << " " << out[0].get_name() << "[i] = pow(" << args[0].get_name()
<< "[i], " << args[1].get_name() << "[i]);\n";
writer << "}\n";
writer.block_begin();
writer << out[0].get_name() << "[i] = pow(" << args[0].get_name() << "[i], "
<< args[1].get_name() << "[i]);\n";
writer.block_end();
#endif
writer.indent--;
writer << "}\n";
writer.block_end();
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::ReplaceSlice)
{
auto replace_slice = static_cast<const ngraph::op::Slice*>(node);
writer << "{ // " << node->get_name() << "\n";
writer.indent++;
writer.block_begin();
#if PREFER_EIGEN == 1
size_t arg0_rank = args[0].get_shape().size();
......@@ -2010,30 +1853,25 @@ namespace ngraph
// Scalar slice is necessarily just a copy.
if (!strided && arg0_rank == 0)
{
writer << "{ // " << node->get_name() << " 1\n";
writer.indent++;
writer.block_begin();
writer << "memcpy(" << out[0].get_name() << ", " << args[1].get_name() << ", "
<< out[0].get_size() * out[0].get_element_type().size() << ");\n";
writer.indent--;
writer << "}\n";
writer.block_end();
}
else if (!strided && arg0_rank == 1)
{
writer << "{ // " << node->get_name() << " 2\n";
writer.indent++;
writer.block_begin();
writer << emit_vector(out[0]) << " =\n"
<< " " << emit_vector(args[0]) << ";\n"
<< emit_vector(out[0]) << ".segment(\n"
<< " " << to_string(lower_bounds[0]) << ", "
<< to_string(upper_bounds[0] - lower_bounds[0]) << ") =\n"
<< " " << emit_vector(args[1]) << ";\n";
writer.indent--;
writer << "}\n";
writer.block_end();
}
else if (!strided && arg0_rank == 2)
{
writer << "{ // " << node->get_name() << " 3\n";
writer.indent++;
writer.block_begin();
writer << emit_matrix(out[0]) << " =\n"
<< " " << emit_matrix(args[0]) << ";\n"
<< emit_matrix(out[0]) << ".block(\n"
......@@ -2042,8 +1880,7 @@ namespace ngraph
<< " " << to_string(upper_bounds[0] - lower_bounds[0]) << ",\n"
<< " " << to_string(upper_bounds[1] - lower_bounds[1]) << ") =\n"
<< " " << emit_matrix(args[1]) << ";\n";
writer.indent--;
writer << "}\n";
writer.block_end();
}
// Other cases (reordering of axes for tensors with rank>2) are not handled yet.
else
......@@ -2073,8 +1910,7 @@ namespace ngraph
replace_slice->get_upper_bounds(),
replace_slice->get_strides());
#endif
writer.indent--;
writer << "}\n";
writer.block_end();
}
template <>
......@@ -2088,80 +1924,66 @@ namespace ngraph
if (arg_rank == 0)
{
writer << "{ // " << node->get_name() << " 1\n";
writer.indent++;
writer.block_begin();
writer << emit_vector(out[0], "out_vector") << ";\n";
writer << "out_vector.setZero();\n"
<< ""
<< "auto pos_raw = " << emit_vector(args[0]) << "(0, 0);\n"
<< "if (floor(pos_raw) != pos_raw)\n"
<< "{\n";
writer.indent++;
<< "if (floor(pos_raw) != pos_raw)\n";
writer.block_begin();
writer
<< "throw(std::range_error(\"One-hot: non-integral value in input\"));\n";
writer.indent--;
writer << "}\n";
writer.block_end();
writer << "size_t pos = pos_raw;\n"
<< "if (pos >= " << bounds << ")\n";
writer << "{\n";
writer.indent++;
writer.block_begin();
writer << "throw(std::range_error(\"One-hot: value is out of category "
"range\"));\n";
writer.indent--;
writer << "}\n";
writer.block_end();
writer << "out_vector(pos, 0) = 1;\n";
writer.indent--;
writer << "}\n";
writer.block_end();
}
else if (arg_rank == 1)
{
writer << "{ // " << node->get_name() << " 1\n";
writer.indent++;
writer.block_begin();
writer << emit_vector(args[0], "arg_vector") << ";\n";
writer << emit_matrix(out[0], "out_vector") << ";\n";
writer << "out_vector.setZero();\n";
writer << "for (size_t i = 0; i < " << args[0].get_shape()[0] << "; i++)\n"
<< "{\n";
writer.indent++;
writer << "for (size_t i = 0; i < " << args[0].get_shape()[0] << "; i++)\n";
writer.block_begin();
writer << "auto pos_raw = arg_vector(i, 0);\n";
writer << "if (floor(pos_raw) != pos_raw)\n"
<< "{\n";
writer.indent++;
writer << "if (floor(pos_raw) != pos_raw)\n";
writer.block_begin();
writer
<< "throw(std::range_error(\"One-hot: non-integral value in input\"));\n";
writer.indent--;
writer << "}\n";
writer.block_end();
writer << "size_t pos = pos_raw;\n";
writer << "bool found = false;\n";
writer << "if (pos >= " << bounds << ")\n"
<< "{\n";
writer.indent++;
writer << "if (pos >= " << bounds << ")\n";
writer.block_begin();
writer << "throw(std::range_error(\"One-hot: value is out of category "
"range\"));\n";
writer.indent--;
writer << "}\n";
writer.block_end();
writer << "out_vector"
<< (oh->get_one_hot_axis() == 0 ? "(pos, i)" : "(i, pos)") << " = 1;\n";
writer.indent--;
writer << "}\n";
writer.block_end();
writer.indent--;
writer << "}\n";
writer.block_end();
}
// Other cases are not handled yet.
else
......@@ -2178,55 +2000,46 @@ namespace ngraph
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Ceiling)
{
writer << "{ // " << node->get_name() << "\n";
writer.indent++;
writer.block_begin();
size_t element_count = out[0].get_size();
#if PREFER_EIGEN == 0
writer << "#pragma omp parallel for\n";
#endif
writer << "for (size_t i = 0; i < " << element_count << "; i++)\n";
writer << "{\n";
writer << " " << out[0].get_name() << "[i] = ceil(" << args[0].get_name()
<< "[i]);\n";
writer << "}\n";
writer.indent--;
writer << "}\n";
writer.block_begin();
writer << out[0].get_name() << "[i] = ceil(" << args[0].get_name() << "[i]);\n";
writer.block_end();
writer.block_end();
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Floor)
{
writer << "{ // " << node->get_name() << "\n";
writer.indent++;
writer.block_begin();
size_t element_count = out[0].get_size();
#if PREFER_EIGEN == 0
writer << "#pragma omp parallel for\n";
#endif
writer << "for (size_t i = 0; i < " << element_count << "; i++)\n";
writer << "{\n";
writer << " " << out[0].get_name() << "[i] = floor(" << args[0].get_name()
<< "[i]);\n";
writer << "}\n";
writer.indent--;
writer << "}\n";
writer.block_begin();
writer << out[0].get_name() << "[i] = floor(" << args[0].get_name() << "[i]);\n";
writer.block_end();
writer.block_end();
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Sqrt)
{
writer << "{ // " << node->get_name() << "\n";
writer.indent++;
writer.block_begin();
size_t element_count = out[0].get_size();
#if PREFER_EIGEN == 0
writer << "#pragma omp parallel for\n";
#endif
writer << "for (size_t i = 0; i < " << element_count << "; i++)\n";
writer << "{\n";
writer << " " << out[0].get_name() << "[i] = sqrt(" << args[0].get_name()
<< "[i]);\n";
writer << "}\n";
writer.indent--;
writer << "}\n";
writer.block_begin();
writer << out[0].get_name() << "[i] = sqrt(" << args[0].get_name() << "[i]);\n";
writer.block_end();
writer.block_end();
}
template <>
......@@ -2748,13 +2561,11 @@ namespace ngraph
auto reduction_function = reduce_window->get_functions()[0];
auto& f_result_element_type = out[0].get_element_type();
writer << "{ // " << node->get_name() << "\n";
writer.indent++;
writer.block_begin();
string type = f_result_element_type.c_type_string();
writer << "auto f = [&](" << type << " x, " << type << " y) -> " << type << "\n{";
writer << "auto f = [&](" << type << " x, " << type << " y) -> " << type << " {\n";
writer.indent++;
writer << "\n";
writer << type << " result;\n";
writer << "void* args[] = {&x, &y};\n";
writer << "void* out[] = {&result};\n";
......@@ -2775,8 +2586,7 @@ namespace ngraph
writer << " {"
<< join(reduce_window->get_window_movement_strides()) << "});\n";
writer.indent--;
writer << "}\n";
writer.block_end();
}
template <>
......@@ -2790,14 +2600,12 @@ namespace ngraph
auto arg1_shape = args[1].get_shape();
auto result_shape = out[0].get_shape();
writer << "{ // " << node->get_name() << "\n";
writer.indent++;
writer.block_begin();
string type = node->get_output_element_type(0).c_type_string();
writer << "auto f_select = [&](" << type << " x, " << type << " y) -> char\n{";
writer << "auto f_select = [&](" << type << " x, " << type << " y) -> char {\n";
writer.indent++;
writer << "\n";
writer << "char result;\n";
writer << "void* args[] = {&x, &y};\n";
writer << "void* out[] = {&result};\n";
......@@ -2807,9 +2615,8 @@ namespace ngraph
writer << "};\n";
writer << "auto f_scatter = [&](" << type << " x, " << type << " y) -> " << type
<< "\n{";
<< " {\n";
writer.indent++;
writer << "\n";
writer << type << " result;\n";
writer << "void* args[] = {&x, &y};\n";
writer << "void* out[] = {&result};\n";
......@@ -2833,8 +2640,7 @@ namespace ngraph
writer << " {"
<< join(select_and_scatter->get_window_movement_strides()) << "});\n";
writer.indent--;
writer << "}\n";
writer.block_end();
}
template <>
......@@ -3057,8 +2863,7 @@ namespace ngraph
void CPU_Emitter::EMITTER_DECL(ngraph::op::Product)
{
const ngraph::op::Product* product = static_cast<const ngraph::op::Product*>(node);
writer << "{ // " << node->get_name() << "\n";
writer.indent++;
writer.block_begin();
#if PREFER_EIGEN == 1
const Shape& arg_shape = args[0].get_shape();
size_t arg_rank = arg_shape.size();
......@@ -3067,41 +2872,33 @@ namespace ngraph
// Trivial case: no reduction axes.
if (reduction_axes.size() == 0)
{
writer << "{ // " << node->get_name() << "\n";
writer.indent++;
writer.block_begin();
writer << "memcpy(" << out[0].get_name() << ", " << args[0].get_name() << ", "
<< out[0].get_size() * out[0].get_element_type().size() << ");\n";
writer.indent--;
writer << "}\n";
writer.block_end();
}
// Full reduction? Then reduce to scalar.
else if ((arg_rank == 1 && reduction_axes == AxisSet{0}) ||
(arg_rank == 2 && reduction_axes == AxisSet{0, 1}))
{
writer << "{ // " << node->get_name() << "\n";
writer.indent++;
writer.block_begin();
writer << emit_array1d(out[0]) << " =\n"
<< " " << emit_array1d(args[0]) << ".prod();\n";
writer.indent--;
writer << "}\n";
writer.block_end();
}
else if (arg_rank == 2 && reduction_axes == AxisSet{1})
{
writer << "{ // " << node->get_name() << "\n";
writer.indent++;
writer.block_begin();
writer << emit_vector(out[0]) << " =\n"
<< " " << emit_matrix(args[0]) << ".rowwise().prod();\n";
writer.indent--;
writer << "}\n";
writer.block_end();
}
else if (arg_rank == 2 && reduction_axes == AxisSet{0})
{
writer << "{ // " << node->get_name() << "\n";
writer.indent++;
writer.block_begin();
writer << emit_vector(out[0]) << " =\n"
<< " " << emit_matrix(args[0]) << ".colwise().prod();\n";
writer.indent--;
writer << "}\n";
writer.block_end();
}
else
{
......@@ -3123,16 +2920,14 @@ namespace ngraph
writer << " {" << join(product->get_reduction_axes())
<< "});\n";
#endif
writer.indent--;
writer << "}\n";
writer.block_end();
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Max)
{
const ngraph::op::Max* max = static_cast<const ngraph::op::Max*>(node);
writer << "{ // " << node->get_name() << "\n";
writer.indent++;
writer.block_begin();
#if PREFER_EIGEN == 1
const Shape& arg_shape = args[0].get_shape();
size_t arg_rank = arg_shape.size();
......@@ -3147,41 +2942,33 @@ namespace ngraph
// Trivial case: no reduction axes.
if (!zero_sized && reduction_axes.size() == 0)
{
writer << "{ // " << node->get_name() << "\n";
writer.indent++;
writer.block_begin();
writer << "memcpy(" << out[0].get_name() << ", " << args[0].get_name() << ", "
<< out[0].get_size() * out[0].get_element_type().size() << ");\n";
writer.indent--;
writer << "}\n";
writer.block_end();
}
// Full reduction? Then reduce to scalar.
else if (!zero_sized && ((arg_rank == 1 && reduction_axes == AxisSet{0}) ||
(arg_rank == 2 && reduction_axes == AxisSet{0, 1})))
{
writer << "{ // " << node->get_name() << "\n";
writer.indent++;
writer.block_begin();
writer << emit_array1d(out[0]) << " =\n"
<< " " << emit_array1d(args[0]) << ".maxCoeff();\n";
writer.indent--;
writer << "}\n";
writer.block_end();
}
else if (!zero_sized && arg_rank == 2 && reduction_axes == AxisSet{1})
{
writer << "{ // " << node->get_name() << "\n";
writer.indent++;
writer.block_begin();
writer << emit_vector(out[0]) << " =\n"
<< " " << emit_matrix(args[0]) << ".rowwise().maxCoeff();\n";
writer.indent--;
writer << "}\n";
writer.block_end();
}
else if (!zero_sized && arg_rank == 2 && reduction_axes == AxisSet{0})
{
writer << "{ // " << node->get_name() << "\n";
writer.indent++;
writer.block_begin();
writer << emit_vector(out[0]) << " =\n"
<< " " << emit_matrix(args[0]) << ".colwise().maxCoeff();\n";
writer.indent--;
writer << "}\n";
writer.block_end();
}
else
{
......@@ -3203,16 +2990,14 @@ namespace ngraph
writer << " {" << join(max->get_reduction_axes())
<< "});\n";
#endif
writer.indent--;
writer << "}\n";
writer.block_end();
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Min)
{
const ngraph::op::Min* min = static_cast<const ngraph::op::Min*>(node);
writer << "{ // " << node->get_name() << "\n";
writer.indent++;
writer.block_begin();
#if PREFER_EIGEN == 1
const Shape& arg_shape = args[0].get_shape();
size_t arg_rank = arg_shape.size();
......@@ -3227,41 +3012,33 @@ namespace ngraph
// Trivial case: no reduction axes.
if (!zero_sized && reduction_axes.size() == 0)
{
writer << "{ // " << node->get_name() << "\n";
writer.indent++;
writer.block_begin();
writer << "memcpy(" << out[0].get_name() << ", " << args[0].get_name() << ", "
<< out[0].get_size() * out[0].get_element_type().size() << ");\n";
writer.indent--;
writer << "}\n";
writer.block_end();
}
// Full reduction? Then reduce to scalar.
else if (!zero_sized && ((arg_rank == 1 && reduction_axes == AxisSet{0}) ||
(arg_rank == 2 && reduction_axes == AxisSet{0, 1})))
{
writer << "{ // " << node->get_name() << "\n";
writer.indent++;
writer.block_begin();
writer << emit_array1d(out[0]) << " =\n"
<< " " << emit_array1d(args[0]) << ".minCoeff();\n";
writer.indent--;
writer << "}\n";
writer.block_end();
}
else if (!zero_sized && arg_rank == 2 && reduction_axes == AxisSet{1})
{
writer << "{ // " << node->get_name() << "\n";
writer.indent++;
writer.block_begin();
writer << emit_vector(out[0]) << " =\n"
<< " " << emit_matrix(args[0]) << ".rowwise().minCoeff();\n";
writer.indent--;
writer << "}\n";
writer.block_end();
}
else if (!zero_sized && arg_rank == 2 && reduction_axes == AxisSet{0})
{
writer << "{ // " << node->get_name() << "\n";
writer.indent++;
writer.block_begin();
writer << emit_vector(out[0]) << " =\n"
<< " " << emit_matrix(args[0]) << ".colwise().minCoeff();\n";
writer.indent--;
writer << "}\n";
writer.block_end();
}
else
{
......@@ -3283,8 +3060,7 @@ namespace ngraph
writer << " {" << join(min->get_reduction_axes())
<< "});\n";
#endif
writer.indent--;
writer << "}\n";
writer.block_end();
}
template <>
......@@ -3358,10 +3134,10 @@ namespace ngraph
{
writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer << "{\n";
writer << " " << out[0].get_name() << "[i] = " << args[0].get_name()
<< "[i] > 0 ? " << args[1].get_name() << "[i] : 0;\n";
writer << "}\n";
writer.block_begin();
writer << out[0].get_name() << "[i] = " << args[0].get_name() << "[i] > 0 ? "
<< args[1].get_name() << "[i] : 0;\n";
writer.block_end();
}
}
......@@ -3391,10 +3167,10 @@ namespace ngraph
{
writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer << "{\n";
writer << " " << out[0].get_name() << "[i] = " << args[0].get_name()
<< "[i] > 0 ? " << args[0].get_name() << "[i] : 0;\n";
writer << "}\n";
writer.block_begin();
writer << out[0].get_name() << "[i] = " << args[0].get_name() << "[i] > 0 ? "
<< args[0].get_name() << "[i] : 0;\n";
writer.block_end();
}
}
......@@ -3518,8 +3294,7 @@ namespace ngraph
writer << "#pragma omp parallel for\n";
writer << "for (size_t i" << d << " = 0; i" << d << " < " << shape[d]
<< "; ++i" << d << ")\n";
writer << "{\n";
writer.indent++;
writer.block_begin();
}
}
......@@ -3532,25 +3307,21 @@ namespace ngraph
{
writer << "for (size_t i" << d << " = 0; i" << d << " < " << shape[d]
<< "; ++i" << d << ")\n";
writer << "{\n";
writer.indent++;
writer.block_begin();
}
}
writer << "if (arg" << index << " > m)\n";
writer << "{\n";
writer.indent++;
writer.block_begin();
writer << "m = arg" << index << ";\n";
writer.indent--;
writer << "}\n";
writer.block_end();
// end max inner loop(s)
for (size_t d = 0; d < dims; ++d)
{
if (axes.find(d) != axes.end())
{
writer.indent--;
writer << "}\n";
writer.block_end();
}
}
......@@ -3561,8 +3332,7 @@ namespace ngraph
{
writer << "for (size_t i" << d << " = 0; i" << d << " < " << shape[d]
<< "; ++i" << d << ")\n";
writer << "{\n";
writer.indent++;
writer.block_begin();
}
}
......@@ -3573,8 +3343,7 @@ namespace ngraph
{
if (axes.find(d) != axes.end())
{
writer.indent--;
writer << "}\n";
writer.block_end();
}
}
......@@ -3583,8 +3352,7 @@ namespace ngraph
{
if (axes.find(d) == axes.end())
{
writer.indent--;
writer << "}\n";
writer.block_end();
}
}
......@@ -3597,8 +3365,7 @@ namespace ngraph
writer << "#pragma omp parallel for\n";
writer << "for (size_t i" << d << " = 0; i" << d << " < " << shape[d]
<< "; ++i" << d << ")\n";
writer << "{\n";
writer.indent++;
writer.block_begin();
}
}
......@@ -3611,8 +3378,7 @@ namespace ngraph
{
writer << "for (size_t i" << d << " = 0; i" << d << " < " << shape[d]
<< "; ++i" << d << ")\n";
writer << "{\n";
writer.indent++;
writer.block_begin();
}
}
......@@ -3623,8 +3389,7 @@ namespace ngraph
{
if (axes.find(d) != axes.end())
{
writer.indent--;
writer << "}\n";
writer.block_end();
}
}
......@@ -3637,8 +3402,7 @@ namespace ngraph
{
writer << "for (size_t i" << d << " = 0; i" << d << " < " << shape[d]
<< "; ++i" << d << ")\n";
writer << "{\n";
writer.indent++;
writer.block_begin();
}
}
......@@ -3649,8 +3413,7 @@ namespace ngraph
{
if (axes.find(d) != axes.end())
{
writer.indent--;
writer << "}\n";
writer.block_end();
}
}
......@@ -3659,8 +3422,7 @@ namespace ngraph
{
if (axes.find(d) == axes.end())
{
writer.indent--;
writer << "}\n";
writer.block_end();
}
}
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment