Commit 870b9b0d authored by Jayaram Bobba's avatar Jayaram Bobba Committed by Scott Cyphers

generic layout methods for elementwise kernels. (#2126)

* Pass layouts through where feasible for unary and binary elementwise ops

* compilation fix for leaky relu
parent 88c9a3e7
......@@ -20,7 +20,7 @@ using namespace std;
using namespace ngraph;
op::BoundedRelu::BoundedRelu(shared_ptr<Node> arg, float alpha)
: Op("BoundedRelu", check_single_output_args({arg}))
: UnaryElementwiseArithmetic("BoundedRelu", {arg})
, m_alpha(alpha)
{
constructor_validate_and_infer_types();
......
......@@ -18,7 +18,7 @@
#include "ngraph/node.hpp"
#include "ngraph/op/op.hpp"
#include "ngraph/op/op.hpp"
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
namespace ngraph
{
......@@ -26,7 +26,7 @@ namespace ngraph
{
/// \brief Elementwise Minimum(Relu(arg, 0), alpha) operation.
///
class BoundedRelu : public Op
class BoundedRelu : public ngraph::op::util::UnaryElementwiseArithmetic
{
public:
/// \brief Constructs a BoundedRelu operation.
......
......@@ -20,7 +20,7 @@ using namespace std;
using namespace ngraph;
op::LeakyRelu::LeakyRelu(shared_ptr<Node> arg, float alpha)
: Op("LeakyRelu", check_single_output_args({arg}))
: UnaryElementwiseArithmetic("LeakyRelu", {arg})
, m_alpha(alpha)
{
constructor_validate_and_infer_types();
......
......@@ -18,7 +18,7 @@
#include "ngraph/node.hpp"
#include "ngraph/op/op.hpp"
#include "ngraph/op/op.hpp"
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
namespace ngraph
{
......@@ -27,7 +27,7 @@ namespace ngraph
/// \brief Elementwise Maximum(arg, arg * alpha) operation
/// alpha > 0
///
class LeakyRelu : public Op
class LeakyRelu : public ngraph::op::util::UnaryElementwiseArithmetic
{
public:
/// \brief Constructs a LeakyRelu operation.
......
This diff is collapsed.
......@@ -53,16 +53,6 @@ namespace ngraph
private:
CPU_ExternalFunction* m_external_function;
static std::shared_ptr<Node> insert_input_conversions(
CPU_ExternalFunction* external_function,
std::shared_ptr<Node>& node,
const std::vector<mkldnn::memory::desc>& required_mds);
static void
set_output_layouts(std::shared_ptr<Node>& node,
const std::vector<mkldnn::memory::desc>& output_mds);
static void set_native_layouts(CPU_ExternalFunction* external_function,
std::shared_ptr<Node> node,
bool use_replace);
};
}
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment