Commit 5259bf21 authored by Tomasz Socha's avatar Tomasz Socha Committed by Scott Cyphers

[FUSED] Add reciprocal op (#3851)

* [FUSED] Add reciprocal op

* Review Fix #1

* Move operator op::v1 -> op

* Fix serializer

* Review Fix I
parent f6a404eb
...@@ -373,6 +373,8 @@ set (SRC ...@@ -373,6 +373,8 @@ set (SRC
op/fused/partial_slice.hpp op/fused/partial_slice.hpp
op/fused/prelu.cpp op/fused/prelu.cpp
op/fused/prelu.hpp op/fused/prelu.hpp
op/fused/reciprocal.cpp
op/fused/reciprocal.hpp
op/fused/rnn_cell.cpp op/fused/rnn_cell.cpp
op/fused/rnn_cell.hpp op/fused/rnn_cell.hpp
op/fused/scale_shift.cpp op/fused/scale_shift.cpp
......
...@@ -17,8 +17,7 @@ ...@@ -17,8 +17,7 @@
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "ngraph/op/constant.hpp" #include "ngraph/op/fused/reciprocal.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/util/broadcasting.hpp" #include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
...@@ -36,11 +35,7 @@ namespace ngraph ...@@ -36,11 +35,7 @@ namespace ngraph
{ {
auto data = node.get_ng_inputs().at(0); auto data = node.get_ng_inputs().at(0);
std::shared_ptr<ngraph::Node> one_node = std::make_shared<ngraph::op::Constant>( return {std::make_shared<ngraph::op::Reciprocal>(data)};
data->get_element_type(), Shape{}, std::vector<double>{1});
one_node = ngraph::op::make_broadcast_node(one_node, data->get_shape());
return {one_node / data};
} }
} // namespace set_1 } // namespace set_1
......
...@@ -151,6 +151,7 @@ namespace ngraph ...@@ -151,6 +151,7 @@ namespace ngraph
#include "ngraph/op/fused/normalize_l2.hpp" #include "ngraph/op/fused/normalize_l2.hpp"
#include "ngraph/op/fused/partial_slice.hpp" #include "ngraph/op/fused/partial_slice.hpp"
#include "ngraph/op/fused/prelu.hpp" #include "ngraph/op/fused/prelu.hpp"
#include "ngraph/op/fused/reciprocal.hpp"
#include "ngraph/op/fused/rnn_cell.hpp" #include "ngraph/op/fused/rnn_cell.hpp"
#include "ngraph/op/fused/scale_shift.hpp" #include "ngraph/op/fused/scale_shift.hpp"
#include "ngraph/op/fused/selu.hpp" #include "ngraph/op/fused/selu.hpp"
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/op/fused/reciprocal.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/divide.hpp"
using namespace std;
using namespace ngraph;
constexpr NodeTypeInfo op::Reciprocal::type_info;
op::Reciprocal::Reciprocal(const Output<Node>& data)
: FusedOp({data})
{
constructor_validate_and_infer_types();
}
NodeVector op::Reciprocal::decompose_op() const
{
auto data = input_value(0);
auto one_node = op::Constant::create(data.get_element_type(), data.get_shape(), {1});
return {make_shared<op::v1::Divide>(one_node, data)};
}
shared_ptr<Node> op::Reciprocal::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<Reciprocal>(new_args.at(0));
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/node.hpp"
#include "ngraph/op/op.hpp"
#include "ngraph/op/util/fused_op.hpp"
namespace ngraph
{
namespace op
{
/// \brief Reciprocal operation
/// f(x) = 1 / x
class Reciprocal : public ngraph::op::util::FusedOp
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"Reciprocal", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
Reciprocal() = default;
/// \brief Constructs a Reciprocal operation.
///
/// \param data Input tensor
Reciprocal(const Output<Node>& data);
virtual NodeVector decompose_op() const override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
};
} // namespace op
} // namespace ngraph
...@@ -48,6 +48,7 @@ NGRAPH_OP(NormalizeL2, ngraph::op) ...@@ -48,6 +48,7 @@ NGRAPH_OP(NormalizeL2, ngraph::op)
NGRAPH_OP(PartialSlice, ngraph::op) NGRAPH_OP(PartialSlice, ngraph::op)
NGRAPH_OP(PartialSliceBackprop, ngraph::op) NGRAPH_OP(PartialSliceBackprop, ngraph::op)
NGRAPH_OP(PRelu, ngraph::op) NGRAPH_OP(PRelu, ngraph::op)
NGRAPH_OP(Reciprocal, ngraph::op)
NGRAPH_OP(RNNCell, ngraph::op) NGRAPH_OP(RNNCell, ngraph::op)
NGRAPH_OP(ScaleShift, ngraph::op) NGRAPH_OP(ScaleShift, ngraph::op)
NGRAPH_OP(Selu, ngraph::op) NGRAPH_OP(Selu, ngraph::op)
......
...@@ -91,6 +91,7 @@ ...@@ -91,6 +91,7 @@
#include "ngraph/op/fused/normalize_l2.hpp" #include "ngraph/op/fused/normalize_l2.hpp"
#include "ngraph/op/fused/partial_slice.hpp" #include "ngraph/op/fused/partial_slice.hpp"
#include "ngraph/op/fused/prelu.hpp" #include "ngraph/op/fused/prelu.hpp"
#include "ngraph/op/fused/reciprocal.hpp"
#include "ngraph/op/fused/rnn_cell.hpp" #include "ngraph/op/fused/rnn_cell.hpp"
#include "ngraph/op/fused/scale_shift.hpp" #include "ngraph/op/fused/scale_shift.hpp"
#include "ngraph/op/fused/selu.hpp" #include "ngraph/op/fused/selu.hpp"
...@@ -2379,6 +2380,11 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js) ...@@ -2379,6 +2380,11 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
node = make_shared<op::Range>(args[0], args[1], args[2]); node = make_shared<op::Range>(args[0], args[1], args[2]);
break; break;
} }
case OP_TYPEID::Reciprocal:
{
node = make_shared<op::Reciprocal>(args[0]);
break;
}
case OP_TYPEID::Relu: case OP_TYPEID::Relu:
{ {
node = make_shared<op::Relu>(args[0]); node = make_shared<op::Relu>(args[0]);
...@@ -3945,6 +3951,8 @@ json JSONSerializer::serialize_node(const Node& n) ...@@ -3945,6 +3951,8 @@ json JSONSerializer::serialize_node(const Node& n)
} }
case OP_TYPEID::Range: { break; case OP_TYPEID::Range: { break;
} }
case OP_TYPEID::Reciprocal: { break;
}
case OP_TYPEID::Relu: { break; case OP_TYPEID::Relu: { break;
} }
case OP_TYPEID::ReluBackprop: { break; case OP_TYPEID::ReluBackprop: { break;
......
...@@ -95,6 +95,20 @@ NGRAPH_TEST(${BACKEND_NAME}, prelu) ...@@ -95,6 +95,20 @@ NGRAPH_TEST(${BACKEND_NAME}, prelu)
EXPECT_EQ(expected, read_vector<float>(result0)); EXPECT_EQ(expected, read_vector<float>(result0));
} }
NGRAPH_TEST(${BACKEND_NAME}, reciprocal)
{
Shape shape{3, 2};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto reciprocal = make_shared<op::Reciprocal>(A);
auto f0 = make_shared<Function>(NodeVector{reciprocal}, ParameterVector{A});
auto test_case = test::NgraphTestCase(f0, "${BACKEND_NAME}");
test_case.add_input(vector<float>{1, 2, 3, 4, 5, 6});
test_case.add_expected_output(
Shape{3, 2}, vector<float>{1.0f, 1 / 2.0f, 1 / 3.0f, 1 / 4.0f, 1 / 5.0f, 1 / 6.0f});
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, hardsigmoid) NGRAPH_TEST(${BACKEND_NAME}, hardsigmoid)
{ {
Shape shape{2, 7}; Shape shape{2, 7};
......
...@@ -40,3 +40,11 @@ TEST(type_prop, unary_arithmetic_bad_argument_element_types) ...@@ -40,3 +40,11 @@ TEST(type_prop, unary_arithmetic_bad_argument_element_types)
FAIL() << "Deduced type check failed for unexpected reason"; FAIL() << "Deduced type check failed for unexpected reason";
} }
} }
TEST(type_prop, reciprocal)
{
auto param = make_shared<op::Parameter>(element::f32, Shape{2, 3, 4});
auto pad = make_shared<op::Reciprocal>(param);
EXPECT_EQ(pad->get_element_type(), element::f32);
EXPECT_EQ(pad->get_shape(), (Shape{2, 3, 4}));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment