Commit 870b9b0d authored by Jayaram Bobba's avatar Jayaram Bobba Committed by Scott Cyphers

generic layout methods for elementwise kernels. (#2126)

* Pass layouts through where feasible for unary and binary elementwise ops

* compilation fix for leaky relu
parent 88c9a3e7
...@@ -20,7 +20,7 @@ using namespace std; ...@@ -20,7 +20,7 @@ using namespace std;
using namespace ngraph; using namespace ngraph;
op::BoundedRelu::BoundedRelu(shared_ptr<Node> arg, float alpha) op::BoundedRelu::BoundedRelu(shared_ptr<Node> arg, float alpha)
: Op("BoundedRelu", check_single_output_args({arg})) : UnaryElementwiseArithmetic("BoundedRelu", {arg})
, m_alpha(alpha) , m_alpha(alpha)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/op/op.hpp" #include "ngraph/op/op.hpp"
#include "ngraph/op/op.hpp" #include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -26,7 +26,7 @@ namespace ngraph ...@@ -26,7 +26,7 @@ namespace ngraph
{ {
/// \brief Elementwise Minimum(Relu(arg, 0), alpha) operation. /// \brief Elementwise Minimum(Relu(arg, 0), alpha) operation.
/// ///
class BoundedRelu : public Op class BoundedRelu : public ngraph::op::util::UnaryElementwiseArithmetic
{ {
public: public:
/// \brief Constructs a BoundedRelu operation. /// \brief Constructs a BoundedRelu operation.
......
...@@ -20,7 +20,7 @@ using namespace std; ...@@ -20,7 +20,7 @@ using namespace std;
using namespace ngraph; using namespace ngraph;
op::LeakyRelu::LeakyRelu(shared_ptr<Node> arg, float alpha) op::LeakyRelu::LeakyRelu(shared_ptr<Node> arg, float alpha)
: Op("LeakyRelu", check_single_output_args({arg})) : UnaryElementwiseArithmetic("LeakyRelu", {arg})
, m_alpha(alpha) , m_alpha(alpha)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/op/op.hpp" #include "ngraph/op/op.hpp"
#include "ngraph/op/op.hpp" #include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -27,7 +27,7 @@ namespace ngraph ...@@ -27,7 +27,7 @@ namespace ngraph
/// \brief Elementwise Maximum(arg, arg * alpha) operation /// \brief Elementwise Maximum(arg, arg * alpha) operation
/// alpha > 0 /// alpha > 0
/// ///
class LeakyRelu : public Op class LeakyRelu : public ngraph::op::util::UnaryElementwiseArithmetic
{ {
public: public:
/// \brief Constructs a LeakyRelu operation. /// \brief Constructs a LeakyRelu operation.
......
...@@ -72,10 +72,10 @@ using namespace ngraph::runtime::cpu; ...@@ -72,10 +72,10 @@ using namespace ngraph::runtime::cpu;
// Check if the input layout matches the layout requested in `required_mds` // Check if the input layout matches the layout requested in `required_mds`
// If not, insert a layout conversion node between the input tensor and // If not, insert a layout conversion node between the input tensor and
// the `node`. For now, only MKLDNN nodes/kernels can request specific layouts // the `node`. For now, only MKLDNN nodes/kernels can request specific layouts
shared_ptr<Node> runtime::cpu::pass::CPULayout::insert_input_conversions( static shared_ptr<Node>
runtime::cpu::CPU_ExternalFunction* external_function, insert_input_conversions(runtime::cpu::CPU_ExternalFunction* external_function,
shared_ptr<Node>& node, shared_ptr<Node>& node,
const vector<memory::desc>& required_mds) const vector<memory::desc>& required_mds)
{ {
vector<shared_ptr<Node>> new_args; vector<shared_ptr<Node>> new_args;
bool replace_node = false; bool replace_node = false;
...@@ -153,8 +153,7 @@ shared_ptr<Node> runtime::cpu::pass::CPULayout::insert_input_conversions( ...@@ -153,8 +153,7 @@ shared_ptr<Node> runtime::cpu::pass::CPULayout::insert_input_conversions(
return node; return node;
} }
void runtime::cpu::pass::CPULayout::set_output_layouts(shared_ptr<Node>& node, static void set_output_layouts(shared_ptr<Node>& node, const vector<memory::desc>& output_mds)
const vector<memory::desc>& output_mds)
{ {
for (size_t i = 0; i < node->get_output_size(); ++i) for (size_t i = 0; i < node->get_output_size(); ++i)
{ {
...@@ -174,10 +173,9 @@ void runtime::cpu::pass::CPULayout::set_output_layouts(shared_ptr<Node>& node, ...@@ -174,10 +173,9 @@ void runtime::cpu::pass::CPULayout::set_output_layouts(shared_ptr<Node>& node,
} }
} }
void runtime::cpu::pass::CPULayout::set_native_layouts( static void set_native_layouts(runtime::cpu::CPU_ExternalFunction* external_function,
runtime::cpu::CPU_ExternalFunction* external_function, std::shared_ptr<Node> node,
std::shared_ptr<Node> node, bool use_replace = true)
bool use_replace = true)
{ {
std::vector<shared_ptr<Node>> new_args; std::vector<shared_ptr<Node>> new_args;
bool replace_node = false; bool replace_node = false;
...@@ -268,6 +266,59 @@ void runtime::cpu::pass::CPULayout::set_native_layouts( ...@@ -268,6 +266,59 @@ void runtime::cpu::pass::CPULayout::set_native_layouts(
} }
} }
static void set_layouts_unaryeltwise(ngraph::runtime::cpu::CPU_ExternalFunction* external_function,
std::shared_ptr<ngraph::Node> node)
{
auto input_md = mkldnn_utils::get_input_mkldnn_md(node.get(), 0);
// Non MKLDNN kernels can handle MKLDNN layouts as long as there are not padded
bool md_check = input_md.data.format != mkldnn_format_undef &&
!mkldnn_utils::is_mkldnn_padded_layout(
input_md, ngraph::get_default_order(node->get_input_shape(0)));
if (mkldnn_utils::use_mkldnn_kernel(node.get()) || md_check)
{
vector<memory::desc> o_mds;
o_mds.push_back(input_md);
set_output_layouts(node, o_mds);
}
else
{
set_native_layouts(external_function, node);
}
}
void set_layouts_binaryeltwise(ngraph::runtime::cpu::CPU_ExternalFunction* external_function,
std::shared_ptr<ngraph::Node> node)
{
std::vector<mkldnn::memory::desc> arg_mds{mkldnn_utils::get_input_mkldnn_md(node.get(), 0),
mkldnn_utils::get_input_mkldnn_md(node.get(), 1)};
bool md_check = arg_mds[0].data.format != mkldnn_format_undef &&
arg_mds[1].data.format != mkldnn_format_undef &&
!mkldnn_utils::is_mkldnn_padded_layout(
arg_mds[0], ngraph::get_default_order(node->get_input_shape(0))) &&
!mkldnn_utils::is_mkldnn_padded_layout(
arg_mds[1], ngraph::get_default_order(node->get_input_shape(1)));
if (mkldnn_utils::use_mkldnn_kernel(node.get()) || md_check)
{
vector<memory::desc> i_mds;
vector<memory::desc> o_mds;
int select = 0;
if (std::getenv("NGRAPH_PASS_CPU_LAYOUT_ELTWISE") != nullptr)
{
const int user_select = std::atoi(std::getenv("NGRAPH_PASS_CPU_LAYOUT_ELTWISE"));
select = (user_select == 0 || user_select == 1) ? user_select : select;
}
i_mds.push_back(arg_mds[select]);
i_mds.push_back(arg_mds[select]);
o_mds.push_back(arg_mds[select]);
node = insert_input_conversions(external_function, node, i_mds);
set_output_layouts(node, o_mds);
}
else
{
set_native_layouts(external_function, node);
}
}
namespace ngraph namespace ngraph
{ {
namespace runtime namespace runtime
...@@ -1519,33 +1570,6 @@ namespace ngraph ...@@ -1519,33 +1570,6 @@ namespace ngraph
} }
} }
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::Relu)
{
if (mkldnn_utils::use_mkldnn_kernel(node.get()))
{
auto input_md = mkldnn_utils::get_input_mkldnn_md(node.get(), 0);
vector<memory::desc> o_mds;
o_mds.push_back(input_md);
set_output_layouts(node, o_mds);
}
else
{
if (mkldnn_utils::get_input_mkldnn_md(node.get(), 0).data.format ==
mkldnn_format_undef)
{
set_native_layouts(external_function, node);
}
else
{
auto input_md = mkldnn_utils::get_input_mkldnn_md(node.get(), 0);
vector<memory::desc> o_mds;
o_mds.push_back(input_md);
set_output_layouts(node, o_mds);
}
}
}
template <> template <>
void CPULayout::LAYOUT_DECL(ngraph::op::LRN) void CPULayout::LAYOUT_DECL(ngraph::op::LRN)
{ {
...@@ -1562,22 +1586,6 @@ namespace ngraph ...@@ -1562,22 +1586,6 @@ namespace ngraph
} }
} }
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::Sigmoid)
{
if (mkldnn_utils::use_mkldnn_kernel(node.get()))
{
auto input_md = mkldnn_utils::get_input_mkldnn_md(node.get(), 0);
vector<memory::desc> o_mds;
o_mds.push_back(input_md);
set_output_layouts(node, o_mds);
}
else
{
set_native_layouts(external_function, node);
}
}
template <> template <>
void CPULayout::LAYOUT_DECL(ngraph::op::SigmoidBackprop) void CPULayout::LAYOUT_DECL(ngraph::op::SigmoidBackprop)
{ {
...@@ -1784,27 +1792,6 @@ namespace ngraph ...@@ -1784,27 +1792,6 @@ namespace ngraph
} }
} }
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::Add)
{
if (mkldnn_utils::use_mkldnn_kernel(node.get()))
{
auto input0_md = mkldnn_utils::get_input_mkldnn_md(node.get(), 0);
vector<memory::desc> i_mds;
vector<memory::desc> o_mds;
i_mds.push_back(input0_md);
i_mds.push_back(input0_md);
o_mds.push_back(input0_md);
node = insert_input_conversions(external_function, node, i_mds);
set_output_layouts(node, o_mds);
}
else
{
set_native_layouts(external_function, node);
}
}
template <> template <>
void CPULayout::LAYOUT_DECL(ngraph::op::Slice) void CPULayout::LAYOUT_DECL(ngraph::op::Slice)
{ {
...@@ -1934,38 +1921,8 @@ namespace ngraph ...@@ -1934,38 +1921,8 @@ namespace ngraph
template <> template <>
void CPULayout::LAYOUT_DECL(ngraph::op::Softmax) void CPULayout::LAYOUT_DECL(ngraph::op::Softmax)
{ {
if (mkldnn_utils::use_mkldnn_kernel(node.get())) // Softmax cannot use the default unary layout method since the kernels
{ // need to know the reduction dimension
auto input_md = mkldnn_utils::get_input_mkldnn_md(node.get(), 0);
vector<memory::desc> o_mds;
o_mds.push_back(input_md);
set_output_layouts(node, o_mds);
}
else
{
set_native_layouts(external_function, node);
}
}
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::BoundedRelu)
{
if (mkldnn_utils::use_mkldnn_kernel(node.get()))
{
auto input_md = mkldnn_utils::get_input_mkldnn_md(node.get(), 0);
vector<memory::desc> o_mds;
o_mds.push_back(input_md);
set_output_layouts(node, o_mds);
}
else
{
set_native_layouts(external_function, node);
}
}
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::LeakyRelu)
{
if (mkldnn_utils::use_mkldnn_kernel(node.get())) if (mkldnn_utils::use_mkldnn_kernel(node.get()))
{ {
auto input_md = mkldnn_utils::get_input_mkldnn_md(node.get(), 0); auto input_md = mkldnn_utils::get_input_mkldnn_md(node.get(), 0);
...@@ -1986,7 +1943,6 @@ namespace ngraph ...@@ -1986,7 +1943,6 @@ namespace ngraph
#define TI(x) type_index(typeid(x)) #define TI(x) type_index(typeid(x))
static const runtime::cpu::pass::LayoutOpMap s_dispatcher{ static const runtime::cpu::pass::LayoutOpMap s_dispatcher{
{TI(ngraph::op::Add), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Add>},
{TI(ngraph::op::Concat), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Concat>}, {TI(ngraph::op::Concat), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Concat>},
{TI(ngraph::op::AvgPool), &runtime::cpu::pass::CPULayout::layout<ngraph::op::AvgPool>}, {TI(ngraph::op::AvgPool), &runtime::cpu::pass::CPULayout::layout<ngraph::op::AvgPool>},
{TI(ngraph::op::AvgPoolBackprop), {TI(ngraph::op::AvgPoolBackprop),
...@@ -2032,19 +1988,15 @@ static const runtime::cpu::pass::LayoutOpMap s_dispatcher{ ...@@ -2032,19 +1988,15 @@ static const runtime::cpu::pass::LayoutOpMap s_dispatcher{
{TI(ngraph::op::GetOutputElement), {TI(ngraph::op::GetOutputElement),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::GetOutputElement>}, &runtime::cpu::pass::CPULayout::layout<ngraph::op::GetOutputElement>},
{TI(ngraph::op::LRN), &runtime::cpu::pass::CPULayout::layout<ngraph::op::LRN>}, {TI(ngraph::op::LRN), &runtime::cpu::pass::CPULayout::layout<ngraph::op::LRN>},
{TI(ngraph::op::Relu), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Relu>},
{TI(ngraph::op::Reshape), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Reshape>}, {TI(ngraph::op::Reshape), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Reshape>},
{TI(ngraph::op::Result), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Result>}, {TI(ngraph::op::Result), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Result>},
{TI(ngraph::op::ReluBackprop), {TI(ngraph::op::ReluBackprop),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::ReluBackprop>}, &runtime::cpu::pass::CPULayout::layout<ngraph::op::ReluBackprop>},
{TI(ngraph::op::Sigmoid), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Sigmoid>},
{TI(ngraph::op::SigmoidBackprop), {TI(ngraph::op::SigmoidBackprop),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::SigmoidBackprop>}, &runtime::cpu::pass::CPULayout::layout<ngraph::op::SigmoidBackprop>},
{TI(ngraph::op::Lstm), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Lstm>}, {TI(ngraph::op::Lstm), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Lstm>},
{TI(ngraph::op::Rnn), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Rnn>}, {TI(ngraph::op::Rnn), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Rnn>},
{TI(ngraph::op::Softmax), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Softmax>}, {TI(ngraph::op::Softmax), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Softmax>},
{TI(ngraph::op::BoundedRelu), &runtime::cpu::pass::CPULayout::layout<ngraph::op::BoundedRelu>},
{TI(ngraph::op::LeakyRelu), &runtime::cpu::pass::CPULayout::layout<ngraph::op::LeakyRelu>},
{TI(ngraph::op::ConvolutionAdd), {TI(ngraph::op::ConvolutionAdd),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::ConvolutionAdd>}, &runtime::cpu::pass::CPULayout::layout<ngraph::op::ConvolutionAdd>},
{TI(ngraph::op::Slice), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Slice>}, {TI(ngraph::op::Slice), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Slice>},
...@@ -2070,6 +2022,16 @@ bool runtime::cpu::pass::CPULayout::run_on_call_graph(const std::list<std::share ...@@ -2070,6 +2022,16 @@ bool runtime::cpu::pass::CPULayout::run_on_call_graph(const std::list<std::share
{ {
handler->second(m_external_function, node); handler->second(m_external_function, node);
} }
else if (dynamic_pointer_cast<ngraph::op::util::UnaryElementwiseArithmetic>(node) !=
nullptr)
{
set_layouts_unaryeltwise(m_external_function, node);
}
else if (dynamic_pointer_cast<ngraph::op::util::BinaryElementwiseArithmetic>(node) !=
nullptr)
{
set_layouts_binaryeltwise(m_external_function, node);
}
else else
{ {
set_native_layouts(m_external_function, node); set_native_layouts(m_external_function, node);
......
...@@ -53,16 +53,6 @@ namespace ngraph ...@@ -53,16 +53,6 @@ namespace ngraph
private: private:
CPU_ExternalFunction* m_external_function; CPU_ExternalFunction* m_external_function;
static std::shared_ptr<Node> insert_input_conversions(
CPU_ExternalFunction* external_function,
std::shared_ptr<Node>& node,
const std::vector<mkldnn::memory::desc>& required_mds);
static void
set_output_layouts(std::shared_ptr<Node>& node,
const std::vector<mkldnn::memory::desc>& output_mds);
static void set_native_layouts(CPU_ExternalFunction* external_function,
std::shared_ptr<Node> node,
bool use_replace);
}; };
} }
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment