Commit 556179a2 authored by Jayaram Bobba's avatar Jayaram Bobba Committed by Scott Cyphers

Remove old eigen code in codegen and misc bug fixes (#2189)

* Remove old Eigen code

* Bug fixes to unordered map checks
parent dbf3703a
......@@ -134,9 +134,6 @@
using namespace std;
using namespace ngraph;
// Enables old unoptimized Eigen code paths
#define USE_EIGEN_CORE_INLINE 0
static bool s_use_ref_kernels = (std::getenv("NGRAPH_CPU_USE_REF_KERNELS") != nullptr);
static string eigen_vector_format(const runtime::cpu::TensorViewWrapper& tvi)
......@@ -160,22 +157,7 @@ namespace ngraph
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Add)
{
// TODO: Audit all uses of Add and fix this to use
// the right alignment instead of Eigen::Unaligned
writer.block_begin();
#if USE_EIGEN_CORE_INLINE == 1
writer << "Eigen::Map<Eigen::Array<" << out[0].get_element_type().c_type_string()
<< ", " << out[0].get_size() << ", 1>, Eigen::Unaligned> out("
<< out[0].get_name() << ");\n";
writer << "Eigen::Map<Eigen::Array<" << args[0].get_element_type().c_type_string()
<< ", " << args[0].get_size() << ", 1>, Eigen::Unaligned> arg0("
<< args[0].get_name() << ");\n";
writer << "Eigen::Map<Eigen::Array<" << args[1].get_element_type().c_type_string()
<< ", " << args[1].get_size() << ", 1>, Eigen::Unaligned> arg1("
<< args[1].get_name() << ");\n";
writer << "out = arg0 + arg1;\n";
#else
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{
std::vector<float> scale_vector(2, 1);
......@@ -213,7 +195,6 @@ namespace ngraph
<< args[1].get_name() << "[i];\n";
writer.block_end();
}
#endif
writer.block_end();
}
......@@ -988,18 +969,12 @@ namespace ngraph
void CPU_Emitter::EMITTER_DECL(ngraph::op::Multiply)
{
writer.block_begin();
#if USE_EIGEN_CORE_INLINE == 1
writer << emit_array1d(out[0]) << " =\n"
<< " " << emit_array1d(args[0]) << " *\n"
<< " " << emit_array1d(args[1]) << ";\n";
#else
writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer.block_begin();
writer << out[0].get_name() << "[i] = " << args[0].get_name() << "[i] * "
<< args[1].get_name() << "[i];\n";
writer.block_end();
#endif
writer.block_end();
}
......@@ -1019,10 +994,6 @@ namespace ngraph
void CPU_Emitter::EMITTER_DECL(ngraph::op::Abs)
{
writer.block_begin();
#if USE_EIGEN_CORE_INLINE == 1
writer << emit_array1d(out[0]) << " =\n";
writer << "Eigen::abs(" << emit_array1d(args[0]) << ");\n";
#else
// Some C++ implementations don't like it when we call std::abs on unsigned types, so we will
// avoid doing so here.
auto& result_element_type = out[0].get_element_type();
......@@ -1034,7 +1005,6 @@ namespace ngraph
<< "[i] = " << (result_element_type.is_signed() ? "std::abs" : "") << "("
<< args[0].get_name() << "[i]);\n";
writer.block_end();
#endif
writer.block_end();
}
......@@ -1066,92 +1036,6 @@ namespace ngraph
}
auto result_shape = out[0].get_shape();
#if USE_EIGEN_CORE_INLINE == 1
if (result_shape.size() == 1)
{
writer.block_begin();
writer << emit_vector(out[0], "out_vector") << ";\n";
size_t concat_pos = 0;
for (size_t i = 0; i < args.size(); i++)
{
writer << "out_vector.segment(" << concat_pos << ", "
<< args[i].get_shape().at(0) << ") << " << emit_vector(args[i])
<< ";\n";
concat_pos += args[i].get_shape().at(0);
}
writer.block_end();
}
else if (result_shape.size() == 2)
{
auto axis =
(dynamic_cast<const ngraph::op::Concat*>(node))->get_concatenation_axis();
writer.block_begin();
writer << emit_matrix(out[0], "out_matrix") << ";\n";
size_t concat_pos[2]{0, 0};
for (size_t i = 0; i < args.size(); i++)
{
auto& arg_shape = args[i].get_shape();
writer << "out_matrix.block(" << concat_pos[0] << ", " << concat_pos[1]
<< ", " << arg_shape.at(0) << ", " << arg_shape.at(1) << ") << "
<< emit_matrix(args[i]) << ";\n";
concat_pos[axis] += arg_shape.at(axis);
}
writer.block_end();
}
else
{
if (s_use_ref_kernels)
{
auto axis = (dynamic_cast<const ngraph::op::Concat*>(node))
->get_concatenation_axis();
std::vector<std::string> arg_names;
std::vector<std::string> arg_shape_strings;
for (auto arg : args)
{
arg_names.push_back(arg.get_name());
arg_shape_strings.push_back("{" + join(arg.get_shape()) + "}");
}
writer << "reference::concat<" << out[0].get_type() << ">({"
<< join(arg_names) << "},\n";
writer << " " << out[0].get_name() << ",\n";
writer << " {" << join(arg_shape_strings) << "},\n";
writer << " {" << join(result_shape) << "},\n";
writer << " " << axis << ");\n";
}
else
{
auto axis = (dynamic_cast<const ngraph::op::Concat*>(node))
->get_concatenation_axis();
std::vector<std::string> arg_names;
std::vector<Shape> arg_shapes;
for (auto arg : args)
{
arg_names.push_back(arg.get_name());
arg_shapes.push_back(arg.get_shape());
}
kernel::emit_concat(writer,
args[0].get_element_type().c_type_string(),
arg_names,
out[0].get_name(),
arg_shapes,
result_shape,
axis);
}
}
#else
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
......@@ -1203,7 +1087,6 @@ namespace ngraph
result_shape,
axis);
}
#endif
}
template <>
......@@ -1220,18 +1103,12 @@ namespace ngraph
<< "[i] == 0) throw std::runtime_error(\"integer divide by zero\");\n";
writer.block_end();
}
#if USE_EIGEN_CORE_INLINE == 1
writer << emit_array1d(out[0]) << " =\n"
<< " " << emit_array1d(args[0]) << " /\n"
<< " " << emit_array1d(args[1]) << ";\n";
#else
writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer.block_begin();
writer << out[0].get_name() << "[i] = " << args[0].get_name() << "[i] / "
<< args[1].get_name() << "[i];\n";
writer.block_end();
#endif
writer.block_end();
}
......@@ -1239,18 +1116,12 @@ namespace ngraph
void CPU_Emitter::EMITTER_DECL(ngraph::op::Equal)
{
writer.block_begin();
#if USE_EIGEN_CORE_INLINE == 1
writer << emit_array1d(out[0]) << " =\n"
<< " (" << emit_array1d(args[0]) << " ==\n"
<< " " << emit_array1d(args[1]) << ").template cast<char>();\n";
#else
writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer.block_begin();
writer << out[0].get_name() << "[i] = " << args[0].get_name()
<< "[i] == " << args[1].get_name() << "[i];\n";
writer.block_end();
#endif
writer.block_end();
}
......@@ -1258,18 +1129,12 @@ namespace ngraph
void CPU_Emitter::EMITTER_DECL(ngraph::op::Greater)
{
writer.block_begin();
#if USE_EIGEN_CORE_INLINE == 1
writer << emit_array1d(out[0]) << " =\n"
<< " (" << emit_array1d(args[0]) << " >\n"
<< " " << emit_array1d(args[1]) << ").template cast<char>();\n";
#else
writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer.block_begin();
writer << out[0].get_name() << "[i] = " << args[0].get_name() << "[i] > "
<< args[1].get_name() << "[i];\n";
writer.block_end();
#endif
writer.block_end();
}
......@@ -1277,18 +1142,12 @@ namespace ngraph
void CPU_Emitter::EMITTER_DECL(ngraph::op::GreaterEq)
{
writer.block_begin();
#if USE_EIGEN_CORE_INLINE == 1
writer << emit_array1d(out[0]) << " =\n"
<< " (" << emit_array1d(args[0]) << " >=\n"
<< " " << emit_array1d(args[1]) << ").template cast<char>();\n";
#else
writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer.block_begin();
writer << out[0].get_name() << "[i] = " << args[0].get_name()
<< "[i] >= " << args[1].get_name() << "[i];\n";
writer.block_end();
#endif
writer.block_end();
}
......@@ -1296,18 +1155,12 @@ namespace ngraph
void CPU_Emitter::EMITTER_DECL(ngraph::op::Less)
{
writer.block_begin();
#if USE_EIGEN_CORE_INLINE == 1
writer << emit_array1d(out[0]) << " =\n"
<< " (" << emit_array1d(args[0]) << " <\n"
<< " " << emit_array1d(args[1]) << ").template cast<char>();\n";
#else
writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer.block_begin();
writer << out[0].get_name() << "[i] = " << args[0].get_name() << "[i] < "
<< args[1].get_name() << "[i];\n";
writer.block_end();
#endif
writer.block_end();
}
......@@ -1315,18 +1168,12 @@ namespace ngraph
void CPU_Emitter::EMITTER_DECL(ngraph::op::LessEq)
{
writer.block_begin();
#if USE_EIGEN_CORE_INLINE == 1
writer << emit_array1d(out[0]) << " =\n"
<< " (" << emit_array1d(args[0]) << " <=\n"
<< " " << emit_array1d(args[1]) << ").template cast<char>();\n";
#else
writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer.block_begin();
writer << out[0].get_name() << "[i] = " << args[0].get_name()
<< "[i] <= " << args[1].get_name() << "[i];\n";
writer.block_end();
#endif
writer.block_end();
}
......@@ -1377,16 +1224,11 @@ namespace ngraph
void CPU_Emitter::EMITTER_DECL(ngraph::op::Log)
{
writer.block_begin();
#if USE_EIGEN_CORE_INLINE == 1
writer << emit_array1d(out[0]) << " =\n"
<< " Eigen::log(" << emit_array1d(args[0]) << ");\n";
#else
writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer.block_begin();
writer << out[0].get_name() << "[i] = log(" << args[0].get_name() << "[i]);\n";
writer.block_end();
#endif
writer.block_end();
}
......@@ -1394,11 +1236,6 @@ namespace ngraph
void CPU_Emitter::EMITTER_DECL(ngraph::op::Maximum)
{
writer.block_begin();
#if USE_EIGEN_CORE_INLINE == 1
writer << emit_array1d(out[0]) << " =\n"
<< " " << emit_array1d(args[0]) << ".max(\n"
<< " " << emit_array1d(args[1]) << ");\n";
#else
writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer.block_begin();
......@@ -1406,7 +1243,6 @@ namespace ngraph
<< args[1].get_name() << "[i] ? " << args[0].get_name()
<< "[i] : " << args[1].get_name() << "[i] ;\n";
writer.block_end();
#endif
writer.block_end();
}
......@@ -1414,11 +1250,6 @@ namespace ngraph
void CPU_Emitter::EMITTER_DECL(ngraph::op::Minimum)
{
writer.block_begin();
#if USE_EIGEN_CORE_INLINE == 1
writer << emit_array1d(out[0]) << " =\n"
<< " " << emit_array1d(args[0]) << ".min(\n"
<< " " << emit_array1d(args[1]) << ");\n";
#else
writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer.block_begin();
......@@ -1426,7 +1257,6 @@ namespace ngraph
<< args[1].get_name() << "[i] ? " << args[0].get_name()
<< "[i] : " << args[1].get_name() << "[i] ;\n";
writer.block_end();
#endif
writer.block_end();
}
......@@ -1434,16 +1264,11 @@ namespace ngraph
void CPU_Emitter::EMITTER_DECL(ngraph::op::Negative)
{
writer.block_begin();
#if USE_EIGEN_CORE_INLINE == 1
writer << emit_array1d(out[0]) << " =\n"
<< " -" << emit_array1d(args[0]) << ";\n";
#else
writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer.block_begin();
writer << out[0].get_name() << "[i] = -" << args[0].get_name() << "[i];\n";
writer.block_end();
#endif
writer.block_end();
}
......@@ -1451,18 +1276,12 @@ namespace ngraph
void CPU_Emitter::EMITTER_DECL(ngraph::op::NotEqual)
{
writer.block_begin();
#if USE_EIGEN_CORE_INLINE == 1
writer << emit_array1d(out[0]) << " =\n"
<< " (" << emit_array1d(args[0]) << " !=\n"
<< " " << emit_array1d(args[1]) << ").template cast<char>();\n";
#else
writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer.block_begin();
writer << out[0].get_name() << "[i] = " << args[0].get_name()
<< "[i] != " << args[1].get_name() << "[i];\n";
writer.block_end();
#endif
writer.block_end();
}
......@@ -1470,19 +1289,12 @@ namespace ngraph
void CPU_Emitter::EMITTER_DECL(ngraph::op::Select)
{
writer.block_begin();
#if USE_EIGEN_CORE_INLINE == 1
writer << emit_array1d(out[0]) << " =\n"
<< " " << emit_array1d(args[0]) << "\n"
<< " .select(" << emit_array1d(args[1]) << ",\n"
<< " " << emit_array1d(args[2]) << ");\n";
#else
writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer.block_begin();
writer << out[0].get_name() << "[i] = " << args[0].get_name() << "[i] ? "
<< args[1].get_name() << "[i] : " << args[2].get_name() << "[i];\n";
writer.block_end();
#endif
writer.block_end();
}
......@@ -1490,18 +1302,12 @@ namespace ngraph
void CPU_Emitter::EMITTER_DECL(ngraph::op::Subtract)
{
writer.block_begin();
#if USE_EIGEN_CORE_INLINE == 1
writer << emit_array1d(out[0]) << " =\n"
<< " " << emit_array1d(args[0]) << " -\n"
<< " " << emit_array1d(args[1]) << ";\n";
#else
writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer.block_begin();
writer << out[0].get_name() << "[i] = " << args[0].get_name() << "[i] - "
<< args[1].get_name() << "[i];\n";
writer.block_end();
#endif
writer.block_end();
}
......@@ -1511,73 +1317,6 @@ namespace ngraph
auto broadcast = static_cast<const ngraph::op::Broadcast*>(node);
writer.block_begin();
#if USE_EIGEN_CORE_INLINE == 1
auto arg_shape = args[0].get_shape();
auto result_shape = out[0].get_shape();
if (broadcast->get_broadcast_axes().empty())
{
writer.block_begin();
writer << "memcpy(" << out[0].get_name() << ", " << args[0].get_name() << ", "
<< out[0].get_size() * out[0].get_element_type().size() << ");\n";
writer.block_end();
}
else if (arg_shape.size() == 0)
{
writer.block_begin();
writer << emit_array1d(out[0]) << " =\n"
<< " " << emit_array1d(args[0]) << "(0, 0);\n";
writer.block_end();
}
else if (arg_shape.size() == 1 && result_shape.size() == 2)
{
if (broadcast->get_broadcast_axes() == AxisSet{1})
{
writer.block_begin();
writer << emit_matrix(out[0]) << ".colwise() =\n"
<< " " << emit_vector(args[0]) << ";\n";
writer.block_end();
}
else if (broadcast->get_broadcast_axes() == AxisSet{0})
{
writer.block_begin();
writer << "Eigen::Map<Eigen::Matrix<"
<< out[0].get_element_type().c_type_string() << ", "
<< join(out[0].get_shape())
<< ", Eigen::RowMajor>, Eigen::Aligned64, Eigen::Stride<"
<< join(out[0].get_strides()) << ">> out(" << out[0].get_name()
<< ");\n";
writer << "Eigen::Map<Eigen::Matrix<"
<< args[0].get_element_type().c_type_string() << ", 1, "
<< args[0].get_size()
<< ", Eigen::RowMajor>, Eigen::Aligned64, Eigen::Stride<"
<< args[0].get_size() << ", 1>> arg0(" << args[0].get_name()
<< ");\n";
writer << "out = arg0.replicate<" << out[0].get_shape().at(0)
<< ", 1>();\n";
writer.block_end();
}
else
{
throw ngraph_error(
"Internal error: axis set for vector-matrix broadcast is neither {0} "
"nor "
"{1}");
}
}
else
{
writer << "reference::broadcast<" << out[0].get_type() << ">("
<< args[0].get_name() << ",\n";
writer << " " << out[0].get_name() << ",\n";
writer << " {" << join(arg_shape) << "},\n";
writer << " {" << join(result_shape) << "},\n";
writer << " {" << join(broadcast->get_broadcast_axes())
<< "});\n";
}
#else
kernel::emit_broadcast(writer,
args[0].get_element_type().c_type_string(),
args[0].get_name(),
......@@ -1585,7 +1324,6 @@ namespace ngraph
args[0].get_shape(),
out[0].get_shape(),
broadcast->get_broadcast_axes());
#endif
writer.block_end();
}
......@@ -1595,18 +1333,12 @@ namespace ngraph
auto& result_element_type = out[0].get_element_type();
writer.block_begin();
#if USE_EIGEN_CORE_INLINE == 1
writer << emit_array1d(out[0]) << " =\n"
<< " " << emit_array1d(args[0]) << "\n"
<< " .template cast<" << result_element_type.c_type_string() << ">();\n";
#else
writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer.block_begin();
writer << out[0].get_name() << "[i] = (" << result_element_type.c_type_string()
<< ")(" << args[0].get_name() << "[i]);\n";
writer.block_end();
#endif
writer.block_end();
}
......@@ -1660,68 +1392,6 @@ namespace ngraph
}
writer.block_begin();
#if USE_EIGEN_CORE_INLINE == 1
auto arg_shape = args[0].get_shape();
auto arg_rank = arg_shape.size();
auto result_shape = out[0].get_shape();
auto& result_element_type = out[0].get_element_type();
auto input_order = reshape->get_input_order();
bool same_layout = is_sorted(input_order.begin(), input_order.end());
size_t result_shape_product = 1;
for (auto i : result_shape)
{
result_shape_product *= i;
}
// If there is no layout change or we are just going from 1^n to 1^m or a zero-size tensor,
// we can just copy.
if (same_layout || result_shape_product < 2)
{
writer.block_begin();
writer << "memcpy(" << out[0].get_name() << ", " << args[0].get_name() << ", "
<< out[0].get_size() * out[0].get_element_type().size() << ");\n";
writer.block_end();
}
// If there *is* a layout change in the 2D case, we transpose the input.
else if (arg_rank == 2)
{
// Emit an MKL transpose call if possible
if (result_element_type == ngraph::element::f32)
{
writer.block_begin();
writer << "mkl::MKL_Somatcopy('R', 'T', " << to_string(arg_shape[0])
<< ",\n"
<< " " << to_string(arg_shape[1]) << ", 1.0f,\n"
<< " " << args[0].get_name() << ", "
<< to_string(arg_shape[1]) << ",\n"
<< " " << out[0].get_name() << ", "
<< to_string(arg_shape[0]) << ");\n";
writer.block_end();
}
else
{
writer.block_begin();
writer << emit_matrix(out[0]) << " =\n"
<< " " << emit_matrix(args[0]) << ".transpose();\n";
writer.block_end();
}
}
// Other cases
else
{
writer << "reference::reshape<" << out[0].get_type() << ">("
<< args[0].get_name() << ",\n";
writer << " " << out[0].get_name() << ",\n";
writer << " {" << join(args[0].get_shape()) << "},\n";
writer << " {" << join(reshape->get_input_order()) << "},\n";
writer << " {" << join(out[0].get_shape()) << "}\n";
writer << " );\n";
}
#else
if (args[0].get_element_type() == element::f32 && args[0].get_shape().size() == 3 &&
out[0].get_shape().size() == 3)
{
......@@ -1753,7 +1423,6 @@ namespace ngraph
reshape->get_input_order());
}
#endif
writer.block_end();
}
......@@ -1813,156 +1482,6 @@ namespace ngraph
auto& f_result_element_type = out[0].get_element_type();
auto result_shape = out[0].get_shape();
#if USE_EIGEN_CORE_INLINE == 1
auto& reduction_axes = reduce->get_reduction_axes();
// Trivial case: no reduction axes (this includes the scalar-reductee case).
if (reduction_axes.empty())
{
writer.block_begin();
writer << "memcpy(" << out[0].get_name() << ", " << args[0].get_name() << ", "
<< out[0].get_size() * out[0].get_element_type().size() << ");\n";
writer.block_end();
}
// Behavior for zero-size axes bears some explanation here. XLA's reduce
// operator provides an "base" element (usually, but not necessarily,
// an identity element) that it apparently *may* choose to insert anywhere
// in the reduction any number of times. For example, given:
//
// reduce{{1,2,3},b,+)
//
// any of the following are valid reductions (I think!):
//
// b+(b+1+2)+3
// b+(1+(2+3))
// (1+2)+3 (I think!)
//
// etc. Here we will choose never to instantiate the base element, which
// works well with Eigen's default behavior for non-zero-length axes. The
// exceptional case is when we reduce on a zero-length axis. In this case,
// Eigen's default behavior is to put a zero in the output, which is not
// what we want, so we detect that case here and override with a copy
// instruction (for reduce-to-scalar) or a broadcast (for reduce-to-vector)
// from the base element.
//
// What I'm actually not sure about is whether the identity element is
// required to appear at least once. If so, this will need to be reworked,
// assuming we actually want to mimic XLA's semantics that closely, which
// we may not.
else if ((reductee_shape.size() == 1 && reduction_axes == AxisSet{0}) ||
(reductee_shape.size() == 2 && reduction_axes == AxisSet{0, 1}))
{
if (reductee_shape.at(0) == 0 ||
(reductee_shape.size() == 2 && reductee_shape.at(1) == 0))
{
writer.block_begin();
writer << "memcpy(" << out[0].get_name() << ", " << args[1].get_name()
<< ", " << out[0].get_size() * out[0].get_element_type().size()
<< ");\n";
writer.block_end();
}
else
{
writer.block_begin();
string type = f_result_element_type.c_type_string();
writer << "auto f = [&](" << type << " x, " << type << " y) -> " << type
<< " {\n";
writer.indent++;
writer << type << " result;\n";
writer << "void* args[] = {&x, &y};\n";
writer << "void* out[] = {&result};\n";
writer << reduction_function->get_name() << "(args, out, ctx);\n";
writer << "return result;\n";
writer.indent--;
writer << "};\n";
writer << emit_array1d(out[0]) << " =\n"
<< " " << emit_array1d(args[0]) << ".redux(f);\n";
writer.block_end();
}
}
else if (reductee_shape.size() == 2 && reduction_axes == AxisSet{1})
{
if (reductee_shape.at(1) == 0)
{
writer.block_begin();
writer << emit_array1d(out[0]) << " =\n"
<< " " << emit_array1d(args[1]) << "(0, 0);\n";
writer.block_end();
}
else
{
writer.block_begin();
string type = f_result_element_type.c_type_string();
writer << "auto f = [&](" << type << " x, " << type << " y) -> " << type
<< " {\n";
writer.indent++;
writer << type << " result;\n";
writer << "void* args[] = {&x, &y};\n";
writer << "void* out[] = {&result};\n";
writer << reduction_function->get_name() << "(args, out, ctx);\n";
writer << "return result;\n";
writer.indent--;
writer << "};\n";
writer << emit_vector(out[0]) << " =\n"
<< " " << emit_matrix(args[0]) << ".rowwise().redux(f);\n";
writer.block_end();
}
}
else if (reductee_shape.size() == 2 && reduction_axes == AxisSet{0})
{
if (reductee_shape.at(0) == 0)
{
writer.block_begin();
writer << emit_array1d(out[0]) << " =\n"
<< " " << emit_array1d(args[1]) << "(0, 0);\n";
writer.block_end();
}
else
{
writer.block_begin();
string type = f_result_element_type.c_type_string();
writer << "auto f = [&](" << type << " x, " << type << " y) -> " << type
<< " {\n";
writer.indent++;
writer << type << " result;\n";
writer << "void* args[] = {&x, &y};\n";
writer << "void* out[] = {&result};\n";
writer << reduction_function->get_name() << "(args, out, ctx);\n";
writer << "return result;\n";
writer.indent--;
writer << "};\n";
writer << emit_vector(out[0]) << " =\n"
<< " " << emit_matrix(args[0]) << ".colwise().redux(f);\n";
writer.block_end();
}
}
else
{
writer.block_begin();
string type = f_result_element_type.c_type_string();
writer << "auto f = [&](" << type << " x, " << type << " y) -> " << type
<< " {\n";
writer.indent++;
writer << type << " result;\n";
writer << "void* args[] = {&x, &y};\n";
writer << "void* out[] = {&result};\n";
writer << reduction_function->get_name() << "(args, out, ctx);\n";
writer << "return result;\n";
writer.indent--;
writer << "};\n";
writer << "reference::reduce<" << out[0].get_type() << ">("
<< args[0].get_name() << ",\n";
writer << " " << args[1].get_name() << ",\n";
writer << " " << out[0].get_name() << ",\n";
writer << " {" << join(args[0].get_shape()) << "},\n";
writer << " {" << join(out[0].get_shape()) << "},\n";
writer << " {" << join(reduce->get_reduction_axes()) << "},\n";
writer << " f);\n";
writer.block_end();
}
#else
writer.block_begin();
string type = f_result_element_type.c_type_string();
......@@ -1987,24 +1506,18 @@ namespace ngraph
reduce->get_reduction_axes());
writer.block_end();
#endif
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Sign)
{
writer.block_begin();
#if USE_EIGEN_CORE_INLINE == 1
writer << emit_array1d(out[0]) << " =\n"
<< " " << emit_array1d(args[0]) << ".sign();\n";
#else
writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer.block_begin();
writer << out[0].get_name() << "[i] = (0 < " << args[0].get_name() << "[i]) - ("
<< args[0].get_name() << "[i] < 0);\n";
writer.block_end();
#endif
writer.block_end();
}
......@@ -2065,65 +1578,6 @@ namespace ngraph
}
writer.block_begin();
#if USE_EIGEN_CORE_INLINE == 1
size_t arg_rank = args[0].get_shape().size();
const Coordinate& lower_bounds = slice->get_lower_bounds();
const Coordinate& upper_bounds = slice->get_upper_bounds();
bool strided = false;
for (size_t stride : slice->get_strides())
{
if (stride != 1)
{
strided = true;
break;
}
}
// Scalar slice is necessarily just a copy.
if (!strided && arg_rank == 0)
{
writer.block_begin();
writer << "memcpy(" << out[0].get_name() << ", " << args[0].get_name() << ", "
<< out[0].get_size() * out[0].get_element_type().size() << ");\n";
writer.block_end();
}
else if (!strided && arg_rank == 1)
{
writer.block_begin();
writer << emit_vector(out[0]) << " =\n"
<< " " << emit_vector(args[0]) << ".segment(\n"
<< " " << to_string(lower_bounds[0]) << ", "
<< to_string(upper_bounds[0] - lower_bounds[0]) << ");\n";
writer.block_end();
}
else if (!strided && arg_rank == 2)
{
writer.block_begin();
writer << emit_matrix(out[0]) << " = \n"
<< " " << emit_matrix(args[0]) << ".block("
<< to_string(lower_bounds[0]) << ", " << to_string(lower_bounds[1])
<< ",\n"
<< " " << to_string(upper_bounds[0] - lower_bounds[0]) << ",\n"
<< " " << to_string(upper_bounds[1] - lower_bounds[1]) << ");\n";
writer.block_end();
}
// Other cases (reordering of axes for tensors with rank>2) are not handled yet.
else
{
writer << "reference::slice<" << out[0].get_type() << ">(" << args[0].get_name()
<< ",\n";
writer << " " << out[0].get_name() << ",\n";
writer << " {" << join(args[0].get_shape()) << "},\n";
writer << " {" << join(slice->get_lower_bounds())
<< "},\n";
writer << " {" << join(slice->get_upper_bounds())
<< "},\n";
writer << " {" << join(slice->get_strides()) << "},\n";
writer << " {" << join(out[0].get_shape()) << "});\n";
}
#else
kernel::emit_slice(writer,
args[0].get_element_type().c_type_string(),
args[0].get_name(),
......@@ -2133,7 +1587,6 @@ namespace ngraph
slice->get_lower_bounds(),
slice->get_upper_bounds(),
slice->get_strides());
#endif
writer.block_end();
}
......@@ -2142,53 +1595,6 @@ namespace ngraph
{
const ngraph::op::Sum* sum = static_cast<const ngraph::op::Sum*>(node);
writer.block_begin();
#if USE_EIGEN_CORE_INLINE == 1
const Shape& arg_shape = args[0].get_shape();
size_t arg_rank = arg_shape.size();
const AxisSet& reduction_axes = sum->get_reduction_axes();
// Trivial case: no reduction axes.
if (reduction_axes.size() == 0)
{
writer.block_begin();
writer << "memcpy(" << out[0].get_name() << ", " << args[0].get_name() << ", "
<< out[0].get_size() * out[0].get_element_type().size() << ");\n";
writer.block_end();
}
// Full reduction? Then sum to scalar.
else if ((arg_rank == 1 && reduction_axes == AxisSet{0}) ||
(arg_rank == 2 && reduction_axes == AxisSet{0, 1}))
{
writer.block_begin();
writer << emit_array1d(out[0]) << " =\n"
<< " " << emit_array1d(args[0]) << ".sum();\n";
writer.block_end();
}
else if (arg_rank == 2 && reduction_axes == AxisSet{1})
{
writer.block_begin();
writer << emit_vector(out[0]) << " =\n"
<< " " << emit_matrix(args[0]) << ".rowwise().sum();\n";
writer.block_end();
}
else if (arg_rank == 2 && reduction_axes == AxisSet{0})
{
writer.block_begin();
writer << emit_vector(out[0]) << " =\n"
<< " " << emit_matrix(args[0]) << ".colwise().sum();\n";
writer.block_end();
}
else
{
writer << "reference::sum<" << out[0].get_type() << ">(" << args[0].get_name()
<< ",\n";
writer << " " << out[0].get_name() << ",\n";
writer << " {" << join(args[0].get_shape()) << "},\n";
writer << " {" << join(out[0].get_shape()) << "},\n";
writer << " {" << join(sum->get_reduction_axes())
<< "});\n";
}
#else
if (args[0].get_element_type() == element::f32 && args[0].get_shape().size() == 1 &&
sum->get_reduction_axes().size() == 1)
{
......@@ -2246,7 +1652,6 @@ namespace ngraph
out[0].get_shape(),
sum->get_reduction_axes());
}
#endif
writer.block_end();
}
......@@ -2254,16 +1659,11 @@ namespace ngraph
void CPU_Emitter::EMITTER_DECL(ngraph::op::Exp)
{
writer.block_begin();
#if USE_EIGEN_CORE_INLINE == 1
writer << emit_array1d(out[0]) << " =\n"
<< " " << emit_array1d(args[0]) << ".exp();\n";
#else
writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer.block_begin();
writer << out[0].get_name() << "[i] = exp(" << args[0].get_name() << "[i]);\n";
writer.block_end();
#endif
writer.block_end();
}
......@@ -2289,16 +1689,11 @@ namespace ngraph
void CPU_Emitter::EMITTER_DECL(ngraph::op::Sin)
{
writer.block_begin();
#if USE_EIGEN_CORE_INLINE == 1
writer << emit_array1d(out[0]) << " =\n"
<< " " << emit_array1d(args[0]) << ".sin();\n";
#else
writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer.block_begin();
writer << out[0].get_name() << "[i] = sin(" << args[0].get_name() << "[i]);\n";
writer.block_end();
#endif
writer.block_end();
}
......@@ -2306,16 +1701,11 @@ namespace ngraph
void CPU_Emitter::EMITTER_DECL(ngraph::op::Sinh)
{
writer.block_begin();
#if USE_EIGEN_CORE_INLINE == 1
writer << emit_array1d(out[0]) << " =\n"
<< " " << emit_array1d(args[0]) << ".sinh();\n";
#else
writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer.block_begin();
writer << out[0].get_name() << "[i] = sinh(" << args[0].get_name() << "[i]);\n";
writer.block_end();
#endif
writer.block_end();
}
......@@ -2323,16 +1713,11 @@ namespace ngraph
void CPU_Emitter::EMITTER_DECL(ngraph::op::Cos)
{
writer.block_begin();
#if USE_EIGEN_CORE_INLINE == 1
writer << emit_array1d(out[0]) << " =\n"
<< " " << emit_array1d(args[0]) << ".cos();\n";
#else
writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer.block_begin();
writer << out[0].get_name() << "[i] = cos(" << args[0].get_name() << "[i]);\n";
writer.block_end();
#endif
writer.block_end();
}
......@@ -2340,16 +1725,11 @@ namespace ngraph
void CPU_Emitter::EMITTER_DECL(ngraph::op::Cosh)
{
writer.block_begin();
#if USE_EIGEN_CORE_INLINE == 1
writer << emit_array1d(out[0]) << " =\n"
<< " " << emit_array1d(args[0]) << ".cosh();\n";
#else
writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer.block_begin();
writer << out[0].get_name() << "[i] = cosh(" << args[0].get_name() << "[i]);\n";
writer.block_end();
#endif
writer.block_end();
}
......@@ -2357,16 +1737,11 @@ namespace ngraph
void CPU_Emitter::EMITTER_DECL(ngraph::op::Tan)
{
writer.block_begin();
#if USE_EIGEN_CORE_INLINE == 1
writer << emit_array1d(out[0]) << " =\n"
<< " " << emit_array1d(args[0]) << ".tan();\n";
#else
writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer.block_begin();
writer << out[0].get_name() << "[i] = tan(" << args[0].get_name() << "[i]);\n";
writer.block_end();
#endif
writer.block_end();
}
......@@ -2378,9 +1753,7 @@ namespace ngraph
// TODO: Implement our own internal fast/approximate tanh if this actually gets used
// by models
writer.block_begin();
#if USE_EIGEN_CORE_INLINE == 0
writer << "#pragma omp parallel for\n";
#endif
writer << "for (size_t i=0; i<" << out[0].get_size() << "; i++)\n";
writer.block_begin();
writer << out[0].get_name() << "[i] = tanh(" << args[0].get_name() << "[i]);\n";
......@@ -2392,16 +1765,11 @@ namespace ngraph
void CPU_Emitter::EMITTER_DECL(ngraph::op::Asin)
{
writer.block_begin();
#if USE_EIGEN_CORE_INLINE == 1
writer << emit_array1d(out[0]) << " =\n"
<< " " << emit_array1d(args[0]) << ".asin();\n";
#else
writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer.block_begin();
writer << out[0].get_name() << "[i] = asin(" << args[0].get_name() << "[i]);\n";
writer.block_end();
#endif
writer.block_end();
}
......@@ -2409,16 +1777,11 @@ namespace ngraph
void CPU_Emitter::EMITTER_DECL(ngraph::op::Acos)
{
writer.block_begin();
#if USE_EIGEN_CORE_INLINE == 1
writer << emit_array1d(out[0]) << " =\n"
<< " " << emit_array1d(args[0]) << ".acos();\n";
#else
writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer.block_begin();
writer << out[0].get_name() << "[i] = acos(" << args[0].get_name() << "[i]);\n";
writer.block_end();
#endif
writer.block_end();
}
......@@ -2426,16 +1789,11 @@ namespace ngraph
void CPU_Emitter::EMITTER_DECL(ngraph::op::Atan)
{
writer.block_begin();
#if USE_EIGEN_CORE_INLINE == 1
writer << emit_array1d(out[0]) << " =\n"
<< " " << emit_array1d(args[0]) << ".atan();\n";
#else
writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer.block_begin();
writer << out[0].get_name() << "[i] = atan(" << args[0].get_name() << "[i]);\n";
writer.block_end();
#endif
writer.block_end();
}
......@@ -2504,20 +1862,12 @@ namespace ngraph
void CPU_Emitter::EMITTER_DECL(ngraph::op::Power)
{
writer.block_begin();
#if USE_EIGEN_CORE_INLINE == 1
writer << emit_array1d(out[0]) << " = \n";
writer.indent++;
writer << emit_array1d(args[0]) << ".pow(\n ";
writer << emit_array1d(args[1]) << ");\n";
writer.indent--;
#else
writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer.block_begin();
writer << out[0].get_name() << "[i] = pow(" << args[0].get_name() << "[i], "
<< args[1].get_name() << "[i]);\n";
writer.block_end();
#endif
writer.block_end();
}
......@@ -2569,71 +1919,6 @@ namespace ngraph
{
auto replace_slice = static_cast<const ngraph::op::ReplaceSlice*>(node);
writer.block_begin();
#if USE_EIGEN_CORE_INLINE == 1
size_t arg0_rank = args[0].get_shape().size();
auto& lower_bounds = replace_slice->get_lower_bounds();
auto& upper_bounds = replace_slice->get_upper_bounds();
bool strided = false;
for (size_t stride : replace_slice->get_strides())
{
if (stride != 1)
{
strided = true;
break;
}
}
// Scalar slice is necessarily just a copy.
if (!strided && arg0_rank == 0)
{
writer.block_begin();
writer << "memcpy(" << out[0].get_name() << ", " << args[1].get_name() << ", "
<< out[0].get_size() * out[0].get_element_type().size() << ");\n";
writer.block_end();
}
else if (!strided && arg0_rank == 1)
{
writer.block_begin();
writer << emit_vector(out[0]) << " =\n"
<< " " << emit_vector(args[0]) << ";\n"
<< emit_vector(out[0]) << ".segment(\n"
<< " " << to_string(lower_bounds[0]) << ", "
<< to_string(upper_bounds[0] - lower_bounds[0]) << ") =\n"
<< " " << emit_vector(args[1]) << ";\n";
writer.block_end();
}
else if (!strided && arg0_rank == 2)
{
writer.block_begin();
writer << emit_matrix(out[0]) << " =\n"
<< " " << emit_matrix(args[0]) << ";\n"
<< emit_matrix(out[0]) << ".block(\n"
<< " " << to_string(lower_bounds[0]) << ",\n"
<< " " << to_string(lower_bounds[1]) << ",\n"
<< " " << to_string(upper_bounds[0] - lower_bounds[0]) << ",\n"
<< " " << to_string(upper_bounds[1] - lower_bounds[1]) << ") =\n"
<< " " << emit_matrix(args[1]) << ";\n";
writer.block_end();
}
// Other cases (reordering of axes for tensors with rank>2) are not handled yet.
else
{
writer << "reference::replace_slice<" << out[0].get_type() << ">("
<< args[0].get_name() << ",\n";
writer << " " << args[1].get_name() << ",\n";
writer << " " << out[0].get_name() << ",\n";
writer << " {" << join(args[1].get_shape()) << "},\n";
writer << " {"
<< join(replace_slice->get_lower_bounds()) << "},\n";
writer << " {"
<< join(replace_slice->get_upper_bounds()) << "},\n";
writer << " {" << join(replace_slice->get_strides())
<< "},\n";
writer << " {" << join(out[0].get_shape()) << "});\n";
}
#else
if (args[0].get_name() != out[0].get_name())
{
kernel::emit_replace_slice(writer,
......@@ -2659,7 +1944,6 @@ namespace ngraph
replace_slice->get_upper_bounds(),
replace_slice->get_strides());
}
#endif
writer.block_end();
}
......@@ -2752,9 +2036,7 @@ namespace ngraph
{
writer.block_begin();
size_t element_count = out[0].get_size();
#if USE_EIGEN_CORE_INLINE == 0
writer << "#pragma omp parallel for\n";
#endif
writer << "for (size_t i = 0; i < " << element_count << "; i++)\n";
writer.block_begin();
writer << out[0].get_name() << "[i] = ceil(" << args[0].get_name() << "[i]);\n";
......@@ -2767,9 +2049,7 @@ namespace ngraph
{
writer.block_begin();
size_t element_count = out[0].get_size();
#if USE_EIGEN_CORE_INLINE == 0
writer << "#pragma omp parallel for\n";
#endif
writer << "for (size_t i = 0; i < " << element_count << "; i++)\n";
writer.block_begin();
writer << out[0].get_name() << "[i] = floor(" << args[0].get_name() << "[i]);\n";
......@@ -2782,9 +2062,7 @@ namespace ngraph
{
writer.block_begin();
size_t element_count = out[0].get_size();
#if USE_EIGEN_CORE_INLINE == 0
writer << "#pragma omp parallel for\n";
#endif
writer << "for (size_t i = 0; i < " << element_count << "; i++)\n";
writer.block_begin();
writer << out[0].get_name() << "[i] = sqrt(" << args[0].get_name() << "[i]);\n";
......@@ -3914,53 +3192,6 @@ namespace ngraph
{
const ngraph::op::Product* product = static_cast<const ngraph::op::Product*>(node);
writer.block_begin();
#if USE_EIGEN_CORE_INLINE == 1
const Shape& arg_shape = args[0].get_shape();
size_t arg_rank = arg_shape.size();
const AxisSet& reduction_axes = product->get_reduction_axes();
// Trivial case: no reduction axes.
if (reduction_axes.size() == 0)
{
writer.block_begin();
writer << "memcpy(" << out[0].get_name() << ", " << args[0].get_name() << ", "
<< out[0].get_size() * out[0].get_element_type().size() << ");\n";
writer.block_end();
}
// Full reduction? Then reduce to scalar.
else if ((arg_rank == 1 && reduction_axes == AxisSet{0}) ||
(arg_rank == 2 && reduction_axes == AxisSet{0, 1}))
{
writer.block_begin();
writer << emit_array1d(out[0]) << " =\n"
<< " " << emit_array1d(args[0]) << ".prod();\n";
writer.block_end();
}
else if (arg_rank == 2 && reduction_axes == AxisSet{1})
{
writer.block_begin();
writer << emit_vector(out[0]) << " =\n"
<< " " << emit_matrix(args[0]) << ".rowwise().prod();\n";
writer.block_end();
}
else if (arg_rank == 2 && reduction_axes == AxisSet{0})
{
writer.block_begin();
writer << emit_vector(out[0]) << " =\n"
<< " " << emit_matrix(args[0]) << ".colwise().prod();\n";
writer.block_end();
}
else
{
writer << "reference::product<" << out[0].get_type() << ">("
<< args[0].get_name() << ",\n";
writer << " " << out[0].get_name() << ",\n";
writer << " {" << join(args[0].get_shape()) << "},\n";
writer << " {" << join(out[0].get_shape()) << "},\n";
writer << " {" << join(product->get_reduction_axes())
<< "});\n";
}
#else
// TODO: add an emitter akin to the emit_sum
writer << "reference::product<" << out[0].get_type() << ">(" << args[0].get_name()
<< ",\n";
......@@ -3969,7 +3200,6 @@ namespace ngraph
writer << " {" << join(out[0].get_shape()) << "},\n";
writer << " {" << join(product->get_reduction_axes())
<< "});\n";
#endif
writer.block_end();
}
......@@ -3978,59 +3208,6 @@ namespace ngraph
{
const ngraph::op::Max* max = static_cast<const ngraph::op::Max*>(node);
writer.block_begin();
#if USE_EIGEN_CORE_INLINE == 1
const Shape& arg_shape = args[0].get_shape();
size_t arg_rank = arg_shape.size();
const AxisSet& reduction_axes = max->get_reduction_axes();
bool zero_sized = false;
for (size_t s : arg_shape)
{
zero_sized |= (s == 0);
}
// Trivial case: no reduction axes.
if (!zero_sized && reduction_axes.size() == 0)
{
writer.block_begin();
writer << "memcpy(" << out[0].get_name() << ", " << args[0].get_name() << ", "
<< out[0].get_size() * out[0].get_element_type().size() << ");\n";
writer.block_end();
}
// Full reduction? Then reduce to scalar.
else if (!zero_sized && ((arg_rank == 1 && reduction_axes == AxisSet{0}) ||
(arg_rank == 2 && reduction_axes == AxisSet{0, 1})))
{
writer.block_begin();
writer << emit_array1d(out[0]) << " =\n"
<< " " << emit_array1d(args[0]) << ".maxCoeff();\n";
writer.block_end();
}
else if (!zero_sized && arg_rank == 2 && reduction_axes == AxisSet{1})
{
writer.block_begin();
writer << emit_vector(out[0]) << " =\n"
<< " " << emit_matrix(args[0]) << ".rowwise().maxCoeff();\n";
writer.block_end();
}
else if (!zero_sized && arg_rank == 2 && reduction_axes == AxisSet{0})
{
writer.block_begin();
writer << emit_vector(out[0]) << " =\n"
<< " " << emit_matrix(args[0]) << ".colwise().maxCoeff();\n";
writer.block_end();
}
else
{
writer << "reference::max<" << out[0].get_type() << ">(" << args[0].get_name()
<< ",\n";
writer << " " << out[0].get_name() << ",\n";
writer << " {" << join(args[0].get_shape()) << "},\n";
writer << " {" << join(out[0].get_shape()) << "},\n";
writer << " {" << join(max->get_reduction_axes())
<< "});\n";
}
#else
if (args[0].get_element_type() == element::f32 && args[0].get_shape().size() == 2 &&
max->get_reduction_axes().size() == 1)
{
......@@ -4051,7 +3228,6 @@ namespace ngraph
writer << " {" << join(max->get_reduction_axes())
<< "});\n";
}
#endif
writer.block_end();
}
......@@ -4060,59 +3236,6 @@ namespace ngraph
{
const ngraph::op::Min* min = static_cast<const ngraph::op::Min*>(node);
writer.block_begin();
#if USE_EIGEN_CORE_INLINE == 1
const Shape& arg_shape = args[0].get_shape();
size_t arg_rank = arg_shape.size();
const AxisSet& reduction_axes = min->get_reduction_axes();
bool zero_sized = false;
for (size_t s : arg_shape)
{
zero_sized |= (s == 0);
}
// Trivial case: no reduction axes.
if (!zero_sized && reduction_axes.size() == 0)
{
writer.block_begin();
writer << "memcpy(" << out[0].get_name() << ", " << args[0].get_name() << ", "
<< out[0].get_size() * out[0].get_element_type().size() << ");\n";
writer.block_end();
}
// Full reduction? Then reduce to scalar.
else if (!zero_sized && ((arg_rank == 1 && reduction_axes == AxisSet{0}) ||
(arg_rank == 2 && reduction_axes == AxisSet{0, 1})))
{
writer.block_begin();
writer << emit_array1d(out[0]) << " =\n"
<< " " << emit_array1d(args[0]) << ".minCoeff();\n";
writer.block_end();
}
else if (!zero_sized && arg_rank == 2 && reduction_axes == AxisSet{1})
{
writer.block_begin();
writer << emit_vector(out[0]) << " =\n"
<< " " << emit_matrix(args[0]) << ".rowwise().minCoeff();\n";
writer.block_end();
}
else if (!zero_sized && arg_rank == 2 && reduction_axes == AxisSet{0})
{
writer.block_begin();
writer << emit_vector(out[0]) << " =\n"
<< " " << emit_matrix(args[0]) << ".colwise().minCoeff();\n";
writer.block_end();
}
else
{
writer << "reference::min<" << out[0].get_type() << ">(" << args[0].get_name()
<< ",\n";
writer << " " << out[0].get_name() << ",\n";
writer << " {" << join(args[0].get_shape()) << "},\n";
writer << " {" << join(out[0].get_shape()) << "},\n";
writer << " {" << join(min->get_reduction_axes())
<< "});\n";
}
#else
// TODO: add an emitter akin to the emit_sum
writer << "reference::min<" << out[0].get_type() << ">(" << args[0].get_name()
<< ",\n";
......@@ -4121,7 +3244,6 @@ namespace ngraph
writer << " {" << join(out[0].get_shape()) << "},\n";
writer << " {" << join(min->get_reduction_axes())
<< "});\n";
#endif
writer.block_end();
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment