Commit 7775d49d authored by Jayaram Bobba's avatar Jayaram Bobba Committed by Scott Cyphers

Adding support for fused ops that are decomposable to core ngraph ops (#2688)

* Initial support for specification of fused ops and type inference

* Added FusedOpDecomposition pass and execution test cases

* Serializer support

* style fix

* Add FusedOpDecomposition to GPU and IGPU backends

* Addressed PR feedback

* Fix comment

* Addressed PR feedback
parent 8c081092
......@@ -254,6 +254,8 @@ set (SRC
op/tanh.hpp
op/topk.cpp
op/topk.hpp
op/fused/prelu.cpp
op/fused/prelu.hpp
op/util/arithmetic_reduction.cpp
op/util/arithmetic_reduction.hpp
op/util/binary_elementwise_arithmetic.cpp
......@@ -262,6 +264,10 @@ set (SRC
op/util/binary_elementwise_comparison.hpp
op/util/binary_elementwise_logical.cpp
op/util/binary_elementwise_logical.hpp
op/util/broadcasting.cpp
op/util/broadcasting.hpp
op/util/fused_op.cpp
op/util/fused_op.hpp
op/util/index_reduction.cpp
op/util/index_reduction.hpp
op/util/logical_reduction.cpp
......@@ -284,6 +290,8 @@ set (SRC
pass/cse.hpp
pass/dump_sorted.cpp
pass/dump_sorted.hpp
pass/fused_op_decomposition.cpp
pass/fused_op_decomposition.hpp
pass/get_output_element_elimination.cpp
pass/get_output_element_elimination.hpp
pass/graph_rewrite.cpp
......
......@@ -65,18 +65,16 @@ void ngraph::traverse_nodes(const Function* p,
traverse_nodes(nodes, f, include_control_deps);
}
// This version of traverses directly from input/output nodes to perform functions on
// graphs that are not wrapped by functions. Most useful for finding parameters of a graph
// directly from the result nodes, not from function parameters.
void ngraph::traverse_nodes(const NodeVector& io_nodes,
void ngraph::traverse_nodes(const NodeVector& subgraph_results,
std::function<void(std::shared_ptr<Node>)> f,
bool include_control_deps,
NodeVector stop_nodes)
const NodeVector& subgraph_params)
{
std::unordered_set<std::shared_ptr<Node>> instances_seen(stop_nodes.begin(), stop_nodes.end());
std::unordered_set<std::shared_ptr<Node>> instances_seen{subgraph_params.begin(),
subgraph_params.end()};
std::deque<std::shared_ptr<Node>> stack;
for (auto r : io_nodes)
for (auto r : subgraph_results)
{
stack.push_front(r);
}
......@@ -484,6 +482,13 @@ NodeVector ngraph::get_subgraph_outputs(const NodeVector& nodes,
return outputs;
}
NodeVector ngraph::extract_subgraph(const NodeVector& results, const NodeVector& args)
{
NodeVector subgraph;
traverse_nodes(results, [&](std::shared_ptr<Node> n) { subgraph.push_back(n); }, true, args);
return subgraph;
}
bool ngraph::is_used(Node* node)
{
std::unordered_set<Node*> instances_seen;
......
......@@ -45,14 +45,29 @@ namespace ngraph
void traverse_nodes(const std::shared_ptr<const Function> p,
std::function<void(std::shared_ptr<Node>)> f,
bool include_control_deps = false);
void traverse_nodes(const Function* p,
std::function<void(std::shared_ptr<Node>)> f,
bool include_control_deps);
void traverse_nodes(const NodeVector& io_nodes,
/// \brief Visit each node in a sub-graph of the entire graph
/// \param subgraph_results The output nodes of the sub-graph
/// \param f Function to execute at each node in the traversal
/// \param include_control_deps Whether to include control deps
/// while traversing the sub-graph
/// \param subgraph_params Input nodes of the sub-graph (optional)
///
/// Traverses a sub-graph starting from subgraph_results moving up
/// towards parameter nodes. Traversal stops if it hits a node in
/// subgraph_params.
///
/// Most useful for finding parameters of a graph directly from the
/// result nodes and not from function parameters or extracting a
/// subgraph relevant to the computation of certain outputs
void traverse_nodes(const NodeVector& subgraph_results,
std::function<void(std::shared_ptr<Node>)> f,
bool include_control_deps,
NodeVector stop_nodes = {});
const NodeVector& subgraph_params = {});
void traverse_functions(std::shared_ptr<Function> p,
std::function<void(std::shared_ptr<Function>)> f);
......@@ -125,6 +140,7 @@ namespace ngraph
return result_list;
}
// For cases, where `nodes` is a subset of the entire graph
template <typename T>
std::list<std::shared_ptr<Node>> subgraph_topological_sort(const T& nodes,
bool include_control_deps = false)
......@@ -205,7 +221,7 @@ namespace ngraph
template <typename T>
void validate_nodes_and_infer_types(const T& nodes)
{
for (auto node : topological_sort(nodes))
for (auto node : subgraph_topological_sort(nodes))
{
node->delayed_validate_and_infer_types();
}
......@@ -296,6 +312,10 @@ namespace ngraph
const NodeVector& exclusions,
bool ignore_unused = false);
// Extract sub-graph computing the `results`. Stops backward traversal at either a Parameter node
// or a node that belongs to args
NodeVector extract_subgraph(const NodeVector& results, const NodeVector& args);
bool is_one(std::shared_ptr<Node> reduce_constant);
bool compare_constants(const std::shared_ptr<Node>& n1, const std::shared_ptr<Node>& n2);
......
......@@ -95,6 +95,7 @@
#include "ngraph/op/experimental/shape_of.hpp"
#include "ngraph/op/experimental/transpose.hpp"
#include "ngraph/op/floor.hpp"
#include "ngraph/op/fused/prelu.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/greater.hpp"
#include "ngraph/op/greater_eq.hpp"
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/op/fused/prelu.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/convert.hpp"
#include "ngraph/op/greater.hpp"
#include "ngraph/op/less.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/util.hpp"
using namespace std;
using namespace ngraph;
op::PRelu::PRelu(const shared_ptr<Node>& data, const shared_ptr<Node>& slope)
: FusedOp("PRelu", {data, slope})
{
constructor_validate_and_infer_types();
}
NodeVector op::PRelu::decompose_op() const
{
auto data = get_argument(0);
auto data_shape = data->get_shape();
auto slope = get_argument(1);
auto slope_shape = slope->get_shape();
if ((slope_shape.size() == 1) && (slope_shape.at(0) != 1))
{
auto it = std::find(std::begin(data_shape), std::end(data_shape), slope_shape.at(0));
auto index = std::distance(std::begin(data_shape), it);
slope = make_broadcast_node(slope, data->get_shape(), index);
}
else if (data_shape != slope_shape)
{
slope = numpy_style_broadcast({slope, data})[0];
}
// x < 0 => f(x) = x * slope
// x >= 0 => f(x) = x
std::shared_ptr<ngraph::Node> zero_node = std::make_shared<ngraph::op::Constant>(
data->get_element_type(), ngraph::Shape{}, std::vector<double>{0});
zero_node = make_broadcast_node(zero_node, data->get_shape());
std::shared_ptr<ngraph::Node> negative_map = std::make_shared<ngraph::op::Convert>(
std::make_shared<ngraph::op::Less>(data, zero_node), data->get_element_type());
std::shared_ptr<ngraph::Node> positive_map = std::make_shared<ngraph::op::Convert>(
std::make_shared<ngraph::op::Greater>(data, zero_node), data->get_element_type());
slope = negative_map * slope + positive_map;
return {data * slope};
}
shared_ptr<Node> op::PRelu::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 2)
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<PRelu>(new_args.at(0), new_args.at(1));
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/node.hpp"
#include "ngraph/op/op.hpp"
#include "ngraph/op/util/fused_op.hpp"
namespace ngraph
{
namespace op
{
/// \brief Parametrized Relu
/// x < 0 => f(x) = x * slope
/// x >= 0 => f(x) = x
///
class PRelu : public ngraph::op::util::FusedOp
{
public:
/// \brief Constructs a PRelu operation.
///
/// \param data Input tensor
/// \param slope Multipliers for negative values
PRelu(const std::shared_ptr<ngraph::Node>& data,
const std::shared_ptr<ngraph::Node>& slope);
virtual NodeVector decompose_op() const override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
};
}
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
// This collection contains one entry for each fused op.
//
NGRAPH_OP(PRelu, ngraph::op)
This diff is collapsed.
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <memory>
#include "ngraph/axis_set.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/shape.hpp"
namespace ngraph
{
namespace op
{
/// \brief Cast shape of all input nodes for an element-wise operation that requires shape-compatibility
///
/// \param inputs Original list of inputs
///
/// \return Numpy-style broadcasted list of nodes.
NodeVector numpy_style_broadcast(const NodeVector& inputs);
/// \brief Cast shape of two nodes to make them compatible for an element-wise binary operation.
///
/// If necessary the right-hand-side argument will be broadcast to match the shape
/// of left-hand-side argument. The starting of the mutually equal shape is
/// specified by the argument "start_match_axis", and if it is not set,
/// suffix matching is assumed.
///
/// This style of broadcast was used in ONNX Op sets prior to version 7, where it was
/// replaced by numpy-style broadcasting.
///
/// \param left Node which contain input of binary op.
/// \param right Node which contain input of binary op.
/// \param start_match_axis position in shape denoting start of the mutually equal shape
///
/// \return Left and right node after broadcasting.
NodeVector
legacy_style_broadcast_for_binary_operation(const std::shared_ptr<ngraph::Node>& left,
const std::shared_ptr<ngraph::Node>& right,
std::size_t start_match_axis);
/// \brief Broadcast shape of two nodes to make them compatible for a matrix multiplication.
///
/// \note This function is reflecting broadcasting behaviour of NumPys' `matmul` operation
/// \link https://docs.scipy.org/doc/numpy/reference/generated/numpy.matmul.html
/// This mean that only \"stack of matrices\" axes are bidirectionally broadcasted.
/// The last two dimension are left untouched.
///
/// \param[in] left The Node providing data for the left-hand side of matrix multiplication.
/// \param[in] right The Node providing data for the right-hand side of matrix multiplication.
///
/// \return The vector containing both nodes broadcasted.
///
NodeVector
numpy_style_broadcast_for_matmul_operation(const std::shared_ptr<ngraph::Node>& left,
const std::shared_ptr<ngraph::Node>& right);
/// \brief Generate a list of broadcast axes.
///
/// \details Informally, a broadcast "adds" axes to the input tensor, replicating
/// elements from the input tensor as needed to fill the new dimensions.
/// Function calculate which of the output axes are added in this way.
///
/// \param output_shape The new shape for the output tensor.
/// \param input_shape The shape of input tensor.
/// \param start_match_axis The axis along which we want to replicate elements.
/// The starting axis position (0-based) int the output
/// shape from which the current shape of the tensor
/// matches the desired new shape.
///
/// \return The indices of added axes.
AxisSet calculate_broadcast_axes(const Shape& output_shape,
const Shape& input_shape,
std::size_t start_match_axis);
/// \brief Generate a list of broadcast along axes.
///
/// \details Broadcast "adds" elements along axes to the input tensor, replicating
/// elements from the input tensor as needed to fill the new dimensions.
/// Function calculate which of the output axes are added in this way.
///
/// This function will attempt to match shapes, assuming the current shape
/// matches the rightmost positions of the desired new shape. This behaviour
/// is similar to NumPy's broadcasting.
///
/// \param output_shape The new shape for the output tensor.
/// \param input_shape The shape of input tensor.
///
/// \return The indices of added axes.
inline AxisSet calculate_broadcast_axes(const Shape& output_shape, const Shape& input_shape)
{
return calculate_broadcast_axes(
output_shape, input_shape, output_shape.size() - input_shape.size());
}
inline std::shared_ptr<ngraph::Node>
make_broadcast_node(const std::shared_ptr<ngraph::Node>& node, ngraph::Shape new_shape)
{
return std::make_shared<ngraph::op::Broadcast>(
node, new_shape, calculate_broadcast_axes(new_shape, node->get_shape()));
}
inline std::shared_ptr<ngraph::Node>
make_broadcast_node(const std::shared_ptr<ngraph::Node>& node,
ngraph::Shape new_shape,
std::size_t start_match_axis)
{
return std::make_shared<ngraph::op::Broadcast>(
node,
new_shape,
calculate_broadcast_axes(new_shape, node->get_shape(), start_match_axis));
}
} // namespace op
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/op/util/fused_op.hpp"
#include "ngraph/graph_util.hpp"
using namespace ngraph;
op::util::FusedOp::FusedOp(const std::string& node_type, const NodeVector& args)
: Op(node_type, args)
{
}
void op::util::FusedOp::validate_and_infer_types()
{
auto subgraph_outputs = decompose_op();
auto subgraph = extract_subgraph(subgraph_outputs, get_arguments());
validate_nodes_and_infer_types(subgraph);
size_t i = 0;
for (auto output_node : subgraph_outputs)
{
for (size_t j = 0; j < output_node->get_output_size(); j++, i++)
{
set_output_type(
i, output_node->get_output_element_type(j), output_node->get_output_shape(j));
}
}
}
void op::util::FusedOp::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
{
// TODO
throw ngraph_error("Autodiff on fused ops not supported yet");
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/op/op.hpp"
namespace ngraph
{
namespace op
{
namespace util
{
/// \brief Abstract base class for fused ops, i.e ops that can be broken down into core ngraph ops
///
class FusedOp : public Op
{
public:
/// \brief Decomposes the FusedOp into a sub-graph consisting of core ngraph ops
///
/// \return A vector of nodes comprising the sub-graph. The order of output
/// tensors must match the match output tensors of the FusedOp
virtual NodeVector decompose_op() const = 0;
void validate_and_infer_types() override;
void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
protected:
/// \brief Constructs a FusedOp
///
/// \param args Nodes that produce the input tensors for the fused op
FusedOp(const std::string& node_type, const NodeVector& args);
};
}
}
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/pass/fused_op_decomposition.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/op/util/fused_op.hpp"
using namespace std;
using namespace ngraph;
bool ngraph::pass::FusedOpDecomposition::run_on_node(std::shared_ptr<ngraph::Node> node)
{
bool modified = false;
if (auto fused_op = std::dynamic_pointer_cast<ngraph::op::util::FusedOp>(node))
{
auto subgraph = fused_op->decompose_op();
if (subgraph.size() != fused_op->get_output_size())
{
throw ngraph_error("While replacing " + node->get_name() +
", mismatch between op output count and outputs of the decomposed "
"subgraph. Expected: " +
to_string(fused_op->get_output_size()) + " Got: " +
to_string(subgraph.size()));
}
if (fused_op->get_output_size() == 1)
{
ngraph::replace_node(fused_op, subgraph[0]);
}
else
{
// TODO (jbobba): Handle multi-output ops. Need to find the GOE for the output and replace that with subgraph output node
}
modified = true;
}
return modified;
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/pass/pass.hpp"
namespace ngraph
{
namespace pass
{
class FusedOpDecomposition : public NodePass
{
public:
bool run_on_node(std::shared_ptr<ngraph::Node> node) override;
};
}
}
......@@ -125,6 +125,7 @@
#include "ngraph/pass/core_fusion.hpp"
#include "ngraph/pass/cse.hpp"
#include "ngraph/pass/dump_sorted.hpp"
#include "ngraph/pass/fused_op_decomposition.hpp"
#include "ngraph/pass/get_output_element_elimination.hpp"
#include "ngraph/pass/like_replacement.hpp"
#include "ngraph/pass/liveness.hpp"
......@@ -1117,6 +1118,7 @@ void runtime::cpu::CPU_ExternalFunction::register_common_passes(
auto pass_map = pass_config.get_enables();
REGISTER_KNOBBED_PASS(LikeReplacement, true, ngraph::pass);
REGISTER_KNOBBED_PASS(FusedOpDecomposition, true, ngraph::pass);
REGISTER_KNOBBED_PASS(NopElimination, true, ngraph::pass);
REGISTER_KNOBBED_PASS(ZeroDimTensorElimination, true, ngraph::pass);
REGISTER_KNOBBED_PASS(LSTMFusion, true, runtime::cpu::pass);
......
......@@ -33,6 +33,7 @@
#include "ngraph/function.hpp"
#include "ngraph/node.hpp"
#include "ngraph/pass/algebraic_simplification.hpp"
#include "ngraph/pass/fused_op_decomposition.hpp"
#include "ngraph/pass/get_output_element_elimination.hpp"
#include "ngraph/pass/like_replacement.hpp"
......@@ -170,6 +171,7 @@ void runtime::gpu::GPUCompiledFunction::compile()
#endif
pass_manager.register_pass<runtime::gpu::pass::BatchNormCache>();
pass_manager.register_pass<ngraph::pass::LikeReplacement>();
pass_manager.register_pass<ngraph::pass::FusedOpDecomposition>();
pass_manager.register_pass<runtime::gpu::pass::GPULayout>(this);
pass_manager.register_pass<ngraph::pass::AssignLayout<descriptor::layout::DenseTensorLayout>>();
pass_manager.register_pass<ngraph::pass::GetOutputElementElimination>();
......
......@@ -43,6 +43,7 @@
#include "ngraph/pass/algebraic_simplification.hpp"
#include "ngraph/pass/cse.hpp"
#include "ngraph/pass/fused_op_decomposition.hpp"
#include "ngraph/pass/get_output_element_elimination.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/nop_elimination.hpp"
......@@ -429,6 +430,7 @@ shared_ptr<runtime::Executable>
{
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<ngraph::pass::FusedOpDecomposition>();
pass_manager.register_pass<ngraph::pass::NopElimination>();
pass_manager.register_pass<ngraph::pass::AlgebraicSimplification>();
pass_manager.register_pass<ngraph::pass::CommonSubexpressionElimination>();
......
......@@ -21,6 +21,7 @@
#include "ngraph/op/select.hpp"
#include "ngraph/op/util/binary_elementwise_comparison.hpp"
#include "ngraph/pass/assign_layout.hpp"
#include "ngraph/pass/fused_op_decomposition.hpp"
#include "ngraph/pass/like_replacement.hpp"
#include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/manager.hpp"
......@@ -39,6 +40,7 @@ runtime::interpreter::INTExecutable::INTExecutable(const shared_ptr<Function>& f
m_is_compiled = true;
pass::Manager pass_manager;
pass_manager.register_pass<pass::LikeReplacement>();
pass_manager.register_pass<pass::FusedOpDecomposition>();
pass_manager.register_pass<pass::AssignLayout<DenseTensorLayout>>();
pass_manager.register_pass<pass::Liveness>();
pass_manager.run_passes(function);
......
......@@ -64,6 +64,7 @@
#include "ngraph/op/experimental/shape_of.hpp"
#include "ngraph/op/experimental/transpose.hpp"
#include "ngraph/op/floor.hpp"
#include "ngraph/op/fused/prelu.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/greater.hpp"
#include "ngraph/op/greater_eq.hpp"
......@@ -124,6 +125,7 @@ using const_data_callback_t = shared_ptr<Node>(const string&, const element::Typ
#define NGRAPH_OP(a, b) a,
enum class OP_TYPEID
{
#include "ngraph/op/fused_op_tbl.hpp"
#include "ngraph/op/op_tbl.hpp"
UnknownOp
};
......@@ -137,6 +139,7 @@ static OP_TYPEID get_typeid(const string& s)
// ...
#define NGRAPH_OP(a, b) {#a, OP_TYPEID::a},
static const unordered_map<string, OP_TYPEID> typeid_map{
#include "ngraph/op/fused_op_tbl.hpp"
#include "ngraph/op/op_tbl.hpp"
};
#undef NGRAPH_OP
......@@ -1024,6 +1027,11 @@ static shared_ptr<ngraph::Function>
node = make_shared<op::Power>(args[0], args[1]);
break;
}
case OP_TYPEID::PRelu:
{
node = make_shared<op::PRelu>(args[0], args[1]);
break;
}
case OP_TYPEID::Product:
{
auto reduction_axes = node_js.at("reduction_axes").get<set<size_t>>();
......@@ -1674,6 +1682,8 @@ static json write(const Node& n, bool binary_constant_data)
node["output_shapes"] = std::move(outputs_js);
break;
}
case OP_TYPEID::PRelu: { break;
}
case OP_TYPEID::Product:
{
auto tmp = dynamic_cast<const op::Product*>(&n);
......
......@@ -143,6 +143,7 @@ set(MULTI_TEST_SRC
backend_comparison.in.cpp
backend_dot.in.cpp
backend_embedding_lookup.in.cpp
backend_fusedop.in.cpp
backend_one_hot.in.cpp
backend_pool.in.cpp
backend_reshape.in.cpp
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <algorithm>
#include <cinttypes>
#include <cmath>
#include <cstdlib>
#include <random>
#include <string>
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "util/all_close.hpp"
#include "util/all_close_f.hpp"
#include "util/ndarray.hpp"
#include "util/random.hpp"
#include "util/test_control.hpp"
#include "util/test_tools.hpp"
using namespace std;
using namespace ngraph;
static string s_manifest = "${MANIFEST}";
NGRAPH_TEST(${BACKEND_NAME}, prelu)
{
Shape shape{3, 2};
Shape rshape{3};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::Parameter>(element::f32, rshape);
auto prelu = make_shared<op::PRelu>(A, B);
auto f0 = make_shared<Function>(NodeVector{prelu}, ParameterVector{A, B});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape);
copy_data(a, vector<float>{-2, 3, -2, 1, -1, 0});
auto b = backend->create_tensor(element::f32, rshape);
copy_data(b, vector<float>{0, 0.5, 1});
auto result0 = backend->create_tensor(element::f32, shape);
auto handle = backend->compile(f0);
handle->call_with_validate({result0}, {a, b});
vector<float> expected{0, 3, -1, 1, -1, 0};
EXPECT_EQ(expected, read_vector<float>(result0));
}
NGRAPH_TEST(${BACKEND_NAME}, prelu_shared_slope)
{
Shape shape{3, 2};
Shape rshape{};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::Parameter>(element::f32, rshape);
auto prelu = make_shared<op::PRelu>(A, B);
auto f0 = make_shared<Function>(NodeVector{prelu}, ParameterVector{A, B});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape);
copy_data(a, vector<float>{-2, 3, -2, 1, -1, 0});
auto b = backend->create_tensor(element::f32, rshape);
copy_data(b, vector<float>{0.5});
auto result0 = backend->create_tensor(element::f32, shape);
auto handle = backend->compile(f0);
handle->call_with_validate({result0}, {a, b});
vector<float> expected{-1, 3, -1, 1, -0.5, 0};
EXPECT_EQ(expected, read_vector<float>(result0));
}
NGRAPH_TEST(${BACKEND_NAME}, prelu_negative_slope)
{
Shape shape{3, 2};
Shape rshape{};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::Parameter>(element::f32, rshape);
auto prelu = make_shared<op::PRelu>(A, B);
auto f0 = make_shared<Function>(NodeVector{prelu}, ParameterVector{A, B});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape);
copy_data(a, vector<float>{-2, 3, -2, 1, -1, 0});
auto b = backend->create_tensor(element::f32, rshape);
copy_data(b, vector<float>{-0.5});
auto result0 = backend->create_tensor(element::f32, shape);
auto handle = backend->compile(f0);
handle->call_with_validate({result0}, {a, b});
vector<float> expected{1, 3, 1, 1, 0.5, 0};
EXPECT_EQ(expected, read_vector<float>(result0));
}
......@@ -13074,3 +13074,13 @@ TEST(type_prop, dynslice_params_et_wrong)
DynSlice_Test_Type_Except(arg, lower_bounds, upper_bounds, strides);
}
}
TEST(type_prop, prelu)
{
auto param = make_shared<op::Parameter>(element::f32, Shape{2, 4});
auto slope = make_shared<op::Parameter>(element::f32, Shape{2});
Shape prelu_shape{2, 4};
auto prelu = make_shared<op::PRelu>(param, slope);
ASSERT_EQ(prelu->get_element_type(), element::f32);
ASSERT_EQ(prelu->get_shape(), prelu_shape);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment