Commit d99ac8ce authored by Tomasz Dołbniak's avatar Tomasz Dołbniak Committed by Robert Kimball

[FusedOps] SquaredDifference (#2918)

* SquaredDifference implementation

* Broadcast input before using it

* Simple test of SquaredDifference

* SquaredDifference validation tests

* Formatting adjustments

* Docs correction

* Exclude the unit test on iGPU

* Keep the includes in a single group

* Update intelgpu_backend.cpp

* Update unit_test.manifest

* UT for the broadcasting path
parent 9d509515
......@@ -310,6 +310,8 @@ set (SRC
op/fused/scale_shift.hpp
op/fused/space_to_depth.cpp
op/fused/space_to_depth.hpp
op/fused/squared_difference.cpp
op/fused/squared_difference.hpp
op/fused/squeeze.cpp
op/fused/squeeze.hpp
op/fused/unsqueeze.cpp
......
......@@ -108,6 +108,7 @@
#include "ngraph/op/fused/prelu.hpp"
#include "ngraph/op/fused/scale_shift.hpp"
#include "ngraph/op/fused/space_to_depth.hpp"
#include "ngraph/op/fused/squared_difference.hpp"
#include "ngraph/op/fused/squeeze.hpp"
#include "ngraph/op/fused/unsqueeze.hpp"
#include "ngraph/op/gather.hpp"
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/op/fused/squared_difference.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/op/util/fused_op.hpp"
using namespace std;
using namespace ngraph;
op::SquaredDifference::SquaredDifference(const shared_ptr<Node>& x1, const shared_ptr<Node>& x2)
: FusedOp("SquaredDifference", {x1, x2})
{
constructor_validate_and_infer_types();
}
NodeVector op::SquaredDifference::decompose_op() const
{
const auto x1 = get_argument(0);
const auto x2 = get_argument(1);
const auto broadcasted = numpy_style_broadcast({x1, x2});
const auto difference = broadcasted.at(0) - broadcasted.at(1);
return {difference * difference};
}
shared_ptr<Node> op::SquaredDifference::copy_with_new_args(const NodeVector& new_args) const
{
NODE_VALIDATION_CHECK(this,
new_args.size() == 2,
"Expected 2 elements in new_args for the SquaredDifference op but got ",
new_args.size());
return make_shared<SquaredDifference>(new_args.at(0), new_args.at(1));
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/node.hpp"
#include "ngraph/op/op.hpp"
#include "ngraph/op/util/fused_op.hpp"
namespace ngraph
{
namespace op
{
/// \brief Calculates an element-wise squared difference between two tensors
///
/// y[i] = (x1[i] - x2[i])^2
class SquaredDifference : public ngraph::op::util::FusedOp
{
public:
/// \brief Constructs the squared difference operation.
///
/// \param x1 First input tensor
/// \param x2 Second input tensor
SquaredDifference(const std::shared_ptr<ngraph::Node>& x1,
const std::shared_ptr<ngraph::Node>& x2);
virtual NodeVector decompose_op() const override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
};
}
}
......@@ -32,5 +32,6 @@ NGRAPH_OP(Normalize, ngraph::op)
NGRAPH_OP(PRelu, ngraph::op)
NGRAPH_OP(ScaleShift, ngraph::op)
NGRAPH_OP(SpaceToDepth, ngraph::op)
NGRAPH_OP(SquaredDifference, ngraph::op)
NGRAPH_OP(Squeeze, ngraph::op)
NGRAPH_OP(Unsqueeze, ngraph::op)
......@@ -2076,6 +2076,7 @@ shared_ptr<runtime::Executable>
case OP_TYPEID::ScatterNDAdd:
case OP_TYPEID::ShapeOf:
case OP_TYPEID::SpaceToDepth:
case OP_TYPEID::SquaredDifference:
case OP_TYPEID::Squeeze:
case OP_TYPEID::StopGradient:
case OP_TYPEID::Tile:
......@@ -2173,6 +2174,7 @@ bool runtime::intelgpu::IntelGPUBackend::is_supported_impl(const Node& node)
case OP_TYPEID::PRelu:
case OP_TYPEID::ScaleShift:
case OP_TYPEID::SpaceToDepth:
case OP_TYPEID::SquaredDifference:
case OP_TYPEID::Squeeze:
case OP_TYPEID::Unsqueeze: { return false;
}
......
......@@ -79,6 +79,7 @@
#include "ngraph/op/fused/prelu.hpp"
#include "ngraph/op/fused/scale_shift.hpp"
#include "ngraph/op/fused/space_to_depth.hpp"
#include "ngraph/op/fused/squared_difference.hpp"
#include "ngraph/op/fused/squeeze.hpp"
#include "ngraph/op/fused/unsqueeze.hpp"
#include "ngraph/op/gather.hpp"
......@@ -1452,6 +1453,11 @@ static shared_ptr<ngraph::Function>
node = make_shared<op::Sqrt>(args[0]);
break;
}
case OP_TYPEID::SquaredDifference:
{
node = make_shared<op::SquaredDifference>(args[0], args[1]);
break;
}
case OP_TYPEID::Squeeze:
{
node = make_shared<op::Squeeze>(args[0], args[1]);
......@@ -2198,6 +2204,8 @@ static json write(const Node& n, bool binary_constant_data)
}
case OP_TYPEID::Sqrt: { break;
}
case OP_TYPEID::SquaredDifference: { break;
}
case OP_TYPEID::Squeeze: { break;
}
case OP_TYPEID::StopGradient: { break;
......
......@@ -868,3 +868,35 @@ NGRAPH_TEST(${BACKEND_NAME}, squeeze_dynamic)
const auto axes_param = make_shared<op::Parameter>(element::i64, Shape{2});
EXPECT_THROW(make_shared<op::Squeeze>(data_param, axes_param), CheckFailure);
}
NGRAPH_TEST(${BACKEND_NAME}, squared_difference)
{
const auto x1 = make_shared<op::Parameter>(element::f64, Shape{2, 2});
const auto x2 = make_shared<op::Parameter>(element::f64, Shape{2, 2});
auto tested_op = make_shared<op::SquaredDifference>(x1, x2);
auto function = make_shared<Function>(tested_op, ParameterVector{x1, x2});
auto test_case = ngraph::test::NgraphTestCase(function, "${BACKEND_NAME}");
test_case.add_input<double>({1.0, 16.0, 0.0, 1.234567});
test_case.add_input<double>({1.0, 8.0, -3.0, 3.456789});
test_case.add_expected_output<double>(Shape{2, 2}, {0.0, 64.0, 9.0, 4.938270617284});
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, squared_difference_broadcast)
{
const auto x1 = make_shared<op::Parameter>(element::i32, Shape{2, 2});
const auto x2 = make_shared<op::Parameter>(element::i32, Shape{});
auto tested_op = make_shared<op::SquaredDifference>(x1, x2);
auto function = make_shared<Function>(tested_op, ParameterVector{x1, x2});
auto test_case = ngraph::test::NgraphTestCase(function, "${BACKEND_NAME}");
test_case.add_input<int32_t>({1, 1, 1, 1});
test_case.add_input<int32_t>({1});
test_case.add_expected_output<int32_t>(Shape{2, 2}, {0, 0, 0, 0});
test_case.run();
}
......@@ -14525,3 +14525,24 @@ TEST(type_prop, scale_shift)
EXPECT_EQ(scale_shift_func->get_element_type(), element::f64);
EXPECT_EQ(scale_shift_func->get_shape(), (Shape{3, 6}));
}
TEST(type_prop, squared_difference)
{
const auto x1 = make_shared<op::Parameter>(element::f64, Shape{2, 2});
const auto x2 = make_shared<op::Parameter>(element::f64, Shape{3, 2});
const auto x3 = make_shared<op::Parameter>(element::f64, Shape{1, 2});
try
{
const auto squared_diff = make_shared<op::SquaredDifference>(x1, x2);
FAIL() << "SquaredDifference node was created with incorrect data.";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("axes are incompatible"));
}
const auto clamp = make_shared<op::SquaredDifference>(x1, x3);
EXPECT_EQ(clamp->get_element_type(), element::f64);
EXPECT_EQ(clamp->get_shape(), (Shape{2, 2}));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment