Unverified Commit ce8fef72 authored by Matthew Brookhart's avatar Matthew Brookhart Committed by GitHub

expand when mkldnn relu can be used, add faster default kernels (#592)

parent ca06e6c3
......@@ -3151,11 +3151,12 @@ namespace ngraph
}
else
{
writer << "kernel::relu_backprop<" << out[0].get_type() << ">("
<< args[0].get_name() << ",\n";
writer << " " << args[1].get_name() << ",\n";
writer << " " << out[0].get_name() << ",\n";
writer << " " << out[0].get_size() << ");\n";
writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer << "{\n";
writer << " " << out[0].get_name() << "[i] = " << args[0].get_name()
<< "[i] > 0 ? " << args[1].get_name() << "[i] : 0;\n";
writer << "}\n";
}
}
......@@ -3183,10 +3184,12 @@ namespace ngraph
}
else
{
writer << "kernel::relu<" << out[0].get_type() << ">(" << args[0].get_name()
<< ",\n";
writer << " " << out[0].get_name() << ",\n";
writer << " " << out[0].get_size() << ");\n";
writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer << "{\n";
writer << " " << out[0].get_name() << "[i] = " << args[0].get_name()
<< "[i] > 0 ? " << args[0].get_name() << "[i] : 0;\n";
writer << "}\n";
}
}
......
......@@ -198,7 +198,8 @@ namespace ngraph
auto arg0_rank = arg0_shape.size();
auto result_shape = node->get_output_shape(0);
if (arg0_rank == 4 && node->get_input_element_type(0) == element::f32)
if ((arg0_rank == 4 || arg0_rank == 2) &&
node->get_input_element_type(0) == element::f32)
{
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment