Commit 6f961de3 authored by tsocha's avatar tsocha Committed by Scott Cyphers

[Fused Op] Add ScaleShift operator (#2892)

* Add ScaleShift operator

* Add ScaleShift to serializer

* Add UT for ScaleShift

* Add type_prop tests for ScaleShift

* Style-fix

* Skip tests on Intel GPU

* Review fix 1

* Style fix
parent c5ed55bd
......@@ -300,6 +300,8 @@ set (SRC
op/fused/normalize.hpp
op/fused/prelu.cpp
op/fused/prelu.hpp
op/fused/scale_shift.cpp
op/fused/scale_shift.hpp
op/fused/space_to_depth.cpp
op/fused/space_to_depth.hpp
op/util/arithmetic_reduction.cpp
......
......@@ -105,6 +105,7 @@
#include "ngraph/op/fused/mvn.hpp"
#include "ngraph/op/fused/normalize.hpp"
#include "ngraph/op/fused/prelu.hpp"
#include "ngraph/op/fused/scale_shift.hpp"
#include "ngraph/op/fused/space_to_depth.hpp"
#include "ngraph/op/gather.hpp"
#include "ngraph/op/gather_nd.hpp"
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "scale_shift.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/util/broadcasting.hpp"
using namespace std;
using namespace ngraph;
op::ScaleShift::ScaleShift(const std::shared_ptr<ngraph::Node>& data,
const std::shared_ptr<ngraph::Node>& scale,
const std::shared_ptr<ngraph::Node>& shift)
: FusedOp("ScaleShift", {data, scale, shift})
{
constructor_validate_and_infer_types();
}
NodeVector op::ScaleShift::decompose_op() const
{
auto data = get_argument(0);
auto scale = get_argument(1);
auto shift = get_argument(2);
// broadcast all data
auto broadcasted_nodes = numpy_style_broadcast({data, scale, shift});
data = broadcasted_nodes[0];
scale = broadcasted_nodes[1];
shift = broadcasted_nodes[2];
return {scale * data + shift};
}
shared_ptr<Node> op::ScaleShift::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 3)
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<ScaleShift>(new_args.at(0), new_args.at(1), new_args.at(2));
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/node.hpp"
#include "ngraph/op/op.hpp"
#include "ngraph/op/util/fused_op.hpp"
namespace ngraph
{
namespace op
{
/// \brief Operator performing Scale Shift transformation.
///
/// Y = Scale * Data + Shift
///
class ScaleShift : public ngraph::op::util::FusedOp
{
public:
/// \brief Constructs an ScaleShift operation.
///
/// \param data Input tensor
/// \param scale Input tensor that scale input data
/// \param shift Input tensor that shift input data
ScaleShift(const std::shared_ptr<ngraph::Node>& data,
const std::shared_ptr<ngraph::Node>& scale,
const std::shared_ptr<ngraph::Node>& shift);
virtual NodeVector decompose_op() const override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
};
}
}
......@@ -29,4 +29,5 @@ NGRAPH_OP(GroupConvolution, ngraph::op)
NGRAPH_OP(MVN, ngraph::op)
NGRAPH_OP(Normalize, ngraph::op)
NGRAPH_OP(PRelu, ngraph::op)
NGRAPH_OP(ScaleShift, ngraph::op)
NGRAPH_OP(SpaceToDepth, ngraph::op)
......@@ -85,6 +85,7 @@
#include "ngraph/op/fused/hard_sigmoid.hpp"
#include "ngraph/op/fused/mvn.hpp"
#include "ngraph/op/fused/normalize.hpp"
#include "ngraph/op/fused/scale_shift.hpp"
#include "ngraph/op/fused/space_to_depth.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/greater.hpp"
......@@ -2054,6 +2055,7 @@ shared_ptr<runtime::Executable>
case OP_TYPEID::QuantizedMaxPool:
case OP_TYPEID::ReplaceSlice:
case OP_TYPEID::ScalarConstantLike:
case OP_TYPEID::ScaleShift:
case OP_TYPEID::ShapeOf:
case OP_TYPEID::SpaceToDepth:
case OP_TYPEID::StopGradient:
......
......@@ -77,4 +77,6 @@ mvn_mean_normalization
mvn_mean_normalization_split_channels
mvn_mean_variance_normalization
mvn_mean_variance_normalization_split_channels
scale_shift_no_broadcast
scale_shift
zero_sized_erf
......@@ -76,6 +76,7 @@
#include "ngraph/op/fused/mvn.hpp"
#include "ngraph/op/fused/normalize.hpp"
#include "ngraph/op/fused/prelu.hpp"
#include "ngraph/op/fused/scale_shift.hpp"
#include "ngraph/op/fused/space_to_depth.hpp"
#include "ngraph/op/gather.hpp"
#include "ngraph/op/gather_nd.hpp"
......@@ -1365,6 +1366,11 @@ static shared_ptr<ngraph::Function>
node = make_shared<op::ScalarConstantLike>(args[0], value);
break;
}
case OP_TYPEID::ScaleShift:
{
node = make_shared<op::ScaleShift>(args[0], args[1], args[2]);
break;
}
case OP_TYPEID::Select:
{
node = make_shared<op::Select>(args[0], args[1], args[2]);
......@@ -2118,6 +2124,8 @@ static json write(const Node& n, bool binary_constant_data)
node["element_type"] = write_element_type(constant->get_element_type());
break;
}
case OP_TYPEID::ScaleShift: { break;
}
case OP_TYPEID::Select: { break;
}
case OP_TYPEID::ShapeOf: { break;
......
......@@ -715,3 +715,45 @@ NGRAPH_TEST(${BACKEND_NAME}, mvn_mean_variance_normalization_split_channels)
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, scale_shift_no_broadcast)
{
auto data = make_shared<op::Parameter>(element::f64, Shape{3, 6});
auto scale = make_shared<op::Parameter>(element::f64, Shape{3, 6});
auto shift = make_shared<op::Parameter>(element::f64, Shape{3, 6});
auto scale_shift_func = make_shared<op::ScaleShift>(data, scale, shift);
auto function =
make_shared<Function>(NodeVector{scale_shift_func}, ParameterVector{data, scale, shift});
auto test_case = ngraph::test::NgraphTestCase(function, "${BACKEND_NAME}");
// Data
test_case.add_input<double>(vector<double>(18, 2));
// Scale
test_case.add_input<double>(vector<double>(18, 2));
// Shift
test_case.add_input<double>(vector<double>(18, 2));
//output
test_case.add_expected_output<double>(Shape{3, 6}, vector<double>(18, 6));
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, scale_shift)
{
auto data = make_shared<op::Parameter>(element::f64, Shape{3, 6});
auto scale = make_shared<op::Parameter>(element::f64, Shape{3, 6});
auto shift = make_shared<op::Parameter>(element::f64, Shape{});
auto scale_shift_func = make_shared<op::ScaleShift>(data, scale, shift);
auto function =
make_shared<Function>(NodeVector{scale_shift_func}, ParameterVector{data, scale, shift});
auto test_case = ngraph::test::NgraphTestCase(function, "${BACKEND_NAME}");
// Data
test_case.add_input<double>(vector<double>(18, 2));
// Scale
test_case.add_input<double>(vector<double>(18, 2));
// Shift
test_case.add_input<double>(vector<double>{2});
//output
test_case.add_expected_output<double>(Shape{3, 6}, vector<double>(18, 6));
test_case.run();
}
......@@ -14163,3 +14163,23 @@ TEST(type_prop, fused_clamp)
EXPECT_EQ(clamp->get_element_type(), element::f64);
EXPECT_EQ(clamp->get_shape(), (Shape{2, 2}));
}
TEST(type_prop, scale_shift_no_broadcast)
{
auto data = make_shared<op::Parameter>(element::f64, Shape{3, 6});
auto scale = make_shared<op::Parameter>(element::f64, Shape{3, 6});
auto shift = make_shared<op::Parameter>(element::f64, Shape{3, 6});
auto scale_shift_func = make_shared<op::ScaleShift>(data, scale, shift);
EXPECT_EQ(scale_shift_func->get_element_type(), element::f64);
EXPECT_EQ(scale_shift_func->get_shape(), (Shape{3, 6}));
}
TEST(type_prop, scale_shift)
{
auto data = make_shared<op::Parameter>(element::f64, Shape{3, 6});
auto scale = make_shared<op::Parameter>(element::f64, Shape{3, 6});
auto shift = make_shared<op::Parameter>(element::f64, Shape{});
auto scale_shift_func = make_shared<op::ScaleShift>(data, scale, shift);
EXPECT_EQ(scale_shift_func->get_element_type(), element::f64);
EXPECT_EQ(scale_shift_func->get_shape(), (Shape{3, 6}));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment