Commit 52313f9e authored by Ashok Emani's avatar Ashok Emani Committed by Scott Cyphers

Add StopGradient op to ngraph (#1067)

* add StopGradient op

* add StopGradient op src

* remove adjoints and add interpreter

* fix compile issue

* use nop_elimination and add unit-test

* update cmake

* update unit-tests
parent 91ecac9d
......@@ -93,6 +93,7 @@ set (SRC
op/slice.cpp
op/softmax.cpp
op/sqrt.cpp
op/stop_gradient.cpp
op/subtract.cpp
op/sum.cpp
op/tan.cpp
......
......@@ -119,6 +119,7 @@
#include "ngraph/op/slice.hpp"
#include "ngraph/op/softmax.hpp"
#include "ngraph/op/sqrt.hpp"
#include "ngraph/op/stop_gradient.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/sum.hpp"
#include "ngraph/op/tan.hpp"
......
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "ngraph/op/stop_gradient.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/constant.hpp"
using namespace std;
using namespace ngraph;
op::StopGradient::StopGradient(const shared_ptr<Node>& arg)
: UnaryElementwiseArithmetic("StopGradient", arg)
{
}
shared_ptr<Node> op::StopGradient::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<StopGradient>(new_args.at(0));
}
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
namespace ngraph
{
namespace op
{
/// \brief create StopGrdient op
class StopGradient : public util::UnaryElementwiseArithmetic
{
public:
/// \brief Constructs StopGradient
///
/// \param arg Node that produces the input tensor.
StopGradient(const std::shared_ptr<Node>& arg);
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
};
}
}
......@@ -24,6 +24,7 @@
#include "ngraph/op/convert.hpp"
#include "ngraph/op/pad.hpp"
#include "ngraph/op/slice.hpp"
#include "ngraph/op/stop_gradient.hpp"
#include "ngraph/op/sum.hpp"
#include "ngraph/util.hpp"
#include "nop_elimination.hpp"
......@@ -87,12 +88,19 @@ HANDLER_DECL(eliminate_broadcast)
return false;
}
HANDLER_DECL(eliminate_stop_gradient)
{
ngraph::replace_node(node, node->get_argument(0));
return true;
}
static const std::unordered_map<std::type_index,
std::function<bool(const std::shared_ptr<ngraph::Node>&)>>
dispatcher{{TI(ngraph::op::Pad), &eliminate_pad},
{TI(ngraph::op::Sum), &eliminate_sum},
{TI(ngraph::op::Convert), &eliminate_convert},
{TI(ngraph::op::Slice), &eliminate_slice},
{TI(ngraph::op::StopGradient), &eliminate_stop_gradient},
{TI(ngraph::op::Broadcast), &eliminate_broadcast}};
bool ngraph::pass::NopElimination::run_on_function(std::shared_ptr<ngraph::Function> function)
......
......@@ -100,3 +100,16 @@ TEST(nop_elimination, eliminate_broadcast)
ASSERT_EQ(count_ops_of_type<op::Broadcast>(f), 0);
}
TEST(nop_elimination, eliminate_stop_gradient)
{
Shape shape{};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto f = make_shared<Function>(make_shared<op::StopGradient>(A), op::ParameterVector{A});
pass::Manager pass_manager;
pass_manager.register_pass<pass::NopElimination>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::StopGradient>(f), 0);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment