Commit 870b9b0d authored by Jayaram Bobba's avatar Jayaram Bobba Committed by Scott Cyphers

generic layout methods for elementwise kernels. (#2126)

* Pass layouts through where feasible for unary and binary elementwise ops

* compilation fix for leaky relu
parent 88c9a3e7
......@@ -20,7 +20,7 @@ using namespace std;
using namespace ngraph;
op::BoundedRelu::BoundedRelu(shared_ptr<Node> arg, float alpha)
: Op("BoundedRelu", check_single_output_args({arg}))
: UnaryElementwiseArithmetic("BoundedRelu", {arg})
, m_alpha(alpha)
{
constructor_validate_and_infer_types();
......
......@@ -18,7 +18,7 @@
#include "ngraph/node.hpp"
#include "ngraph/op/op.hpp"
#include "ngraph/op/op.hpp"
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
namespace ngraph
{
......@@ -26,7 +26,7 @@ namespace ngraph
{
/// \brief Elementwise Minimum(Relu(arg, 0), alpha) operation.
///
class BoundedRelu : public Op
class BoundedRelu : public ngraph::op::util::UnaryElementwiseArithmetic
{
public:
/// \brief Constructs a BoundedRelu operation.
......
......@@ -20,7 +20,7 @@ using namespace std;
using namespace ngraph;
op::LeakyRelu::LeakyRelu(shared_ptr<Node> arg, float alpha)
: Op("LeakyRelu", check_single_output_args({arg}))
: UnaryElementwiseArithmetic("LeakyRelu", {arg})
, m_alpha(alpha)
{
constructor_validate_and_infer_types();
......
......@@ -18,7 +18,7 @@
#include "ngraph/node.hpp"
#include "ngraph/op/op.hpp"
#include "ngraph/op/op.hpp"
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
namespace ngraph
{
......@@ -27,7 +27,7 @@ namespace ngraph
/// \brief Elementwise Maximum(arg, arg * alpha) operation
/// alpha > 0
///
class LeakyRelu : public Op
class LeakyRelu : public ngraph::op::util::UnaryElementwiseArithmetic
{
public:
/// \brief Constructs a LeakyRelu operation.
......
......@@ -72,10 +72,10 @@ using namespace ngraph::runtime::cpu;
// Check if the input layout matches the layout requested in `required_mds`
// If not, insert a layout conversion node between the input tensor and
// the `node`. For now, only MKLDNN nodes/kernels can request specific layouts
shared_ptr<Node> runtime::cpu::pass::CPULayout::insert_input_conversions(
runtime::cpu::CPU_ExternalFunction* external_function,
shared_ptr<Node>& node,
const vector<memory::desc>& required_mds)
static shared_ptr<Node>
insert_input_conversions(runtime::cpu::CPU_ExternalFunction* external_function,
shared_ptr<Node>& node,
const vector<memory::desc>& required_mds)
{
vector<shared_ptr<Node>> new_args;
bool replace_node = false;
......@@ -153,8 +153,7 @@ shared_ptr<Node> runtime::cpu::pass::CPULayout::insert_input_conversions(
return node;
}
void runtime::cpu::pass::CPULayout::set_output_layouts(shared_ptr<Node>& node,
const vector<memory::desc>& output_mds)
static void set_output_layouts(shared_ptr<Node>& node, const vector<memory::desc>& output_mds)
{
for (size_t i = 0; i < node->get_output_size(); ++i)
{
......@@ -174,10 +173,9 @@ void runtime::cpu::pass::CPULayout::set_output_layouts(shared_ptr<Node>& node,
}
}
void runtime::cpu::pass::CPULayout::set_native_layouts(
runtime::cpu::CPU_ExternalFunction* external_function,
std::shared_ptr<Node> node,
bool use_replace = true)
static void set_native_layouts(runtime::cpu::CPU_ExternalFunction* external_function,
std::shared_ptr<Node> node,
bool use_replace = true)
{
std::vector<shared_ptr<Node>> new_args;
bool replace_node = false;
......@@ -268,6 +266,59 @@ void runtime::cpu::pass::CPULayout::set_native_layouts(
}
}
static void set_layouts_unaryeltwise(ngraph::runtime::cpu::CPU_ExternalFunction* external_function,
std::shared_ptr<ngraph::Node> node)
{
auto input_md = mkldnn_utils::get_input_mkldnn_md(node.get(), 0);
// Non MKLDNN kernels can handle MKLDNN layouts as long as there are not padded
bool md_check = input_md.data.format != mkldnn_format_undef &&
!mkldnn_utils::is_mkldnn_padded_layout(
input_md, ngraph::get_default_order(node->get_input_shape(0)));
if (mkldnn_utils::use_mkldnn_kernel(node.get()) || md_check)
{
vector<memory::desc> o_mds;
o_mds.push_back(input_md);
set_output_layouts(node, o_mds);
}
else
{
set_native_layouts(external_function, node);
}
}
void set_layouts_binaryeltwise(ngraph::runtime::cpu::CPU_ExternalFunction* external_function,
std::shared_ptr<ngraph::Node> node)
{
std::vector<mkldnn::memory::desc> arg_mds{mkldnn_utils::get_input_mkldnn_md(node.get(), 0),
mkldnn_utils::get_input_mkldnn_md(node.get(), 1)};
bool md_check = arg_mds[0].data.format != mkldnn_format_undef &&
arg_mds[1].data.format != mkldnn_format_undef &&
!mkldnn_utils::is_mkldnn_padded_layout(
arg_mds[0], ngraph::get_default_order(node->get_input_shape(0))) &&
!mkldnn_utils::is_mkldnn_padded_layout(
arg_mds[1], ngraph::get_default_order(node->get_input_shape(1)));
if (mkldnn_utils::use_mkldnn_kernel(node.get()) || md_check)
{
vector<memory::desc> i_mds;
vector<memory::desc> o_mds;
int select = 0;
if (std::getenv("NGRAPH_PASS_CPU_LAYOUT_ELTWISE") != nullptr)
{
const int user_select = std::atoi(std::getenv("NGRAPH_PASS_CPU_LAYOUT_ELTWISE"));
select = (user_select == 0 || user_select == 1) ? user_select : select;
}
i_mds.push_back(arg_mds[select]);
i_mds.push_back(arg_mds[select]);
o_mds.push_back(arg_mds[select]);
node = insert_input_conversions(external_function, node, i_mds);
set_output_layouts(node, o_mds);
}
else
{
set_native_layouts(external_function, node);
}
}
namespace ngraph
{
namespace runtime
......@@ -1519,33 +1570,6 @@ namespace ngraph
}
}
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::Relu)
{
if (mkldnn_utils::use_mkldnn_kernel(node.get()))
{
auto input_md = mkldnn_utils::get_input_mkldnn_md(node.get(), 0);
vector<memory::desc> o_mds;
o_mds.push_back(input_md);
set_output_layouts(node, o_mds);
}
else
{
if (mkldnn_utils::get_input_mkldnn_md(node.get(), 0).data.format ==
mkldnn_format_undef)
{
set_native_layouts(external_function, node);
}
else
{
auto input_md = mkldnn_utils::get_input_mkldnn_md(node.get(), 0);
vector<memory::desc> o_mds;
o_mds.push_back(input_md);
set_output_layouts(node, o_mds);
}
}
}
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::LRN)
{
......@@ -1562,22 +1586,6 @@ namespace ngraph
}
}
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::Sigmoid)
{
if (mkldnn_utils::use_mkldnn_kernel(node.get()))
{
auto input_md = mkldnn_utils::get_input_mkldnn_md(node.get(), 0);
vector<memory::desc> o_mds;
o_mds.push_back(input_md);
set_output_layouts(node, o_mds);
}
else
{
set_native_layouts(external_function, node);
}
}
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::SigmoidBackprop)
{
......@@ -1784,27 +1792,6 @@ namespace ngraph
}
}
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::Add)
{
if (mkldnn_utils::use_mkldnn_kernel(node.get()))
{
auto input0_md = mkldnn_utils::get_input_mkldnn_md(node.get(), 0);
vector<memory::desc> i_mds;
vector<memory::desc> o_mds;
i_mds.push_back(input0_md);
i_mds.push_back(input0_md);
o_mds.push_back(input0_md);
node = insert_input_conversions(external_function, node, i_mds);
set_output_layouts(node, o_mds);
}
else
{
set_native_layouts(external_function, node);
}
}
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::Slice)
{
......@@ -1934,38 +1921,8 @@ namespace ngraph
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::Softmax)
{
if (mkldnn_utils::use_mkldnn_kernel(node.get()))
{
auto input_md = mkldnn_utils::get_input_mkldnn_md(node.get(), 0);
vector<memory::desc> o_mds;
o_mds.push_back(input_md);
set_output_layouts(node, o_mds);
}
else
{
set_native_layouts(external_function, node);
}
}
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::BoundedRelu)
{
if (mkldnn_utils::use_mkldnn_kernel(node.get()))
{
auto input_md = mkldnn_utils::get_input_mkldnn_md(node.get(), 0);
vector<memory::desc> o_mds;
o_mds.push_back(input_md);
set_output_layouts(node, o_mds);
}
else
{
set_native_layouts(external_function, node);
}
}
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::LeakyRelu)
{
// Softmax cannot use the default unary layout method since the kernels
// need to know the reduction dimension
if (mkldnn_utils::use_mkldnn_kernel(node.get()))
{
auto input_md = mkldnn_utils::get_input_mkldnn_md(node.get(), 0);
......@@ -1986,7 +1943,6 @@ namespace ngraph
#define TI(x) type_index(typeid(x))
static const runtime::cpu::pass::LayoutOpMap s_dispatcher{
{TI(ngraph::op::Add), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Add>},
{TI(ngraph::op::Concat), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Concat>},
{TI(ngraph::op::AvgPool), &runtime::cpu::pass::CPULayout::layout<ngraph::op::AvgPool>},
{TI(ngraph::op::AvgPoolBackprop),
......@@ -2032,19 +1988,15 @@ static const runtime::cpu::pass::LayoutOpMap s_dispatcher{
{TI(ngraph::op::GetOutputElement),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::GetOutputElement>},
{TI(ngraph::op::LRN), &runtime::cpu::pass::CPULayout::layout<ngraph::op::LRN>},
{TI(ngraph::op::Relu), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Relu>},
{TI(ngraph::op::Reshape), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Reshape>},
{TI(ngraph::op::Result), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Result>},
{TI(ngraph::op::ReluBackprop),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::ReluBackprop>},
{TI(ngraph::op::Sigmoid), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Sigmoid>},
{TI(ngraph::op::SigmoidBackprop),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::SigmoidBackprop>},
{TI(ngraph::op::Lstm), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Lstm>},
{TI(ngraph::op::Rnn), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Rnn>},
{TI(ngraph::op::Softmax), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Softmax>},
{TI(ngraph::op::BoundedRelu), &runtime::cpu::pass::CPULayout::layout<ngraph::op::BoundedRelu>},
{TI(ngraph::op::LeakyRelu), &runtime::cpu::pass::CPULayout::layout<ngraph::op::LeakyRelu>},
{TI(ngraph::op::ConvolutionAdd),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::ConvolutionAdd>},
{TI(ngraph::op::Slice), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Slice>},
......@@ -2070,6 +2022,16 @@ bool runtime::cpu::pass::CPULayout::run_on_call_graph(const std::list<std::share
{
handler->second(m_external_function, node);
}
else if (dynamic_pointer_cast<ngraph::op::util::UnaryElementwiseArithmetic>(node) !=
nullptr)
{
set_layouts_unaryeltwise(m_external_function, node);
}
else if (dynamic_pointer_cast<ngraph::op::util::BinaryElementwiseArithmetic>(node) !=
nullptr)
{
set_layouts_binaryeltwise(m_external_function, node);
}
else
{
set_native_layouts(m_external_function, node);
......
......@@ -53,16 +53,6 @@ namespace ngraph
private:
CPU_ExternalFunction* m_external_function;
static std::shared_ptr<Node> insert_input_conversions(
CPU_ExternalFunction* external_function,
std::shared_ptr<Node>& node,
const std::vector<mkldnn::memory::desc>& required_mds);
static void
set_output_layouts(std::shared_ptr<Node>& node,
const std::vector<mkldnn::memory::desc>& output_mds);
static void set_native_layouts(CPU_ExternalFunction* external_function,
std::shared_ptr<Node> node,
bool use_replace);
};
}
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment