Commit edac72f6 authored by Robert Kimball's avatar Robert Kimball Committed by Scott Cyphers

Normalize op implementations (#728)

* Normalize op implementations

* update custom backend ops
parent 3c8ab010
...@@ -32,11 +32,15 @@ set (SRC ...@@ -32,11 +32,15 @@ set (SRC
log.cpp log.cpp
node.cpp node.cpp
op/abs.cpp op/abs.cpp
op/acos.cpp
op/add.cpp op/add.cpp
op/allreduce.cpp op/allreduce.cpp
op/asin.cpp
op/atan.cpp
op/avg_pool.cpp op/avg_pool.cpp
op/batch_norm.cpp op/batch_norm.cpp
op/broadcast.cpp op/broadcast.cpp
op/ceiling.cpp
op/concat.cpp op/concat.cpp
op/constant.cpp op/constant.cpp
op/convert.cpp op/convert.cpp
...@@ -45,30 +49,42 @@ set (SRC ...@@ -45,30 +49,42 @@ set (SRC
op/cosh.cpp op/cosh.cpp
op/divide.cpp op/divide.cpp
op/dot.cpp op/dot.cpp
op/equal.cpp
op/exp.cpp op/exp.cpp
op/floor.cpp
op/function_call.cpp op/function_call.cpp
op/get_output_element.cpp op/get_output_element.cpp
op/greater.cpp
op/greater_eq.cpp
op/less.cpp
op/less_eq.cpp
op/log.cpp op/log.cpp
op/max_pool.cpp op/max.cpp
op/maximum.cpp op/maximum.cpp
op/max_pool.cpp
op/min.cpp
op/minimum.cpp op/minimum.cpp
op/multiply.cpp op/multiply.cpp
op/negative.cpp op/negative.cpp
op/not.cpp op/not.cpp
op/not_equal.cpp
op/one_hot.cpp op/one_hot.cpp
op/op.cpp op/op.cpp
op/pad.cpp op/pad.cpp
op/parameter.cpp op/parameter.cpp
op/power.cpp op/power.cpp
op/product.cpp
op/reduce.cpp op/reduce.cpp
op/reduce_window.cpp op/reduce_window.cpp
op/relu.cpp op/relu.cpp
op/remainder.cpp
op/replace_slice.cpp op/replace_slice.cpp
op/reshape.cpp op/reshape.cpp
op/reverse.cpp
op/result.cpp op/result.cpp
op/select.cpp op/reverse.cpp
op/select_and_scatter.cpp op/select_and_scatter.cpp
op/select.cpp
op/sign.cpp
op/sin.cpp op/sin.cpp
op/sinh.cpp op/sinh.cpp
op/slice.cpp op/slice.cpp
......
...@@ -18,10 +18,26 @@ ...@@ -18,10 +18,26 @@
#include "ngraph/op/multiply.hpp" #include "ngraph/op/multiply.hpp"
#include "ngraph/op/sign.hpp" #include "ngraph/op/sign.hpp"
void ngraph::op::Abs::generate_adjoints(autodiff::Adjoints& adjoints, using namespace std;
const std::shared_ptr<Node>& delta) using namespace ngraph;
op::Abs::Abs(const shared_ptr<Node>& arg)
: UnaryElementwiseArithmetic("Abs", arg)
{
}
shared_ptr<Node> op::Abs::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Abs>(new_args.at(0));
}
void op::Abs::generate_adjoints(autodiff::Adjoints& adjoints, const shared_ptr<Node>& delta)
{ {
auto x = get_inputs().at(0).get_output().get_node(); auto x = get_inputs().at(0).get_output().get_node();
adjoints.add_delta(x, delta * std::make_shared<op::Sign>(x)); adjoints.add_delta(x, delta * make_shared<op::Sign>(x));
} }
...@@ -36,20 +36,11 @@ namespace ngraph ...@@ -36,20 +36,11 @@ namespace ngraph
/// ///
/// Output `[d1, ...]` /// Output `[d1, ...]`
/// ///
Abs(const std::shared_ptr<Node>& arg) Abs(const std::shared_ptr<Node>& arg);
: UnaryElementwiseArithmetic("Abs", arg) Abs(const op::Abs& other, const NodeVector& new_args);
{
}
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override copy_with_new_args(const NodeVector& new_args) const override;
{
if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<Abs>(new_args.at(0));
}
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
......
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "ngraph/op/acos.hpp"
using namespace std;
using namespace ngraph;
op::Acos::Acos(const shared_ptr<Node>& arg)
: UnaryElementwiseArithmetic("Acos", arg)
{
}
shared_ptr<Node> op::Acos::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Acos>(new_args.at(0));
}
...@@ -36,20 +36,10 @@ namespace ngraph ...@@ -36,20 +36,10 @@ namespace ngraph
/// ///
/// Output `[d1, ...]` /// Output `[d1, ...]`
/// ///
Acos(const std::shared_ptr<Node>& arg) Acos(const std::shared_ptr<Node>& arg);
: UnaryElementwiseArithmetic("Acos", arg)
{
}
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override copy_with_new_args(const NodeVector& new_args) const override;
{
if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<Acos>(new_args.at(0));
}
}; };
} }
} }
...@@ -16,8 +16,24 @@ ...@@ -16,8 +16,24 @@
#include "ngraph/op/add.hpp" #include "ngraph/op/add.hpp"
void ngraph::op::Add::generate_adjoints(autodiff::Adjoints& adjoints, using namespace std;
const std::shared_ptr<Node>& delta) using namespace ngraph;
op::Add::Add(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg1)
: BinaryElementwiseArithmetic("Add", arg0, arg1)
{
}
shared_ptr<Node> op::Add::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 2)
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Add>(new_args.at(0), new_args.at(1));
}
void op::Add::generate_adjoints(autodiff::Adjoints& adjoints, const shared_ptr<Node>& delta)
{ {
auto x = get_input_op(0); auto x = get_input_op(0);
auto y = get_input_op(1); auto y = get_input_op(1);
...@@ -25,3 +41,8 @@ void ngraph::op::Add::generate_adjoints(autodiff::Adjoints& adjoints, ...@@ -25,3 +41,8 @@ void ngraph::op::Add::generate_adjoints(autodiff::Adjoints& adjoints,
adjoints.add_delta(x, delta); adjoints.add_delta(x, delta);
adjoints.add_delta(y, delta); adjoints.add_delta(y, delta);
} }
shared_ptr<Node> ngraph::operator+(const shared_ptr<Node> arg0, const shared_ptr<Node> arg1)
{
return make_shared<op::Add>(arg0, arg1);
}
...@@ -38,20 +38,10 @@ namespace ngraph ...@@ -38,20 +38,10 @@ namespace ngraph
/// ///
/// Output `[d0, ...]` /// Output `[d0, ...]`
/// ///
Add(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) Add(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1);
: BinaryElementwiseArithmetic("Add", arg0, arg1)
{
}
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override copy_with_new_args(const NodeVector& new_args) const override;
{
if (new_args.size() != 2)
{
throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<Add>(new_args.at(0), new_args.at(1));
}
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
...@@ -60,9 +50,6 @@ namespace ngraph ...@@ -60,9 +50,6 @@ namespace ngraph
}; };
} }
inline std::shared_ptr<ngraph::Node> operator+(const std::shared_ptr<ngraph::Node> arg0, std::shared_ptr<ngraph::Node> operator+(const std::shared_ptr<ngraph::Node> arg0,
const std::shared_ptr<ngraph::Node> arg1) const std::shared_ptr<ngraph::Node> arg1);
{
return std::make_shared<ngraph::op::Add>(arg0, arg1);
}
} }
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::AllReduce::AllReduce(const std::shared_ptr<Node>& arg) op::AllReduce::AllReduce(const shared_ptr<Node>& arg)
: RequiresTensorViewArgs("AllReduce", {arg}) : RequiresTensorViewArgs("AllReduce", {arg})
{ {
auto& input = m_inputs.at(0); auto& input = m_inputs.at(0);
...@@ -31,3 +31,12 @@ op::AllReduce::AllReduce(const std::shared_ptr<Node>& arg) ...@@ -31,3 +31,12 @@ op::AllReduce::AllReduce(const std::shared_ptr<Node>& arg)
throw ngraph_error("Unsupported data type for AllReduce"); throw ngraph_error("Unsupported data type for AllReduce");
} }
} }
shared_ptr<Node> op::AllReduce::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<AllReduce>(new_args.at(0));
}
...@@ -29,14 +29,7 @@ namespace ngraph ...@@ -29,14 +29,7 @@ namespace ngraph
AllReduce(const std::shared_ptr<Node>& arg); AllReduce(const std::shared_ptr<Node>& arg);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override copy_with_new_args(const NodeVector& new_args) const override;
{
if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<AllReduce>(new_args.at(0));
}
}; };
} }
} }
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "ngraph/op/asin.hpp"
using namespace std;
using namespace ngraph;
op::Asin::Asin(const shared_ptr<Node>& arg)
: UnaryElementwiseArithmetic("Asin", arg)
{
}
shared_ptr<Node> op::Asin::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Asin>(new_args.at(0));
}
...@@ -36,20 +36,10 @@ namespace ngraph ...@@ -36,20 +36,10 @@ namespace ngraph
/// ///
/// Output `[d1, ...]` /// Output `[d1, ...]`
/// ///
Asin(const std::shared_ptr<Node>& arg) Asin(const std::shared_ptr<Node>& arg);
: UnaryElementwiseArithmetic("Asin", arg)
{
}
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override copy_with_new_args(const NodeVector& new_args) const override;
{
if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<Asin>(new_args.at(0));
}
}; };
} }
} }
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "ngraph/op/atan.hpp"
using namespace std;
using namespace ngraph;
op::Atan::Atan(const shared_ptr<Node>& arg)
: UnaryElementwiseArithmetic("Atan", arg)
{
}
shared_ptr<Node> op::Atan::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Atan>(new_args.at(0));
}
...@@ -36,20 +36,10 @@ namespace ngraph ...@@ -36,20 +36,10 @@ namespace ngraph
/// ///
/// Output `[d1, ...]` /// Output `[d1, ...]`
/// ///
Atan(const std::shared_ptr<Node>& arg) Atan(const std::shared_ptr<Node>& arg);
: UnaryElementwiseArithmetic("Atan", arg)
{
}
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override copy_with_new_args(const NodeVector& new_args) const override;
{
if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<Atan>(new_args.at(0));
}
}; };
} }
} }
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::AvgPool::AvgPool(const std::shared_ptr<Node>& arg, op::AvgPool::AvgPool(const shared_ptr<Node>& arg,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Shape& padding_below, const Shape& padding_below,
...@@ -191,12 +191,12 @@ op::AvgPool::AvgPool(const std::shared_ptr<Node>& arg, ...@@ -191,12 +191,12 @@ op::AvgPool::AvgPool(const std::shared_ptr<Node>& arg,
Shape result_shape(1 + 1 + spatial_dimension_count); Shape result_shape(1 + 1 + spatial_dimension_count);
result_shape[0] = batch_size; result_shape[0] = batch_size;
result_shape[1] = channel_count; result_shape[1] = channel_count;
std::copy(output_item_shape.begin(), output_item_shape.end(), result_shape.begin() + 2); copy(output_item_shape.begin(), output_item_shape.end(), result_shape.begin() + 2);
set_value_type_checked(get_input_element_type(0), result_shape); set_value_type_checked(get_input_element_type(0), result_shape);
} }
static Shape default_padding(const std::shared_ptr<Node>& arg) static Shape default_padding(const shared_ptr<Node>& arg)
{ {
if (arg->get_outputs().size() != 1) if (arg->get_outputs().size() != 1)
{ {
...@@ -214,7 +214,7 @@ static Shape default_padding(const std::shared_ptr<Node>& arg) ...@@ -214,7 +214,7 @@ static Shape default_padding(const std::shared_ptr<Node>& arg)
return Shape(arg_shape.size() - 2, 0); return Shape(arg_shape.size() - 2, 0);
} }
op::AvgPool::AvgPool(const std::shared_ptr<Node>& arg, op::AvgPool::AvgPool(const shared_ptr<Node>& arg,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides) const Strides& window_movement_strides)
: AvgPool(arg, : AvgPool(arg,
...@@ -226,7 +226,7 @@ op::AvgPool::AvgPool(const std::shared_ptr<Node>& arg, ...@@ -226,7 +226,7 @@ op::AvgPool::AvgPool(const std::shared_ptr<Node>& arg,
{ {
} }
static Strides default_strides(const std::shared_ptr<Node>& arg) static Strides default_strides(const shared_ptr<Node>& arg)
{ {
if (arg->get_outputs().size() != 1) if (arg->get_outputs().size() != 1)
{ {
...@@ -244,7 +244,7 @@ static Strides default_strides(const std::shared_ptr<Node>& arg) ...@@ -244,7 +244,7 @@ static Strides default_strides(const std::shared_ptr<Node>& arg)
return Strides(arg_shape.size() - 2, 1); return Strides(arg_shape.size() - 2, 1);
} }
op::AvgPool::AvgPool(const std::shared_ptr<Node>& arg, const Shape& window_shape) op::AvgPool::AvgPool(const shared_ptr<Node>& arg, const Shape& window_shape)
: AvgPool(arg, : AvgPool(arg,
window_shape, window_shape,
default_strides(arg), default_strides(arg),
...@@ -254,8 +254,23 @@ op::AvgPool::AvgPool(const std::shared_ptr<Node>& arg, const Shape& window_shape ...@@ -254,8 +254,23 @@ op::AvgPool::AvgPool(const std::shared_ptr<Node>& arg, const Shape& window_shape
{ {
} }
shared_ptr<Node> op::AvgPool::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<AvgPool>(new_args.at(0),
m_window_shape,
m_window_movement_strides,
m_padding_below,
m_padding_above,
m_include_padding_in_avg_computation);
}
op::AvgPoolBackprop::AvgPoolBackprop(const Shape& forward_arg_shape, op::AvgPoolBackprop::AvgPoolBackprop(const Shape& forward_arg_shape,
const std::shared_ptr<Node>& delta, const shared_ptr<Node>& delta,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Shape& padding_below, const Shape& padding_below,
...@@ -435,7 +450,7 @@ op::AvgPoolBackprop::AvgPoolBackprop(const Shape& forward_arg_shape, ...@@ -435,7 +450,7 @@ op::AvgPoolBackprop::AvgPoolBackprop(const Shape& forward_arg_shape,
Shape forward_result_shape(1 + 1 + spatial_dimension_count); Shape forward_result_shape(1 + 1 + spatial_dimension_count);
forward_result_shape[0] = batch_size; forward_result_shape[0] = batch_size;
forward_result_shape[1] = channel_count; forward_result_shape[1] = channel_count;
std::copy(output_item_shape.begin(), output_item_shape.end(), forward_result_shape.begin() + 2); copy(output_item_shape.begin(), output_item_shape.end(), forward_result_shape.begin() + 2);
if (forward_result_shape != delta_shape) if (forward_result_shape != delta_shape)
{ {
...@@ -446,17 +461,33 @@ op::AvgPoolBackprop::AvgPoolBackprop(const Shape& forward_arg_shape, ...@@ -446,17 +461,33 @@ op::AvgPoolBackprop::AvgPoolBackprop(const Shape& forward_arg_shape,
set_value_type_checked(get_input_element_type(0), forward_arg_shape); set_value_type_checked(get_input_element_type(0), forward_arg_shape);
} }
void op::AvgPool::generate_adjoints(autodiff::Adjoints& adjoints, shared_ptr<Node> op::AvgPoolBackprop::copy_with_new_args(const NodeVector& new_args) const
const std::shared_ptr<Node>& delta) {
if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments");
}
AvgPoolBackprop* avpn = new AvgPoolBackprop(m_forward_arg_shape,
new_args.at(0),
m_window_shape,
m_window_movement_strides,
m_padding_below,
m_padding_above,
m_include_padding_in_avg_computation);
return shared_ptr<op::AvgPoolBackprop>(avpn);
}
void op::AvgPool::generate_adjoints(autodiff::Adjoints& adjoints, const shared_ptr<Node>& delta)
{ {
auto operand = get_input_op(0); auto operand = get_input_op(0);
auto& operand_shape = get_input_shape(0); auto& operand_shape = get_input_shape(0);
auto backprop = std::make_shared<op::AvgPoolBackprop>(operand_shape, auto backprop = make_shared<op::AvgPoolBackprop>(operand_shape,
delta, delta,
m_window_shape, m_window_shape,
m_window_movement_strides, m_window_movement_strides,
m_padding_below, m_padding_below,
m_padding_above, m_padding_above,
m_include_padding_in_avg_computation); m_include_padding_in_avg_computation);
adjoints.add_delta(operand, backprop); adjoints.add_delta(operand, backprop);
} }
...@@ -70,20 +70,7 @@ namespace ngraph ...@@ -70,20 +70,7 @@ namespace ngraph
AvgPool(const std::shared_ptr<Node>& arg, const Shape& window_shape); AvgPool(const std::shared_ptr<Node>& arg, const Shape& window_shape);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override copy_with_new_args(const NodeVector& new_args) const override;
{
if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<AvgPool>(new_args.at(0),
m_window_shape,
m_window_movement_strides,
m_padding_below,
m_padding_above,
m_include_padding_in_avg_computation);
}
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const std::shared_ptr<Node>& delta) override; const std::shared_ptr<Node>& delta) override;
...@@ -121,22 +108,7 @@ namespace ngraph ...@@ -121,22 +108,7 @@ namespace ngraph
bool include_padding_in_avg_computation); bool include_padding_in_avg_computation);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override copy_with_new_args(const NodeVector& new_args) const override;
{
if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments");
}
AvgPoolBackprop* avpn = new AvgPoolBackprop(m_forward_arg_shape,
new_args.at(0),
m_window_shape,
m_window_movement_strides,
m_padding_below,
m_padding_above,
m_include_padding_in_avg_computation);
return std::shared_ptr<op::AvgPoolBackprop>(avpn);
}
const Shape& get_forward_arg_shape() const { return m_forward_arg_shape; } const Shape& get_forward_arg_shape() const { return m_forward_arg_shape; }
const Shape& get_window_shape() const { return m_window_shape; } const Shape& get_window_shape() const { return m_window_shape; }
......
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Broadcast::Broadcast(const std::shared_ptr<Node>& arg, op::Broadcast::Broadcast(const shared_ptr<Node>& arg,
const Shape& shape, const Shape& shape,
const AxisSet& broadcast_axes) const AxisSet& broadcast_axes)
: RequiresTensorViewArgs("Broadcast", {arg}) : RequiresTensorViewArgs("Broadcast", {arg})
...@@ -40,8 +40,16 @@ op::Broadcast::Broadcast(const std::shared_ptr<Node>& arg, ...@@ -40,8 +40,16 @@ op::Broadcast::Broadcast(const std::shared_ptr<Node>& arg,
set_value_type_checked(make_shared<TensorViewType>(input.get_element_type(), m_shape)); set_value_type_checked(make_shared<TensorViewType>(input.get_element_type(), m_shape));
} }
void op::Broadcast::generate_adjoints(autodiff::Adjoints& adjoints, shared_ptr<Node> op::Broadcast::copy_with_new_args(const NodeVector& new_args) const
const std::shared_ptr<Node>& delta) {
if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Broadcast>(new_args.at(0), m_shape, m_broadcast_axes);
}
void op::Broadcast::generate_adjoints(autodiff::Adjoints& adjoints, const shared_ptr<Node>& delta)
{ {
auto x = get_input_op(0); auto x = get_input_op(0);
......
...@@ -38,14 +38,7 @@ namespace ngraph ...@@ -38,14 +38,7 @@ namespace ngraph
const AxisSet& broadcast_axes); const AxisSet& broadcast_axes);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override copy_with_new_args(const NodeVector& new_args) const override;
{
if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<Broadcast>(new_args.at(0), m_shape, m_broadcast_axes);
}
/// \return A set containing the indices of the broadcast axes (0-based). /// \return A set containing the indices of the broadcast axes (0-based).
const AxisSet& get_broadcast_axes() const { return m_broadcast_axes; } const AxisSet& get_broadcast_axes() const { return m_broadcast_axes; }
......
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "ngraph/op/ceiling.hpp"
using namespace std;
using namespace ngraph;
op::Ceiling::Ceiling(const shared_ptr<Node>& arg)
: UnaryElementwiseArithmetic("Ceiling", arg)
{
}
shared_ptr<Node> op::Ceiling::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Ceiling>(new_args.at(0));
}
...@@ -29,20 +29,10 @@ namespace ngraph ...@@ -29,20 +29,10 @@ namespace ngraph
/// \brief Constructs a ceiling operation. /// \brief Constructs a ceiling operation.
/// ///
/// \param arg Node that produces the input tensor. /// \param arg Node that produces the input tensor.
Ceiling(const std::shared_ptr<Node>& arg) Ceiling(const std::shared_ptr<Node>& arg);
: UnaryElementwiseArithmetic("Ceiling", arg)
{
}
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override copy_with_new_args(const NodeVector& new_args) const override;
{
if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<Ceiling>(new_args.at(0));
}
}; };
} }
} }
...@@ -75,7 +75,12 @@ op::Concat::Concat(const NodeVector& args, size_t concatenation_axis) ...@@ -75,7 +75,12 @@ op::Concat::Concat(const NodeVector& args, size_t concatenation_axis)
set_value_type_checked(make_shared<TensorViewType>(input_0_element_type, concatenated_shape)); set_value_type_checked(make_shared<TensorViewType>(input_0_element_type, concatenated_shape));
} }
void op::Concat::generate_adjoints(autodiff::Adjoints& adjoints, const std::shared_ptr<Node>& delta) shared_ptr<Node> op::Concat::copy_with_new_args(const NodeVector& new_args) const
{
return make_shared<Concat>(new_args, m_concatenation_axis);
}
void op::Concat::generate_adjoints(autodiff::Adjoints& adjoints, const shared_ptr<Node>& delta)
{ {
auto concat_result_shape = get_outputs().at(0).get_shape(); auto concat_result_shape = get_outputs().at(0).get_shape();
......
...@@ -35,10 +35,7 @@ namespace ngraph ...@@ -35,10 +35,7 @@ namespace ngraph
Concat(const NodeVector& args, size_t concatenation_axis); Concat(const NodeVector& args, size_t concatenation_axis);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override copy_with_new_args(const NodeVector& new_args) const override;
{
return std::make_shared<Concat>(new_args, m_concatenation_axis);
}
/// \return The concatenation axis. /// \return The concatenation axis.
size_t get_concatenation_axis() const { return m_concatenation_axis; } size_t get_concatenation_axis() const { return m_concatenation_axis; }
......
...@@ -25,14 +25,14 @@ using namespace ngraph; ...@@ -25,14 +25,14 @@ using namespace ngraph;
using namespace std; using namespace std;
template <typename T> template <typename T>
std::string to_cpp_string(T value) string to_cpp_string(T value)
{ {
string rc; string rc;
if (std::isnan(value)) if (isnan(value))
{ {
rc = "NAN"; rc = "NAN";
} }
else if (std::isinf(value)) else if (isinf(value))
{ {
if (value > 0) if (value > 0)
{ {
...@@ -56,11 +56,11 @@ op::Constant::~Constant() ...@@ -56,11 +56,11 @@ op::Constant::~Constant()
{ {
if (m_data) if (m_data)
{ {
ngraph::aligned_free(m_data); aligned_free(m_data);
} }
} }
std::vector<std::string> op::Constant::get_value_strings() const vector<string> op::Constant::get_value_strings() const
{ {
vector<string> rc; vector<string> rc;
...@@ -143,12 +143,21 @@ std::vector<std::string> op::Constant::get_value_strings() const ...@@ -143,12 +143,21 @@ std::vector<std::string> op::Constant::get_value_strings() const
} }
else else
{ {
throw std::runtime_error("unsupported type"); throw runtime_error("unsupported type");
} }
return rc; return rc;
} }
shared_ptr<Node> op::Constant::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 0)
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Constant>(m_element_type, m_shape, m_data);
}
// //
// We have to open up namespace blocks here to work around a problem with gcc: // We have to open up namespace blocks here to work around a problem with gcc:
// //
...@@ -159,11 +168,11 @@ namespace ngraph ...@@ -159,11 +168,11 @@ namespace ngraph
namespace op namespace op
{ {
template <> template <>
void Constant::write_to_buffer<std::string>(const element::Type& target_type, void Constant::write_to_buffer<string>(const element::Type& target_type,
const Shape& target_shape, const Shape& target_shape,
const std::vector<std::string>& source, const vector<string>& source,
void* target, void* target,
size_t target_element_count) size_t target_element_count)
{ {
} }
} }
......
...@@ -131,14 +131,7 @@ namespace ngraph ...@@ -131,14 +131,7 @@ namespace ngraph
} }
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override copy_with_new_args(const NodeVector& new_args) const override;
{
if (new_args.size() != 0)
{
throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<Constant>(m_element_type, m_shape, m_data);
}
/// \return The initialization literals for the tensor constant. /// \return The initialization literals for the tensor constant.
std::vector<std::string> get_value_strings() const; std::vector<std::string> get_value_strings() const;
......
...@@ -21,16 +21,24 @@ ...@@ -21,16 +21,24 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Convert::Convert(const std::shared_ptr<Node>& arg, const ngraph::element::Type& element_type) op::Convert::Convert(const shared_ptr<Node>& arg, const element::Type& element_type)
: UnaryElementwise("Convert", element_type, arg) : UnaryElementwise("Convert", element_type, arg)
, m_element_type(element_type) , m_element_type(element_type)
{ {
} }
void ngraph::op::Convert::generate_adjoints(autodiff::Adjoints& adjoints, shared_ptr<Node> op::Convert::copy_with_new_args(const NodeVector& new_args) const
const std::shared_ptr<Node>& delta) {
if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Convert>(new_args.at(0), m_element_type);
}
void op::Convert::generate_adjoints(autodiff::Adjoints& adjoints, const shared_ptr<Node>& delta)
{ {
auto x = get_input_op(0); auto x = get_input_op(0);
adjoints.add_delta(x, std::make_shared<op::Convert>(delta, x->get_element_type())); adjoints.add_delta(x, make_shared<op::Convert>(delta, x->get_element_type()));
} }
...@@ -34,14 +34,7 @@ namespace ngraph ...@@ -34,14 +34,7 @@ namespace ngraph
Convert(const std::shared_ptr<Node>& arg, const ngraph::element::Type& element_type); Convert(const std::shared_ptr<Node>& arg, const ngraph::element::Type& element_type);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override copy_with_new_args(const NodeVector& new_args) const override;
{
if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<Convert>(new_args.at(0), m_element_type);
}
const element::Type& get_convert_element_type() const { return m_element_type; } const element::Type& get_convert_element_type() const { return m_element_type; }
protected: protected:
......
This diff is collapsed.
...@@ -19,10 +19,26 @@ ...@@ -19,10 +19,26 @@
#include "ngraph/op/negative.hpp" #include "ngraph/op/negative.hpp"
#include "ngraph/op/sin.hpp" #include "ngraph/op/sin.hpp"
void ngraph::op::Cos::generate_adjoints(autodiff::Adjoints& adjoints, using namespace std;
const std::shared_ptr<Node>& delta) using namespace ngraph;
op::Cos::Cos(const shared_ptr<Node>& arg)
: UnaryElementwiseArithmetic("Cos", arg)
{
}
shared_ptr<Node> op::Cos::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Cos>(new_args.at(0));
}
void op::Cos::generate_adjoints(autodiff::Adjoints& adjoints, const shared_ptr<Node>& delta)
{ {
auto x = get_input_op(0); auto x = get_input_op(0);
adjoints.add_delta(x, -delta * (std::make_shared<op::Sin>(x))); adjoints.add_delta(x, -delta * (make_shared<op::Sin>(x)));
} }
...@@ -29,20 +29,10 @@ namespace ngraph ...@@ -29,20 +29,10 @@ namespace ngraph
/// \brief Constructs a cosine operation. /// \brief Constructs a cosine operation.
/// ///
/// \param arg Node that produces the input tensor. /// \param arg Node that produces the input tensor.
Cos(const std::shared_ptr<Node>& arg) Cos(const std::shared_ptr<Node>& arg);
: UnaryElementwiseArithmetic("Cos", arg)
{
}
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override copy_with_new_args(const NodeVector& new_args) const override;
{
if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<Cos>(new_args.at(0));
}
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
......
...@@ -18,10 +18,26 @@ ...@@ -18,10 +18,26 @@
#include "ngraph/op/multiply.hpp" #include "ngraph/op/multiply.hpp"
#include "ngraph/op/sinh.hpp" #include "ngraph/op/sinh.hpp"
void ngraph::op::Cosh::generate_adjoints(autodiff::Adjoints& adjoints, using namespace std;
const std::shared_ptr<Node>& delta) using namespace ngraph;
op::Cosh::Cosh(const shared_ptr<Node>& arg)
: UnaryElementwiseArithmetic("Cosh", arg)
{
}
shared_ptr<Node> op::Cosh::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Cosh>(new_args.at(0));
}
void op::Cosh::generate_adjoints(autodiff::Adjoints& adjoints, const shared_ptr<Node>& delta)
{ {
auto x = get_input_op(0); auto x = get_input_op(0);
adjoints.add_delta(x, delta * (std::make_shared<op::Sinh>(x))); adjoints.add_delta(x, delta * (make_shared<op::Sinh>(x)));
} }
...@@ -29,20 +29,10 @@ namespace ngraph ...@@ -29,20 +29,10 @@ namespace ngraph
/// \brief Constructs a hyperbolic cosine operation. /// \brief Constructs a hyperbolic cosine operation.
/// ///
/// \param arg Node that produces the input tensor. /// \param arg Node that produces the input tensor.
Cosh(const std::shared_ptr<Node>& arg) Cosh(const std::shared_ptr<Node>& arg);
: UnaryElementwiseArithmetic("Cosh", arg)
{
}
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override copy_with_new_args(const NodeVector& new_args) const override;
{
if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<Cosh>(new_args.at(0));
}
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
......
...@@ -18,8 +18,24 @@ ...@@ -18,8 +18,24 @@
#include "ngraph/op/multiply.hpp" #include "ngraph/op/multiply.hpp"
#include "ngraph/op/negative.hpp" #include "ngraph/op/negative.hpp"
void ngraph::op::Divide::generate_adjoints(autodiff::Adjoints& adjoints, using namespace std;
const std::shared_ptr<Node>& delta) using namespace ngraph;
op::Divide::Divide(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg1)
: BinaryElementwiseArithmetic("Divide", arg0, arg1)
{
}
shared_ptr<Node> op::Divide::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 2)
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Divide>(new_args.at(0), new_args.at(1));
}
void op::Divide::generate_adjoints(autodiff::Adjoints& adjoints, const shared_ptr<Node>& delta)
{ {
auto x = get_input_op(0); auto x = get_input_op(0);
auto y = get_input_op(1); auto y = get_input_op(1);
...@@ -27,3 +43,8 @@ void ngraph::op::Divide::generate_adjoints(autodiff::Adjoints& adjoints, ...@@ -27,3 +43,8 @@ void ngraph::op::Divide::generate_adjoints(autodiff::Adjoints& adjoints,
adjoints.add_delta(x, delta / y); adjoints.add_delta(x, delta / y);
adjoints.add_delta(y, -delta * shared_from_this() / y); adjoints.add_delta(y, -delta * shared_from_this() / y);
} }
shared_ptr<Node> ngraph::operator/(const shared_ptr<Node> arg0, const shared_ptr<Node> arg1)
{
return make_shared<op::Divide>(arg0, arg1);
}
...@@ -30,28 +30,16 @@ namespace ngraph ...@@ -30,28 +30,16 @@ namespace ngraph
/// ///
/// \param arg0 Node that produces the first input tensor. /// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor. /// \param arg1 Node that produces the second input tensor.
Divide(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) Divide(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1);
: BinaryElementwiseArithmetic("Divide", arg0, arg1)
{
}
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override copy_with_new_args(const NodeVector& new_args) const override;
{
if (new_args.size() != 2)
{
throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<Divide>(new_args.at(0), new_args.at(1));
}
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const std::shared_ptr<Node>& delta) override; const std::shared_ptr<Node>& delta) override;
}; };
} }
inline std::shared_ptr<ngraph::Node> operator/(const std::shared_ptr<ngraph::Node> arg0,
const std::shared_ptr<ngraph::Node> arg1) std::shared_ptr<ngraph::Node> operator/(const std::shared_ptr<ngraph::Node> arg0,
{ const std::shared_ptr<ngraph::Node> arg1);
return std::make_shared<ngraph::op::Divide>(arg0, arg1);
}
} }
...@@ -33,8 +33,7 @@ using namespace ngraph; ...@@ -33,8 +33,7 @@ using namespace ngraph;
// Helper function to compute the number of dot axes according to default behavior when // Helper function to compute the number of dot axes according to default behavior when
// they are not specified. // they are not specified.
// //
size_t default_reduction_axes_count(const std::shared_ptr<Node>& arg0, size_t default_reduction_axes_count(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg1)
const std::shared_ptr<Node>& arg1)
{ {
if (arg0->get_shape().size() == 0 || arg1->get_shape().size() == 0) if (arg0->get_shape().size() == 0 || arg1->get_shape().size() == 0)
{ {
...@@ -46,13 +45,13 @@ size_t default_reduction_axes_count(const std::shared_ptr<Node>& arg0, ...@@ -46,13 +45,13 @@ size_t default_reduction_axes_count(const std::shared_ptr<Node>& arg0,
} }
} }
op::Dot::Dot(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) op::Dot::Dot(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg1)
: Dot(arg0, arg1, default_reduction_axes_count(arg0, arg1)) : Dot(arg0, arg1, default_reduction_axes_count(arg0, arg1))
{ {
} }
op::Dot::Dot(const std::shared_ptr<Node>& arg0, op::Dot::Dot(const shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1, const shared_ptr<Node>& arg1,
size_t reduction_axes_count) size_t reduction_axes_count)
: RequiresTensorViewArgs("Dot", {arg0, arg1}) : RequiresTensorViewArgs("Dot", {arg0, arg1})
, m_reduction_axes_count(reduction_axes_count) , m_reduction_axes_count(reduction_axes_count)
...@@ -88,19 +87,18 @@ op::Dot::Dot(const std::shared_ptr<Node>& arg0, ...@@ -88,19 +87,18 @@ op::Dot::Dot(const std::shared_ptr<Node>& arg0,
Shape result_shape(input_0_shape.size() + input_1_shape.size() - 2 * reduction_axes_count); Shape result_shape(input_0_shape.size() + input_1_shape.size() - 2 * reduction_axes_count);
std::copy( copy(input_0_shape.begin(), input_0_shape.end() - reduction_axes_count, result_shape.begin());
input_0_shape.begin(), input_0_shape.end() - reduction_axes_count, result_shape.begin()); copy(input_1_shape.begin() + reduction_axes_count,
std::copy(input_1_shape.begin() + reduction_axes_count, input_1_shape.end(),
input_1_shape.end(), result_shape.begin() + (input_0_shape.size() - reduction_axes_count));
result_shape.begin() + (input_0_shape.size() - reduction_axes_count));
auto result_type = make_shared<TensorViewType>(input_0.get_element_type(), result_shape); auto result_type = make_shared<TensorViewType>(input_0.get_element_type(), result_shape);
set_value_type_checked(result_type); set_value_type_checked(result_type);
} }
std::shared_ptr<op::Reshape> make_reshape_axes_to_front(const std::shared_ptr<Node>& n, shared_ptr<op::Reshape> make_reshape_axes_to_front(const shared_ptr<Node>& n,
const Shape& front_shape, const Shape& front_shape,
const Shape& back_shape) const Shape& back_shape)
{ {
AxisVector input_order; AxisVector input_order;
Shape output_shape; Shape output_shape;
...@@ -120,7 +118,7 @@ std::shared_ptr<op::Reshape> make_reshape_axes_to_front(const std::shared_ptr<No ...@@ -120,7 +118,7 @@ std::shared_ptr<op::Reshape> make_reshape_axes_to_front(const std::shared_ptr<No
return make_shared<op::Reshape>(n, input_order, output_shape); return make_shared<op::Reshape>(n, input_order, output_shape);
} }
void op::Dot::generate_adjoints(autodiff::Adjoints& adjoints, const std::shared_ptr<Node>& delta) void op::Dot::generate_adjoints(autodiff::Adjoints& adjoints, const shared_ptr<Node>& delta)
{ {
auto x = get_inputs().at(0).get_output().get_node(); auto x = get_inputs().at(0).get_output().get_node();
auto y = get_inputs().at(1).get_output().get_node(); auto y = get_inputs().at(1).get_output().get_node();
......
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "ngraph/op/equal.hpp"
using namespace std;
using namespace ngraph;
op::Equal::Equal(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg1)
: BinaryElementwiseComparison("Equal", arg0, arg1)
{
}
shared_ptr<Node> op::Equal::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 2)
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Equal>(new_args.at(0), new_args.at(1));
}
...@@ -43,20 +43,10 @@ namespace ngraph ...@@ -43,20 +43,10 @@ namespace ngraph
/// ///
/// \param arg0 Node that produces the first input tensor. /// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor. /// \param arg1 Node that produces the second input tensor.
Equal(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) Equal(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1);
: BinaryElementwiseComparison("Equal", arg0, arg1)
{
}
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override copy_with_new_args(const NodeVector& new_args) const override;
{
if (new_args.size() != 2)
{
throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<Equal>(new_args.at(0), new_args.at(1));
}
}; };
} }
} }
...@@ -17,8 +17,24 @@ ...@@ -17,8 +17,24 @@
#include "ngraph/op/exp.hpp" #include "ngraph/op/exp.hpp"
#include "ngraph/op/multiply.hpp" #include "ngraph/op/multiply.hpp"
void ngraph::op::Exp::generate_adjoints(autodiff::Adjoints& adjoints, using namespace std;
const std::shared_ptr<Node>& delta) using namespace ngraph;
op::Exp::Exp(const shared_ptr<Node>& arg)
: UnaryElementwiseArithmetic("Exp", arg)
{
}
shared_ptr<Node> op::Exp::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Exp>(new_args.at(0));
}
void op::Exp::generate_adjoints(autodiff::Adjoints& adjoints, const shared_ptr<Node>& delta)
{ {
auto x = get_input_op(0); auto x = get_input_op(0);
......
...@@ -29,20 +29,10 @@ namespace ngraph ...@@ -29,20 +29,10 @@ namespace ngraph
/// \brief Constructs an exponential operation. /// \brief Constructs an exponential operation.
/// ///
/// \param arg Node that produces the input tensor. /// \param arg Node that produces the input tensor.
Exp(const std::shared_ptr<Node>& arg) Exp(const std::shared_ptr<Node>& arg);
: UnaryElementwiseArithmetic("Exp", arg)
{
}
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override copy_with_new_args(const NodeVector& new_args) const override;
{
if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<Exp>(new_args.at(0));
}
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const std::shared_ptr<Node>& delta) override; const std::shared_ptr<Node>& delta) override;
......
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "ngraph/op/floor.hpp"
using namespace std;
using namespace ngraph;
op::Floor::Floor(const shared_ptr<Node>& arg)
: UnaryElementwiseArithmetic("Floor", arg)
{
}
shared_ptr<Node> op::Floor::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Floor>(new_args.at(0));
}
...@@ -29,20 +29,10 @@ namespace ngraph ...@@ -29,20 +29,10 @@ namespace ngraph
/// \brief Constructs a floor operation. /// \brief Constructs a floor operation.
/// ///
/// \param arg Node that produces the input tensor. /// \param arg Node that produces the input tensor.
Floor(const std::shared_ptr<Node>& arg) Floor(const std::shared_ptr<Node>& arg);
: UnaryElementwiseArithmetic("Floor", arg)
{
}
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override copy_with_new_args(const NodeVector& new_args) const override;
{
if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<Floor>(new_args.at(0));
}
}; };
} }
} }
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::FunctionCall::FunctionCall(std::shared_ptr<Function> function, const NodeVector& args) op::FunctionCall::FunctionCall(shared_ptr<Function> function, const NodeVector& args)
: Node("FunctionCall", args) : Node("FunctionCall", args)
, m_function(function) , m_function(function)
{ {
...@@ -48,3 +48,14 @@ op::FunctionCall::FunctionCall(std::shared_ptr<Function> function, const NodeVec ...@@ -48,3 +48,14 @@ op::FunctionCall::FunctionCall(std::shared_ptr<Function> function, const NodeVec
add_output(function->get_output_element_type(i), function->get_output_shape(i)); add_output(function->get_output_element_type(i), function->get_output_shape(i));
} }
} }
shared_ptr<Node> op::FunctionCall::copy_with_new_args(const NodeVector& new_args) const
{
return make_shared<FunctionCall>(m_function, new_args);
}
/// \return A singleton vector containing the function to be called.
vector<shared_ptr<Function>> op::FunctionCall::get_functions() const
{
return vector<shared_ptr<Function>>{m_function};
}
...@@ -33,16 +33,10 @@ namespace ngraph ...@@ -33,16 +33,10 @@ namespace ngraph
FunctionCall(std::shared_ptr<Function> function, const NodeVector& args); FunctionCall(std::shared_ptr<Function> function, const NodeVector& args);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override copy_with_new_args(const NodeVector& new_args) const override;
{
return std::make_shared<FunctionCall>(m_function, new_args);
}
/// \return A singleton vector containing the function to be called. /// \return A singleton vector containing the function to be called.
std::vector<std::shared_ptr<Function>> get_functions() const override std::vector<std::shared_ptr<Function>> get_functions() const override;
{
return std::vector<std::shared_ptr<Function>>{m_function};
}
protected: protected:
std::shared_ptr<Function> m_function; std::shared_ptr<Function> m_function;
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::GetOutputElement::GetOutputElement(const std::shared_ptr<Node>& arg, size_t n) op::GetOutputElement::GetOutputElement(const shared_ptr<Node>& arg, size_t n)
: Node("GetOutputElement", {arg}) : Node("GetOutputElement", {arg})
, m_n{n} , m_n{n}
{ {
...@@ -32,3 +32,29 @@ op::GetOutputElement::GetOutputElement(const std::shared_ptr<Node>& arg, size_t ...@@ -32,3 +32,29 @@ op::GetOutputElement::GetOutputElement(const std::shared_ptr<Node>& arg, size_t
set_value_type_checked(arg->get_output_element_type(n), arg->get_output_shape(n)); set_value_type_checked(arg->get_output_element_type(n), arg->get_output_shape(n));
} }
shared_ptr<Node> op::GetOutputElement::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<GetOutputElement>(new_args.at(0), m_n);
}
NodeVector op::GetOutputElement::get_input_ops()
{
return NodeVector{get_inputs().at(0).get_output().get_node()};
}
void op::GetOutputElement::generate_adjoints(autodiff::Adjoints& adjoints,
const shared_ptr<Node>& delta)
{
//Filter out updates(deltas) from mean and variance (for batchnorm)
//as dinput is the only update required.
//This logic needs to be generalized as new multi-output ops are introduced
if (get_n() == 0)
{
adjoints.add_delta(get_inputs().at(0).get_output().get_node(), delta);
}
}
...@@ -51,34 +51,15 @@ namespace ngraph ...@@ -51,34 +51,15 @@ namespace ngraph
GetOutputElement(const std::shared_ptr<Node>& arg, size_t n); GetOutputElement(const std::shared_ptr<Node>& arg, size_t n);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override copy_with_new_args(const NodeVector& new_args) const override;
{
if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<GetOutputElement>(new_args.at(0), m_n);
}
/// \return The index of the tuple element to get. /// \return The index of the tuple element to get.
size_t get_n() const { return m_n; } size_t get_n() const { return m_n; }
virtual NodeVector get_input_ops() override virtual NodeVector get_input_ops() override;
{
return NodeVector{get_inputs().at(0).get_output().get_node()};
}
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const std::shared_ptr<Node>& delta) override const std::shared_ptr<Node>& delta) override;
{
//Filter out updates(deltas) from mean and variance (for batchnorm)
//as dinput is the only update required.
//This logic needs to be generalized as new multi-output ops are introduced
if (get_n() == 0)
{
adjoints.add_delta(get_inputs().at(0).get_output().get_node(), delta);
}
}
size_t m_n; size_t m_n;
}; };
} }
......
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "ngraph/op/greater.hpp"
using namespace std;
using namespace ngraph;
op::Greater::Greater(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg1)
: BinaryElementwiseComparison("Greater", arg0, arg1)
{
}
shared_ptr<Node> op::Greater::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 2)
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Greater>(new_args.at(0), new_args.at(1));
}
...@@ -30,20 +30,10 @@ namespace ngraph ...@@ -30,20 +30,10 @@ namespace ngraph
/// ///
/// \param arg0 Node that produces the first input tensor. /// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor. /// \param arg1 Node that produces the second input tensor.
Greater(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) Greater(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1);
: BinaryElementwiseComparison("Greater", arg0, arg1)
{
}
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override copy_with_new_args(const NodeVector& new_args) const override;
{
if (new_args.size() != 2)
{
throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<Greater>(new_args.at(0), new_args.at(1));
}
}; };
} }
} }
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "ngraph/op/greater_eq.hpp"
using namespace std;
using namespace ngraph;
op::GreaterEq::GreaterEq(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg1)
: BinaryElementwiseComparison("GreaterEq", arg0, arg1)
{
}
shared_ptr<Node> op::GreaterEq::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 2)
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<GreaterEq>(new_args.at(0), new_args.at(1));
}
...@@ -30,20 +30,10 @@ namespace ngraph ...@@ -30,20 +30,10 @@ namespace ngraph
/// ///
/// \param arg0 Node that produces the first input tensor. /// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor. /// \param arg1 Node that produces the second input tensor.
GreaterEq(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) GreaterEq(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1);
: BinaryElementwiseComparison("GreaterEq", arg0, arg1)
{
}
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override copy_with_new_args(const NodeVector& new_args) const override;
{
if (new_args.size() != 2)
{
throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<GreaterEq>(new_args.at(0), new_args.at(1));
}
}; };
} }
} }
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "ngraph/op/less.hpp"
using namespace std;
using namespace ngraph;
op::Less::Less(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg1)
: BinaryElementwiseComparison("Less", arg0, arg1)
{
}
shared_ptr<Node> op::Less::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 2)
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Less>(new_args.at(0), new_args.at(1));
}
...@@ -30,20 +30,10 @@ namespace ngraph ...@@ -30,20 +30,10 @@ namespace ngraph
/// ///
/// \param arg0 Node that produces the first input tensor. /// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor. /// \param arg1 Node that produces the second input tensor.
Less(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) Less(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1);
: BinaryElementwiseComparison("Less", arg0, arg1)
{
}
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override copy_with_new_args(const NodeVector& new_args) const override;
{
if (new_args.size() != 2)
{
throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<Less>(new_args.at(0), new_args.at(1));
}
}; };
} }
} }
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "ngraph/op/less_eq.hpp"
using namespace std;
using namespace ngraph;
op::LessEq::LessEq(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg1)
: BinaryElementwiseComparison("LessEq", arg0, arg1)
{
}
shared_ptr<Node> op::LessEq::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 2)
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<LessEq>(new_args.at(0), new_args.at(1));
}
...@@ -30,20 +30,10 @@ namespace ngraph ...@@ -30,20 +30,10 @@ namespace ngraph
/// ///
/// \param arg0 Node that produces the first input tensor. /// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor. /// \param arg1 Node that produces the second input tensor.
LessEq(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) LessEq(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1);
: BinaryElementwiseComparison("LessEq", arg0, arg1)
{
}
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override copy_with_new_args(const NodeVector& new_args) const override;
{
if (new_args.size() != 2)
{
throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<LessEq>(new_args.at(0), new_args.at(1));
}
}; };
} }
} }
...@@ -17,8 +17,24 @@ ...@@ -17,8 +17,24 @@
#include "ngraph/op/log.hpp" #include "ngraph/op/log.hpp"
#include "ngraph/op/divide.hpp" #include "ngraph/op/divide.hpp"
void ngraph::op::Log::generate_adjoints(autodiff::Adjoints& adjoints, using namespace std;
const std::shared_ptr<Node>& delta) using namespace ngraph;
op::Log::Log(const shared_ptr<Node>& arg)
: UnaryElementwiseArithmetic("Log", arg)
{
}
shared_ptr<Node> op::Log::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Log>(new_args.at(0));
}
void op::Log::generate_adjoints(autodiff::Adjoints& adjoints, const shared_ptr<Node>& delta)
{ {
auto x = get_input_op(0); auto x = get_input_op(0);
......
...@@ -29,20 +29,10 @@ namespace ngraph ...@@ -29,20 +29,10 @@ namespace ngraph
/// \brief Constructs a natural log operation. /// \brief Constructs a natural log operation.
/// ///
/// \param arg Node that produces the input tensor. /// \param arg Node that produces the input tensor.
Log(const std::shared_ptr<Node>& arg) Log(const std::shared_ptr<Node>& arg);
: UnaryElementwiseArithmetic("Log", arg)
{
}
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override copy_with_new_args(const NodeVector& new_args) const override;
{
if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<Log>(new_args.at(0));
}
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const std::shared_ptr<Node>& delta) override; const std::shared_ptr<Node>& delta) override;
......
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "ngraph/op/max.hpp"
using namespace std;
using namespace ngraph;
op::Max::Max(const shared_ptr<Node>& arg, const AxisSet& reduction_axes)
: ArithmeticReduction("Max", arg, reduction_axes)
{
}
shared_ptr<Node> op::Max::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Max>(new_args.at(0), m_reduction_axes);
}
...@@ -52,20 +52,10 @@ namespace ngraph ...@@ -52,20 +52,10 @@ namespace ngraph
/// ///
/// \param arg The tensor view to be reduced. /// \param arg The tensor view to be reduced.
/// \param reduction_axes The axis positions (0-based) to be eliminated. /// \param reduction_axes The axis positions (0-based) to be eliminated.
Max(const std::shared_ptr<Node>& arg, const AxisSet& reduction_axes) Max(const std::shared_ptr<Node>& arg, const AxisSet& reduction_axes);
: ArithmeticReduction("Max", arg, reduction_axes)
{
}
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override copy_with_new_args(const NodeVector& new_args) const override;
{
if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<Max>(new_args.at(0), m_reduction_axes);
}
}; };
} }
} }
...@@ -25,7 +25,7 @@ ...@@ -25,7 +25,7 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::MaxPool::MaxPool(const std::shared_ptr<Node>& arg, op::MaxPool::MaxPool(const shared_ptr<Node>& arg,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Shape& padding_below, const Shape& padding_below,
...@@ -152,12 +152,12 @@ op::MaxPool::MaxPool(const std::shared_ptr<Node>& arg, ...@@ -152,12 +152,12 @@ op::MaxPool::MaxPool(const std::shared_ptr<Node>& arg,
Shape result_shape(1 + 1 + spatial_dimension_count); Shape result_shape(1 + 1 + spatial_dimension_count);
result_shape[0] = batch_size; result_shape[0] = batch_size;
result_shape[1] = channel_count; result_shape[1] = channel_count;
std::copy(output_item_shape.begin(), output_item_shape.end(), result_shape.begin() + 2); copy(output_item_shape.begin(), output_item_shape.end(), result_shape.begin() + 2);
set_value_type_checked(get_input_element_type(0), result_shape); set_value_type_checked(get_input_element_type(0), result_shape);
} }
static Shape default_padding(const std::shared_ptr<Node>& arg) static Shape default_padding(const shared_ptr<Node>& arg)
{ {
if (arg->get_outputs().size() != 1) if (arg->get_outputs().size() != 1)
{ {
...@@ -175,7 +175,7 @@ static Shape default_padding(const std::shared_ptr<Node>& arg) ...@@ -175,7 +175,7 @@ static Shape default_padding(const std::shared_ptr<Node>& arg)
return Shape(arg_shape.size() - 2, 0); return Shape(arg_shape.size() - 2, 0);
} }
op::MaxPool::MaxPool(const std::shared_ptr<Node>& arg, op::MaxPool::MaxPool(const shared_ptr<Node>& arg,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides) const Strides& window_movement_strides)
: MaxPool( : MaxPool(
...@@ -183,7 +183,7 @@ op::MaxPool::MaxPool(const std::shared_ptr<Node>& arg, ...@@ -183,7 +183,7 @@ op::MaxPool::MaxPool(const std::shared_ptr<Node>& arg,
{ {
} }
static Strides default_strides(const std::shared_ptr<Node>& arg) static Strides default_strides(const shared_ptr<Node>& arg)
{ {
if (arg->get_outputs().size() != 1) if (arg->get_outputs().size() != 1)
{ {
...@@ -201,18 +201,31 @@ static Strides default_strides(const std::shared_ptr<Node>& arg) ...@@ -201,18 +201,31 @@ static Strides default_strides(const std::shared_ptr<Node>& arg)
return Strides(arg_shape.size() - 2, 1); return Strides(arg_shape.size() - 2, 1);
} }
op::MaxPool::MaxPool(const std::shared_ptr<Node>& arg, const Shape& window_shape) op::MaxPool::MaxPool(const shared_ptr<Node>& arg, const Shape& window_shape)
: MaxPool(arg, window_shape, default_strides(arg), default_padding(arg), default_padding(arg)) : MaxPool(arg, window_shape, default_strides(arg), default_padding(arg), default_padding(arg))
{ {
} }
op::MaxPoolBackprop::MaxPoolBackprop(const std::shared_ptr<Node>& arg_forward, shared_ptr<Node> op::MaxPool::copy_with_new_args(const NodeVector& new_args) const
const std::shared_ptr<Node>& delta, {
if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<MaxPool>(new_args.at(0),
m_window_shape,
m_window_movement_strides,
m_padding_below,
m_padding_above);
}
op::MaxPoolBackprop::MaxPoolBackprop(const shared_ptr<Node>& arg_forward,
const shared_ptr<Node>& delta,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Shape& padding_below, const Shape& padding_below,
const Shape& padding_above, const Shape& padding_above,
const std::shared_ptr<op::MaxPool>& forward_op) const shared_ptr<op::MaxPool>& forward_op)
: RequiresTensorViewArgs("MaxPoolBackprop", {arg_forward, delta}) : RequiresTensorViewArgs("MaxPoolBackprop", {arg_forward, delta})
, m_window_shape(window_shape) , m_window_shape(window_shape)
, m_window_movement_strides(window_movement_strides) , m_window_movement_strides(window_movement_strides)
...@@ -350,7 +363,7 @@ op::MaxPoolBackprop::MaxPoolBackprop(const std::shared_ptr<Node>& arg_forward, ...@@ -350,7 +363,7 @@ op::MaxPoolBackprop::MaxPoolBackprop(const std::shared_ptr<Node>& arg_forward,
Shape forward_result_shape(1 + 1 + spatial_dimension_count); Shape forward_result_shape(1 + 1 + spatial_dimension_count);
forward_result_shape[0] = batch_size; forward_result_shape[0] = batch_size;
forward_result_shape[1] = channel_count; forward_result_shape[1] = channel_count;
std::copy(output_item_shape.begin(), output_item_shape.end(), forward_result_shape.begin() + 2); copy(output_item_shape.begin(), output_item_shape.end(), forward_result_shape.begin() + 2);
if (forward_result_shape != delta_shape) if (forward_result_shape != delta_shape)
{ {
...@@ -360,23 +373,38 @@ op::MaxPoolBackprop::MaxPoolBackprop(const std::shared_ptr<Node>& arg_forward, ...@@ -360,23 +373,38 @@ op::MaxPoolBackprop::MaxPoolBackprop(const std::shared_ptr<Node>& arg_forward,
set_value_type_checked(get_input_element_type(0), arg_forward_shape); set_value_type_checked(get_input_element_type(0), arg_forward_shape);
} }
std::shared_ptr<op::MaxPool> op::MaxPoolBackprop::get_forward_op() const shared_ptr<op::MaxPool> op::MaxPoolBackprop::get_forward_op() const
{ {
return m_forward_op.lock(); return m_forward_op.lock();
} }
void op::MaxPool::generate_adjoints(autodiff::Adjoints& adjoints, shared_ptr<Node> op::MaxPoolBackprop::copy_with_new_args(const NodeVector& new_args) const
const std::shared_ptr<Node>& delta) {
if (new_args.size() != 2)
{
throw ngraph_error("Incorrect number of new arguments");
}
MaxPoolBackprop* mpbp = new MaxPoolBackprop(new_args.at(0),
new_args.at(1),
m_window_shape,
m_window_movement_strides,
m_padding_below,
m_padding_above);
return shared_ptr<op::MaxPoolBackprop>(mpbp);
}
void op::MaxPool::generate_adjoints(autodiff::Adjoints& adjoints, const shared_ptr<Node>& delta)
{ {
auto operand = get_input_op(0); auto operand = get_input_op(0);
auto backprop = auto backprop =
std::make_shared<op::MaxPoolBackprop>(operand, make_shared<op::MaxPoolBackprop>(operand,
delta, delta,
m_window_shape, m_window_shape,
m_window_movement_strides, m_window_movement_strides,
m_padding_below, m_padding_below,
m_padding_above, m_padding_above,
static_pointer_cast<op::MaxPool>(shared_from_this())); static_pointer_cast<op::MaxPool>(shared_from_this()));
adjoints.add_delta(operand, backprop); adjoints.add_delta(operand, backprop);
} }
...@@ -73,18 +73,7 @@ namespace ngraph ...@@ -73,18 +73,7 @@ namespace ngraph
MaxPool(const std::shared_ptr<Node>& arg, const Shape& window_shape); MaxPool(const std::shared_ptr<Node>& arg, const Shape& window_shape);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override copy_with_new_args(const NodeVector& new_args) const override;
{
if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<MaxPool>(new_args.at(0),
m_window_shape,
m_window_movement_strides,
m_padding_below,
m_padding_above);
}
/// \return The window shape. /// \return The window shape.
const Shape& get_window_shape() const { return m_window_shape; } const Shape& get_window_shape() const { return m_window_shape; }
...@@ -116,21 +105,7 @@ namespace ngraph ...@@ -116,21 +105,7 @@ namespace ngraph
const std::shared_ptr<op::MaxPool>& forward_op = nullptr); const std::shared_ptr<op::MaxPool>& forward_op = nullptr);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override copy_with_new_args(const NodeVector& new_args) const override;
{
if (new_args.size() != 2)
{
throw ngraph_error("Incorrect number of new arguments");
}
MaxPoolBackprop* mpbp = new MaxPoolBackprop(new_args.at(0),
new_args.at(1),
m_window_shape,
m_window_movement_strides,
m_padding_below,
m_padding_above);
return std::shared_ptr<op::MaxPoolBackprop>(mpbp);
}
const Shape& get_window_shape() const { return m_window_shape; } const Shape& get_window_shape() const { return m_window_shape; }
const Strides& get_window_movement_strides() const { return m_window_movement_strides; } const Strides& get_window_movement_strides() const { return m_window_movement_strides; }
......
...@@ -25,8 +25,21 @@ ...@@ -25,8 +25,21 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
void ngraph::op::Maximum::generate_adjoints(autodiff::Adjoints& adjoints, op::Maximum::Maximum(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg1)
const std::shared_ptr<Node>& delta) : BinaryElementwiseArithmetic("Maximum", arg0, arg1)
{
}
shared_ptr<Node> op::Maximum::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 2)
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Maximum>(new_args.at(0), new_args.at(1));
}
void op::Maximum::generate_adjoints(autodiff::Adjoints& adjoints, const shared_ptr<Node>& delta)
{ {
auto x = get_input_op(0); auto x = get_input_op(0);
auto y = get_input_op(1); auto y = get_input_op(1);
......
...@@ -30,20 +30,10 @@ namespace ngraph ...@@ -30,20 +30,10 @@ namespace ngraph
/// ///
/// \param arg0 Node that produces the first input tensor. /// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor. /// \param arg1 Node that produces the second input tensor.
Maximum(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) Maximum(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1);
: BinaryElementwiseArithmetic("Maximum", arg0, arg1)
{
}
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override copy_with_new_args(const NodeVector& new_args) const override;
{
if (new_args.size() != 2)
{
throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<Maximum>(new_args.at(0), new_args.at(1));
}
virtual bool is_commutative() override { return true; } virtual bool is_commutative() override { return true; }
protected: protected:
......
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "ngraph/op/min.hpp"
using namespace std;
using namespace ngraph;
op::Min::Min(const shared_ptr<Node>& arg, const AxisSet& reduction_axes)
: ArithmeticReduction("Min", arg, reduction_axes)
{
}
shared_ptr<Node> op::Min::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Min>(new_args.at(0), m_reduction_axes);
}
...@@ -52,20 +52,10 @@ namespace ngraph ...@@ -52,20 +52,10 @@ namespace ngraph
/// ///
/// \param arg The tensor view to be reduced. /// \param arg The tensor view to be reduced.
/// \param reduction_axes The axis positions (0-based) to be eliminated. /// \param reduction_axes The axis positions (0-based) to be eliminated.
Min(const std::shared_ptr<Node>& arg, const AxisSet& reduction_axes) Min(const std::shared_ptr<Node>& arg, const AxisSet& reduction_axes);
: ArithmeticReduction("Min", arg, reduction_axes)
{
}
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override copy_with_new_args(const NodeVector& new_args) const override;
{
if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<Min>(new_args.at(0), m_reduction_axes);
}
}; };
} }
} }
...@@ -25,8 +25,21 @@ ...@@ -25,8 +25,21 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
void ngraph::op::Minimum::generate_adjoints(autodiff::Adjoints& adjoints, op::Minimum::Minimum(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg1)
const std::shared_ptr<Node>& delta) : BinaryElementwiseArithmetic("Minimum", arg0, arg1)
{
}
shared_ptr<Node> op::Minimum::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 2)
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Minimum>(new_args.at(0), new_args.at(1));
}
void op::Minimum::generate_adjoints(autodiff::Adjoints& adjoints, const shared_ptr<Node>& delta)
{ {
auto x = get_input_op(0); auto x = get_input_op(0);
auto y = get_input_op(1); auto y = get_input_op(1);
......
...@@ -30,20 +30,10 @@ namespace ngraph ...@@ -30,20 +30,10 @@ namespace ngraph
/// ///
/// \param arg0 Node that produces the first input tensor. /// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor. /// \param arg1 Node that produces the second input tensor.
Minimum(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) Minimum(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1);
: BinaryElementwiseArithmetic("Minimum", arg0, arg1)
{
}
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override copy_with_new_args(const NodeVector& new_args) const override;
{
if (new_args.size() != 2)
{
throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<Minimum>(new_args.at(0), new_args.at(1));
}
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
......
...@@ -16,8 +16,24 @@ ...@@ -16,8 +16,24 @@
#include "ngraph/op/multiply.hpp" #include "ngraph/op/multiply.hpp"
void ngraph::op::Multiply::generate_adjoints(autodiff::Adjoints& adjoints, using namespace std;
const std::shared_ptr<Node>& delta) using namespace ngraph;
op::Multiply::Multiply(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg1)
: BinaryElementwiseArithmetic("Multiply", arg0, arg1)
{
}
shared_ptr<Node> op::Multiply::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 2)
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Multiply>(new_args.at(0), new_args.at(1));
}
void op::Multiply::generate_adjoints(autodiff::Adjoints& adjoints, const shared_ptr<Node>& delta)
{ {
auto x = get_input_op(0); auto x = get_input_op(0);
auto y = get_input_op(1); auto y = get_input_op(1);
...@@ -25,3 +41,8 @@ void ngraph::op::Multiply::generate_adjoints(autodiff::Adjoints& adjoints, ...@@ -25,3 +41,8 @@ void ngraph::op::Multiply::generate_adjoints(autodiff::Adjoints& adjoints,
adjoints.add_delta(x, delta * y); adjoints.add_delta(x, delta * y);
adjoints.add_delta(y, x * delta); adjoints.add_delta(y, x * delta);
} }
shared_ptr<Node> ngraph::operator*(const shared_ptr<Node> arg0, const shared_ptr<Node> arg1)
{
return make_shared<op::Multiply>(arg0, arg1);
}
...@@ -30,20 +30,10 @@ namespace ngraph ...@@ -30,20 +30,10 @@ namespace ngraph
/// ///
/// \param arg0 Node that produces the first input tensor. /// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor. /// \param arg1 Node that produces the second input tensor.
Multiply(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) Multiply(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1);
: BinaryElementwiseArithmetic("Multiply", arg0, arg1)
{
}
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override copy_with_new_args(const NodeVector& new_args) const override;
{
if (new_args.size() != 2)
{
throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<Multiply>(new_args.at(0), new_args.at(1));
}
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
...@@ -52,9 +42,6 @@ namespace ngraph ...@@ -52,9 +42,6 @@ namespace ngraph
}; };
}; };
inline std::shared_ptr<ngraph::Node> operator*(const std::shared_ptr<ngraph::Node> arg0, std::shared_ptr<ngraph::Node> operator*(const std::shared_ptr<ngraph::Node> arg0,
const std::shared_ptr<ngraph::Node> arg1) const std::shared_ptr<ngraph::Node> arg1);
{
return std::make_shared<ngraph::op::Multiply>(arg0, arg1);
}
} }
...@@ -16,10 +16,31 @@ ...@@ -16,10 +16,31 @@
#include "ngraph/op/negative.hpp" #include "ngraph/op/negative.hpp"
void ngraph::op::Negative::generate_adjoints(autodiff::Adjoints& adjoints, using namespace std;
const std::shared_ptr<Node>& delta) using namespace ngraph;
op::Negative::Negative(const shared_ptr<Node>& arg)
: UnaryElementwiseArithmetic("Negative", arg)
{
}
shared_ptr<Node> op::Negative::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Negative>(new_args.at(0));
}
void op::Negative::generate_adjoints(autodiff::Adjoints& adjoints, const shared_ptr<Node>& delta)
{ {
auto x = get_input_op(0); auto x = get_input_op(0);
adjoints.add_delta(x, -delta); adjoints.add_delta(x, -delta);
} }
shared_ptr<Node> ngraph::operator-(const shared_ptr<Node> arg0)
{
return make_shared<op::Negative>(arg0);
}
...@@ -29,27 +29,14 @@ namespace ngraph ...@@ -29,27 +29,14 @@ namespace ngraph
/// \brief Constructs a negative operation. /// \brief Constructs a negative operation.
/// ///
/// \param arg Node that produces the input tensor. /// \param arg Node that produces the input tensor.
Negative(const std::shared_ptr<Node>& arg) Negative(const std::shared_ptr<Node>& arg);
: UnaryElementwiseArithmetic("Negative", arg)
{
}
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override copy_with_new_args(const NodeVector& new_args) const override;
{
if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<Negative>(new_args.at(0));
}
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const std::shared_ptr<Node>& delta) override; const std::shared_ptr<Node>& delta) override;
}; };
} }
inline std::shared_ptr<ngraph::Node> operator-(const std::shared_ptr<ngraph::Node> arg0) std::shared_ptr<ngraph::Node> operator-(const std::shared_ptr<ngraph::Node> arg0);
{
return std::make_shared<ngraph::op::Negative>(arg0);
}
} }
...@@ -24,3 +24,12 @@ op::Not::Not(const shared_ptr<Node>& arg) ...@@ -24,3 +24,12 @@ op::Not::Not(const shared_ptr<Node>& arg)
: UnaryElementwise("Not", arg->get_element_type(), arg) : UnaryElementwise("Not", arg->get_element_type(), arg)
{ {
} }
shared_ptr<Node> op::Not::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Not>(new_args.at(0));
}
...@@ -32,14 +32,7 @@ namespace ngraph ...@@ -32,14 +32,7 @@ namespace ngraph
Not(const std::shared_ptr<Node>& arg); Not(const std::shared_ptr<Node>& arg);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override copy_with_new_args(const NodeVector& new_args) const override;
{
if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<Not>(new_args.at(0));
}
}; };
} }
} }
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "ngraph/op/not_equal.hpp"
using namespace std;
using namespace ngraph;
op::NotEqual::NotEqual(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg1)
: BinaryElementwiseComparison("NotEqual", arg0, arg1)
{
}
shared_ptr<Node> op::NotEqual::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 2)
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<NotEqual>(new_args.at(0), new_args.at(1));
}
...@@ -30,20 +30,10 @@ namespace ngraph ...@@ -30,20 +30,10 @@ namespace ngraph
/// ///
/// \param arg0 Node that produces the first input tensor. /// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor. /// \param arg1 Node that produces the second input tensor.
NotEqual(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) NotEqual(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1);
: BinaryElementwiseComparison("NotEqual", arg0, arg1)
{
}
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override copy_with_new_args(const NodeVector& new_args) const override;
{
if (new_args.size() != 2)
{
throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<NotEqual>(new_args.at(0), new_args.at(1));
}
}; };
} }
} }
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::OneHot::OneHot(const std::shared_ptr<Node>& arg, const Shape& shape, size_t one_hot_axis) op::OneHot::OneHot(const shared_ptr<Node>& arg, const Shape& shape, size_t one_hot_axis)
: RequiresTensorViewArgs("OneHot", {arg}) : RequiresTensorViewArgs("OneHot", {arg})
, m_shape(shape) , m_shape(shape)
, m_one_hot_axis(one_hot_axis) , m_one_hot_axis(one_hot_axis)
...@@ -43,3 +43,12 @@ op::OneHot::OneHot(const std::shared_ptr<Node>& arg, const Shape& shape, size_t ...@@ -43,3 +43,12 @@ op::OneHot::OneHot(const std::shared_ptr<Node>& arg, const Shape& shape, size_t
set_value_type_checked(make_shared<TensorViewType>(input_element_type, shape)); set_value_type_checked(make_shared<TensorViewType>(input_element_type, shape));
} }
shared_ptr<Node> op::OneHot::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<OneHot>(new_args.at(0), m_shape, m_one_hot_axis);
}
...@@ -53,14 +53,7 @@ namespace ngraph ...@@ -53,14 +53,7 @@ namespace ngraph
OneHot(const std::shared_ptr<Node>& arg, const Shape& shape, size_t one_hot_axis); OneHot(const std::shared_ptr<Node>& arg, const Shape& shape, size_t one_hot_axis);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override copy_with_new_args(const NodeVector& new_args) const override;
{
if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<OneHot>(new_args.at(0), m_shape, m_one_hot_axis);
}
/// \return The index of the one-hot axis. /// \return The index of the one-hot axis.
size_t get_one_hot_axis() const { return m_one_hot_axis; } size_t get_one_hot_axis() const { return m_one_hot_axis; }
......
...@@ -20,8 +20,8 @@ ...@@ -20,8 +20,8 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Pad::Pad(const std::shared_ptr<Node>& arg, op::Pad::Pad(const shared_ptr<Node>& arg,
const std::shared_ptr<Node>& arg_pad_value, const shared_ptr<Node>& arg_pad_value,
const Shape& padding_below, const Shape& padding_below,
const Shape& padding_above, const Shape& padding_above,
const Shape& padding_interior) const Shape& padding_interior)
...@@ -70,13 +70,13 @@ op::Pad::Pad(const std::shared_ptr<Node>& arg, ...@@ -70,13 +70,13 @@ op::Pad::Pad(const std::shared_ptr<Node>& arg,
set_value_type_checked(get_input_element_type(0), result_shape); set_value_type_checked(get_input_element_type(0), result_shape);
} }
std::shared_ptr<Node> op::Pad::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Pad::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 2) if (new_args.size() != 2)
{ {
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
} }
return std::make_shared<Pad>( return make_shared<Pad>(
new_args.at(0), new_args.at(1), m_padding_below, m_padding_above, m_padding_interior); new_args.at(0), new_args.at(1), m_padding_below, m_padding_above, m_padding_interior);
} }
...@@ -118,7 +118,7 @@ std::shared_ptr<Node> op::Pad::copy_with_new_args(const NodeVector& new_args) co ...@@ -118,7 +118,7 @@ std::shared_ptr<Node> op::Pad::copy_with_new_args(const NodeVector& new_args) co
and push that back. and push that back.
*/ */
void op::Pad::generate_adjoints(autodiff::Adjoints& adjoints, const std::shared_ptr<Node>& delta) void op::Pad::generate_adjoints(autodiff::Adjoints& adjoints, const shared_ptr<Node>& delta)
{ {
throw std::invalid_argument("Autodiff is not yet implemented for Pad"); throw invalid_argument("Autodiff is not yet implemented for Pad");
} }
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Parameter::Parameter(const ngraph::element::Type& element_type, const Shape& shape) op::Parameter::Parameter(const element::Type& element_type, const Shape& shape)
: Op("Parameter", {}) : Op("Parameter", {})
{ {
add_output(element_type, shape); add_output(element_type, shape);
...@@ -37,7 +37,6 @@ shared_ptr<Node> op::Parameter::copy_with_new_args(const NodeVector& new_args) c ...@@ -37,7 +37,6 @@ shared_ptr<Node> op::Parameter::copy_with_new_args(const NodeVector& new_args) c
return make_shared<Parameter>(output.get_element_type(), output.get_shape()); return make_shared<Parameter>(output.get_element_type(), output.get_shape());
} }
void op::Parameter::generate_adjoints(autodiff::Adjoints& adjoints, void op::Parameter::generate_adjoints(autodiff::Adjoints& adjoints, const shared_ptr<Node>& delta)
const std::shared_ptr<Node>& delta)
{ {
} }
...@@ -19,13 +19,29 @@ ...@@ -19,13 +19,29 @@
#include "ngraph/op/log.hpp" #include "ngraph/op/log.hpp"
#include "ngraph/op/multiply.hpp" #include "ngraph/op/multiply.hpp"
void ngraph::op::Power::generate_adjoints(autodiff::Adjoints& adjoints, using namespace std;
const std::shared_ptr<Node>& delta) using namespace ngraph;
op::Power::Power(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg1)
: BinaryElementwiseArithmetic("Power", arg0, arg1)
{
}
shared_ptr<Node> op::Power::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 2)
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Power>(new_args.at(0), new_args.at(1));
}
void op::Power::generate_adjoints(autodiff::Adjoints& adjoints, const shared_ptr<Node>& delta)
{ {
auto x = get_input_op(0); auto x = get_input_op(0);
auto y = get_input_op(1); auto y = get_input_op(1);
auto log_x = std::make_shared<op::Log>(x); auto log_x = make_shared<op::Log>(x);
adjoints.add_delta(x, delta * y * shared_from_this() / x); adjoints.add_delta(x, delta * y * shared_from_this() / x);
adjoints.add_delta(y, delta * shared_from_this() * log_x); adjoints.add_delta(y, delta * shared_from_this() * log_x);
......
...@@ -43,20 +43,10 @@ namespace ngraph ...@@ -43,20 +43,10 @@ namespace ngraph
/// ///
/// \param arg0 Node that produces the first input tensor. /// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor. /// \param arg1 Node that produces the second input tensor.
Power(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) Power(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1);
: BinaryElementwiseArithmetic("Power", arg0, arg1)
{
}
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override copy_with_new_args(const NodeVector& new_args) const override;
{
if (new_args.size() != 2)
{
throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<Power>(new_args.at(0), new_args.at(1));
}
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
......
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "ngraph/op/product.hpp"
using namespace std;
using namespace ngraph;
op::Product::Product(const shared_ptr<Node>& arg, const AxisSet& reduction_axes)
: ArithmeticReduction("Product", arg, reduction_axes)
{
}
shared_ptr<Node> op::Product::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Product>(new_args.at(0), m_reduction_axes);
}
...@@ -82,20 +82,10 @@ namespace ngraph ...@@ -82,20 +82,10 @@ namespace ngraph
/// ///
/// \param arg The tensor view to be reduced. /// \param arg The tensor view to be reduced.
/// \param reduction_axes The axis positions (0-based) to be eliminated. /// \param reduction_axes The axis positions (0-based) to be eliminated.
Product(const std::shared_ptr<Node>& arg, const AxisSet& reduction_axes) Product(const std::shared_ptr<Node>& arg, const AxisSet& reduction_axes);
: ArithmeticReduction("Product", arg, reduction_axes)
{
}
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override copy_with_new_args(const NodeVector& new_args) const override;
{
if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<Product>(new_args.at(0), m_reduction_axes);
}
}; };
} }
} }
...@@ -20,9 +20,9 @@ ...@@ -20,9 +20,9 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Reduce::Reduce(const std::shared_ptr<Node>& arg_reductee, op::Reduce::Reduce(const shared_ptr<Node>& arg_reductee,
const std::shared_ptr<Node>& arg_init, const shared_ptr<Node>& arg_init,
const std::shared_ptr<Function>& reduction_function, const shared_ptr<Function>& reduction_function,
const AxisSet& reduction_axes) const AxisSet& reduction_axes)
: RequiresTensorViewArgs("Reduce", {arg_reductee, arg_init}) : RequiresTensorViewArgs("Reduce", {arg_reductee, arg_init})
, m_reduction_function(reduction_function) , m_reduction_function(reduction_function)
...@@ -92,3 +92,13 @@ op::Reduce::Reduce(const std::shared_ptr<Node>& arg_reductee, ...@@ -92,3 +92,13 @@ op::Reduce::Reduce(const std::shared_ptr<Node>& arg_reductee,
add_output(input_reductee.get_element_type(), result_shape); add_output(input_reductee.get_element_type(), result_shape);
} }
shared_ptr<Node> op::Reduce::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 2)
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Reduce>(
new_args.at(0), new_args.at(1), m_reduction_function, m_reduction_axes);
}
...@@ -98,15 +98,7 @@ namespace ngraph ...@@ -98,15 +98,7 @@ namespace ngraph
const AxisSet& reduction_axes); const AxisSet& reduction_axes);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override copy_with_new_args(const NodeVector& new_args) const override;
{
if (new_args.size() != 2)
{
throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<Reduce>(
new_args.at(0), new_args.at(1), m_reduction_function, m_reduction_axes);
}
/// \return A one-element vector containing the function to use for reduction. /// \return A one-element vector containing the function to use for reduction.
std::vector<std::shared_ptr<Function>> get_functions() const override std::vector<std::shared_ptr<Function>> get_functions() const override
......
...@@ -21,9 +21,9 @@ ...@@ -21,9 +21,9 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::ReduceWindow::ReduceWindow(const std::shared_ptr<Node>& arg_reductee, op::ReduceWindow::ReduceWindow(const shared_ptr<Node>& arg_reductee,
const std::shared_ptr<Node>& arg_init, const shared_ptr<Node>& arg_init,
const std::shared_ptr<Function>& reduction_function, const shared_ptr<Function>& reduction_function,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides) const Strides& window_movement_strides)
: RequiresTensorViewArgs("ReduceWindow", {arg_reductee, arg_init}) : RequiresTensorViewArgs("ReduceWindow", {arg_reductee, arg_init})
...@@ -129,3 +129,16 @@ op::ReduceWindow::ReduceWindow(const std::shared_ptr<Node>& arg_reductee, ...@@ -129,3 +129,16 @@ op::ReduceWindow::ReduceWindow(const std::shared_ptr<Node>& arg_reductee,
set_value_type_checked(input_reductee.get_element_type(), result_shape); set_value_type_checked(input_reductee.get_element_type(), result_shape);
} }
shared_ptr<Node> op::ReduceWindow::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 2)
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<ReduceWindow>(new_args.at(0),
new_args.at(1),
m_reduction_function,
m_window_shape,
m_window_movement_strides);
}
...@@ -67,18 +67,7 @@ namespace ngraph ...@@ -67,18 +67,7 @@ namespace ngraph
const Strides& window_movement_strides); const Strides& window_movement_strides);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override copy_with_new_args(const NodeVector& new_args) const override;
{
if (new_args.size() != 2)
{
throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<ReduceWindow>(new_args.at(0),
new_args.at(1),
m_reduction_function,
m_window_shape,
m_window_movement_strides);
}
/// \return A singleton vector containing the function to use for reduction. /// \return A singleton vector containing the function to use for reduction.
std::vector<std::shared_ptr<Function>> get_functions() const override std::vector<std::shared_ptr<Function>> get_functions() const override
......
...@@ -26,6 +26,15 @@ op::Relu::Relu(shared_ptr<Node> arg) ...@@ -26,6 +26,15 @@ op::Relu::Relu(shared_ptr<Node> arg)
set_value_type_checked(arg->get_element_type(), arg->get_shape()); set_value_type_checked(arg->get_element_type(), arg->get_shape());
} }
shared_ptr<Node> op::Relu::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Relu>(new_args.at(0));
}
op::ReluBackprop::ReluBackprop(shared_ptr<Node> arg, shared_ptr<Node> delta) op::ReluBackprop::ReluBackprop(shared_ptr<Node> arg, shared_ptr<Node> delta)
: RequiresTensorViewArgs("ReluBackprop", {arg, delta}) : RequiresTensorViewArgs("ReluBackprop", {arg, delta})
{ {
...@@ -40,8 +49,17 @@ op::ReluBackprop::ReluBackprop(shared_ptr<Node> arg, shared_ptr<Node> delta) ...@@ -40,8 +49,17 @@ op::ReluBackprop::ReluBackprop(shared_ptr<Node> arg, shared_ptr<Node> delta)
set_value_type_checked(delta->get_element_type(), delta->get_shape()); set_value_type_checked(delta->get_element_type(), delta->get_shape());
} }
void op::Relu::generate_adjoints(autodiff::Adjoints& adjoints, const std::shared_ptr<Node>& delta) shared_ptr<Node> op::ReluBackprop::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 2)
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<ReluBackprop>(new_args.at(0), new_args.at(1));
}
void op::Relu::generate_adjoints(autodiff::Adjoints& adjoints, const shared_ptr<Node>& delta)
{ {
auto backprop = std::make_shared<op::ReluBackprop>(get_input_op(0), delta); auto backprop = make_shared<op::ReluBackprop>(get_input_op(0), delta);
adjoints.add_delta(get_input_op(0), backprop); adjoints.add_delta(get_input_op(0), backprop);
} }
...@@ -39,14 +39,7 @@ namespace ngraph ...@@ -39,14 +39,7 @@ namespace ngraph
Relu(std::shared_ptr<ngraph::Node> arg); Relu(std::shared_ptr<ngraph::Node> arg);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override copy_with_new_args(const NodeVector& new_args) const override;
{
if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<Relu>(new_args.at(0));
}
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const std::shared_ptr<Node>& delta) override; const std::shared_ptr<Node>& delta) override;
...@@ -63,14 +56,7 @@ namespace ngraph ...@@ -63,14 +56,7 @@ namespace ngraph
ReluBackprop(std::shared_ptr<ngraph::Node> arg, std::shared_ptr<ngraph::Node> delta); ReluBackprop(std::shared_ptr<ngraph::Node> arg, std::shared_ptr<ngraph::Node> delta);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override copy_with_new_args(const NodeVector& new_args) const override;
{
if (new_args.size() != 2)
{
throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<ReluBackprop>(new_args.at(0), new_args.at(1));
}
}; };
} }
} }
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "ngraph/op/remainder.hpp"
using namespace std;
using namespace ngraph;
op::Remainder::Remainder(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg1)
: BinaryElementwiseArithmetic("Remainder", arg0, arg1)
{
}
shared_ptr<Node> op::Remainder::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 2)
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Remainder>(new_args.at(0), new_args.at(1));
}
...@@ -45,20 +45,10 @@ namespace ngraph ...@@ -45,20 +45,10 @@ namespace ngraph
/// ///
/// \param arg0 Node that produces the first input tensor. /// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor. /// \param arg1 Node that produces the second input tensor.
Remainder(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) Remainder(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1);
: BinaryElementwiseArithmetic("Remainder", arg0, arg1)
{
}
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override copy_with_new_args(const NodeVector& new_args) const override;
{
if (new_args.size() != 2)
{
throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<Remainder>(new_args.at(0), new_args.at(1));
}
}; };
} }
} }
...@@ -21,8 +21,8 @@ ...@@ -21,8 +21,8 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::ReplaceSlice::ReplaceSlice(const std::shared_ptr<Node>& arg0, op::ReplaceSlice::ReplaceSlice(const shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1, const shared_ptr<Node>& arg1,
const Coordinate& lower_bounds, const Coordinate& lower_bounds,
const Coordinate& upper_bounds, const Coordinate& upper_bounds,
const Strides& strides) const Strides& strides)
...@@ -34,8 +34,8 @@ op::ReplaceSlice::ReplaceSlice(const std::shared_ptr<Node>& arg0, ...@@ -34,8 +34,8 @@ op::ReplaceSlice::ReplaceSlice(const std::shared_ptr<Node>& arg0,
check_args(); check_args();
} }
op::ReplaceSlice::ReplaceSlice(const std::shared_ptr<Node>& arg0, op::ReplaceSlice::ReplaceSlice(const shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1, const shared_ptr<Node>& arg1,
const Coordinate& lower_bounds, const Coordinate& lower_bounds,
const Coordinate& upper_bounds) const Coordinate& upper_bounds)
: RequiresTensorViewArgs("ReplaceSlice", {arg0, arg1}) : RequiresTensorViewArgs("ReplaceSlice", {arg0, arg1})
...@@ -117,8 +117,18 @@ void op::ReplaceSlice::check_args() ...@@ -117,8 +117,18 @@ void op::ReplaceSlice::check_args()
set_value_type_checked(input_0_element_type, input_0_shape); set_value_type_checked(input_0_element_type, input_0_shape);
} }
shared_ptr<Node> op::ReplaceSlice::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 2)
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<ReplaceSlice>(
new_args.at(0), new_args.at(1), m_lower_bounds, m_upper_bounds, m_strides);
}
void op::ReplaceSlice::generate_adjoints(autodiff::Adjoints& adjoints, void op::ReplaceSlice::generate_adjoints(autodiff::Adjoints& adjoints,
const std::shared_ptr<Node>& delta) const shared_ptr<Node>& delta)
{ {
auto x = get_inputs().at(0).get_output().get_node(); auto x = get_inputs().at(0).get_output().get_node();
auto& y_input = get_inputs().at(1); auto& y_input = get_inputs().at(1);
...@@ -129,8 +139,7 @@ void op::ReplaceSlice::generate_adjoints(autodiff::Adjoints& adjoints, ...@@ -129,8 +139,7 @@ void op::ReplaceSlice::generate_adjoints(autodiff::Adjoints& adjoints,
auto zeros_shaped_like_y = op::Constant::create(y_element_type, y_shape, {0.0}); auto zeros_shaped_like_y = op::Constant::create(y_element_type, y_shape, {0.0});
adjoints.add_delta(x, adjoints.add_delta(x,
std::make_shared<op::ReplaceSlice>( make_shared<op::ReplaceSlice>(
delta, zeros_shaped_like_y, m_lower_bounds, m_upper_bounds, m_strides)); delta, zeros_shaped_like_y, m_lower_bounds, m_upper_bounds, m_strides));
adjoints.add_delta( adjoints.add_delta(y, make_shared<op::Slice>(delta, m_lower_bounds, m_upper_bounds, m_strides));
y, std::make_shared<op::Slice>(delta, m_lower_bounds, m_upper_bounds, m_strides));
} }
...@@ -77,15 +77,7 @@ namespace ngraph ...@@ -77,15 +77,7 @@ namespace ngraph
const Coordinate& upper_bounds); const Coordinate& upper_bounds);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override copy_with_new_args(const NodeVector& new_args) const override;
{
if (new_args.size() != 2)
{
throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<ReplaceSlice>(
new_args.at(0), new_args.at(1), m_lower_bounds, m_upper_bounds, m_strides);
}
/// \return The inclusive lower-bound coordinates. /// \return The inclusive lower-bound coordinates.
const Coordinate& get_lower_bounds() const { return m_lower_bounds; } const Coordinate& get_lower_bounds() const { return m_lower_bounds; }
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Reshape::Reshape(const std::shared_ptr<Node>& arg, op::Reshape::Reshape(const shared_ptr<Node>& arg,
const AxisVector& input_order, const AxisVector& input_order,
const Shape& output_shape) const Shape& output_shape)
: RequiresTensorViewArgs("Reshape", {arg}) : RequiresTensorViewArgs("Reshape", {arg})
...@@ -40,8 +40,8 @@ op::Reshape::Reshape(const std::shared_ptr<Node>& arg, ...@@ -40,8 +40,8 @@ op::Reshape::Reshape(const std::shared_ptr<Node>& arg,
for (size_t i = 0; i < input_rank; i++) for (size_t i = 0; i < input_rank; i++)
{ {
auto it = std::find(std::begin(m_input_order), std::end(m_input_order), i); auto it = find(begin(m_input_order), end(m_input_order), i);
if (std::end(m_input_order) == it) if (end(m_input_order) == it)
{ {
throw ngraph_error( throw ngraph_error(
"Input axis order for reshape is not a permutation of argument's axes"); "Input axis order for reshape is not a permutation of argument's axes");
...@@ -70,8 +70,16 @@ op::Reshape::Reshape(const std::shared_ptr<Node>& arg, ...@@ -70,8 +70,16 @@ op::Reshape::Reshape(const std::shared_ptr<Node>& arg,
set_value_type_checked(input.get_element_type(), m_output_shape); set_value_type_checked(input.get_element_type(), m_output_shape);
} }
void op::Reshape::generate_adjoints(autodiff::Adjoints& adjoints, shared_ptr<Node> op::Reshape::copy_with_new_args(const NodeVector& new_args) const
const std::shared_ptr<Node>& delta) {
if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Reshape>(new_args.at(0), m_input_order, m_output_shape);
}
void op::Reshape::generate_adjoints(autodiff::Adjoints& adjoints, const shared_ptr<Node>& delta)
{ {
auto x_shape = get_inputs().at(0).get_shape(); auto x_shape = get_inputs().at(0).get_shape();
auto x_rank = x_shape.size(); auto x_rank = x_shape.size();
......
...@@ -72,14 +72,7 @@ namespace ngraph ...@@ -72,14 +72,7 @@ namespace ngraph
const Shape& output_shape); const Shape& output_shape);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override copy_with_new_args(const NodeVector& new_args) const override;
{
if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<Reshape>(new_args.at(0), m_input_order, m_output_shape);
}
/// \return The order in which to iterate over input axes. /// \return The order in which to iterate over input axes.
const AxisVector& get_input_order() const { return m_input_order; } const AxisVector& get_input_order() const { return m_input_order; }
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Result::Result(const std::shared_ptr<Node>& arg) op::Result::Result(const shared_ptr<Node>& arg)
: RequiresTensorViewArgs("Result", {arg}) : RequiresTensorViewArgs("Result", {arg})
{ {
if (arg->get_outputs().size() != 1) if (arg->get_outputs().size() != 1)
...@@ -37,7 +37,7 @@ op::Result::Result(const std::shared_ptr<Node>& arg) ...@@ -37,7 +37,7 @@ op::Result::Result(const std::shared_ptr<Node>& arg)
set_value_type_checked(arg->get_element_type(), arg->get_shape()); set_value_type_checked(arg->get_element_type(), arg->get_shape());
} }
std::shared_ptr<Node> op::Result::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Result::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 1) if (new_args.size() != 1)
{ {
...@@ -49,7 +49,12 @@ std::shared_ptr<Node> op::Result::copy_with_new_args(const NodeVector& new_args) ...@@ -49,7 +49,12 @@ std::shared_ptr<Node> op::Result::copy_with_new_args(const NodeVector& new_args)
throw ngraph_error("Expected a single-output argument"); throw ngraph_error("Expected a single-output argument");
} }
auto res = std::make_shared<Result>(new_args.at(0)); auto res = make_shared<Result>(new_args.at(0));
res->set_needs_copy(res->needs_copy()); res->set_needs_copy(res->needs_copy());
return res; return res;
} }
void op::Result::generate_adjoints(autodiff::Adjoints& adjoints, const shared_ptr<Node>& delta)
{
adjoints.add_delta(get_input_op(0), delta);
}
...@@ -40,10 +40,7 @@ namespace ngraph ...@@ -40,10 +40,7 @@ namespace ngraph
bool needs_copy() const { return m_needs_copy; } bool needs_copy() const { return m_needs_copy; }
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const std::shared_ptr<Node>& delta) override const std::shared_ptr<Node>& delta) override;
{
adjoints.add_delta(get_input_op(0), delta);
}
private: private:
bool m_needs_copy{true}; bool m_needs_copy{true};
......
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Reverse::Reverse(const std::shared_ptr<Node>& arg, const AxisSet& reversed_axes) op::Reverse::Reverse(const shared_ptr<Node>& arg, const AxisSet& reversed_axes)
: RequiresTensorViewArgs("Reverse", {arg}) : RequiresTensorViewArgs("Reverse", {arg})
, m_reversed_axes(reversed_axes) , m_reversed_axes(reversed_axes)
{ {
...@@ -36,7 +36,7 @@ op::Reverse::Reverse(const std::shared_ptr<Node>& arg, const AxisSet& reversed_a ...@@ -36,7 +36,7 @@ op::Reverse::Reverse(const std::shared_ptr<Node>& arg, const AxisSet& reversed_a
{ {
if (axis >= input_rank) if (axis >= input_rank)
{ {
std::stringstream ss; stringstream ss;
ss << "Reverse axis " << axis << " is out of bounds (input rank is " << input_rank ss << "Reverse axis " << axis << " is out of bounds (input rank is " << input_rank
<< ")."; << ").";
throw ngraph_error(ss.str()); throw ngraph_error(ss.str());
...@@ -46,8 +46,16 @@ op::Reverse::Reverse(const std::shared_ptr<Node>& arg, const AxisSet& reversed_a ...@@ -46,8 +46,16 @@ op::Reverse::Reverse(const std::shared_ptr<Node>& arg, const AxisSet& reversed_a
set_value_type_checked(input.get_element_type(), input_shape); set_value_type_checked(input.get_element_type(), input_shape);
} }
void op::Reverse::generate_adjoints(autodiff::Adjoints& adjoints, shared_ptr<Node> op::Reverse::copy_with_new_args(const NodeVector& new_args) const
const std::shared_ptr<Node>& delta) {
if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Reverse>(new_args.at(0), m_reversed_axes);
}
void op::Reverse::generate_adjoints(autodiff::Adjoints& adjoints, const shared_ptr<Node>& delta)
{ {
auto x = get_input_op(0); auto x = get_input_op(0);
......
...@@ -53,14 +53,7 @@ namespace ngraph ...@@ -53,14 +53,7 @@ namespace ngraph
Reverse(const std::shared_ptr<Node>& arg, const AxisSet& reversed_axes); Reverse(const std::shared_ptr<Node>& arg, const AxisSet& reversed_axes);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override copy_with_new_args(const NodeVector& new_args) const override;
{
if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<Reverse>(new_args.at(0), m_reversed_axes);
}
/// \return The set of axes to reverse. /// \return The set of axes to reverse.
const AxisSet& get_reversed_axes() const { return m_reversed_axes; } const AxisSet& get_reversed_axes() const { return m_reversed_axes; }
......
...@@ -26,9 +26,9 @@ ...@@ -26,9 +26,9 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Select::Select(const std::shared_ptr<Node>& arg0, op::Select::Select(const shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1, const shared_ptr<Node>& arg1,
const std::shared_ptr<Node>& arg2) const shared_ptr<Node>& arg2)
: RequiresTensorViewArgs("Select", NodeVector{arg0, arg1, arg2}) : RequiresTensorViewArgs("Select", NodeVector{arg0, arg1, arg2})
{ {
auto& input_0 = get_inputs().at(0); auto& input_0 = get_inputs().at(0);
...@@ -51,16 +51,23 @@ op::Select::Select(const std::shared_ptr<Node>& arg0, ...@@ -51,16 +51,23 @@ op::Select::Select(const std::shared_ptr<Node>& arg0,
set_value_type_checked(input_1.get_element_type(), input_1.get_shape()); set_value_type_checked(input_1.get_element_type(), input_1.get_shape());
} }
void ngraph::op::Select::generate_adjoints(autodiff::Adjoints& adjoints, shared_ptr<Node> op::Select::copy_with_new_args(const NodeVector& new_args) const
const std::shared_ptr<Node>& delta) {
if (new_args.size() != 3)
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Select>(new_args.at(0), new_args.at(1), new_args.at(2));
}
void op::Select::generate_adjoints(autodiff::Adjoints& adjoints, const shared_ptr<Node>& delta)
{ {
auto p = get_inputs().at(0).get_output().get_node(); auto p = get_inputs().at(0).get_output().get_node();
auto x = get_inputs().at(1).get_output().get_node(); auto x = get_inputs().at(1).get_output().get_node();
auto y = get_inputs().at(2).get_output().get_node(); auto y = get_inputs().at(2).get_output().get_node();
auto p_as_x_type = std::make_shared<op::Convert>(p, x->get_element_type()); auto p_as_x_type = make_shared<op::Convert>(p, x->get_element_type());
auto not_p_as_y_type = auto not_p_as_y_type = make_shared<op::Convert>(make_shared<op::Not>(p), y->get_element_type());
std::make_shared<op::Convert>(std::make_shared<op::Not>(p), y->get_element_type());
adjoints.add_delta(x, delta * p_as_x_type); adjoints.add_delta(x, delta * p_as_x_type);
adjoints.add_delta(y, delta * not_p_as_y_type); adjoints.add_delta(y, delta * not_p_as_y_type);
......
...@@ -50,14 +50,7 @@ namespace ngraph ...@@ -50,14 +50,7 @@ namespace ngraph
const std::shared_ptr<Node>& arg2); const std::shared_ptr<Node>& arg2);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override copy_with_new_args(const NodeVector& new_args) const override;
{
if (new_args.size() != 3)
{
throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<Select>(new_args.at(0), new_args.at(1), new_args.at(2));
}
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
......
...@@ -22,11 +22,11 @@ ...@@ -22,11 +22,11 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::SelectAndScatter::SelectAndScatter(const std::shared_ptr<Node>& arg_selectee, op::SelectAndScatter::SelectAndScatter(const shared_ptr<Node>& arg_selectee,
const std::shared_ptr<Node>& arg_source, const shared_ptr<Node>& arg_source,
const std::shared_ptr<Node>& arg_init, const shared_ptr<Node>& arg_init,
const std::shared_ptr<Function>& selection_function, const shared_ptr<Function>& selection_function,
const std::shared_ptr<Function>& scatter_function, const shared_ptr<Function>& scatter_function,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides) const Strides& window_movement_strides)
: RequiresTensorViewArgs("SelectAndScatter", {arg_selectee, arg_source, arg_init}) : RequiresTensorViewArgs("SelectAndScatter", {arg_selectee, arg_source, arg_init})
...@@ -216,3 +216,18 @@ op::SelectAndScatter::SelectAndScatter(const std::shared_ptr<Node>& arg_selectee ...@@ -216,3 +216,18 @@ op::SelectAndScatter::SelectAndScatter(const std::shared_ptr<Node>& arg_selectee
// //
set_value_type_checked(input_selectee_element_type, input_selectee_shape); set_value_type_checked(input_selectee_element_type, input_selectee_shape);
} }
shared_ptr<Node> op::SelectAndScatter::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 3)
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<SelectAndScatter>(new_args.at(0),
new_args.at(1),
new_args.at(2),
m_selection_function,
m_scatter_function,
m_window_shape,
m_window_movement_strides);
}
...@@ -88,20 +88,7 @@ namespace ngraph ...@@ -88,20 +88,7 @@ namespace ngraph
const Strides& window_movement_strides); const Strides& window_movement_strides);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override copy_with_new_args(const NodeVector& new_args) const override;
{
if (new_args.size() != 3)
{
throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<SelectAndScatter>(new_args.at(0),
new_args.at(1),
new_args.at(2),
m_selection_function,
m_scatter_function,
m_window_shape,
m_window_movement_strides);
}
/// \return A vector of length 2 containing the selection function as element 0, and the scatter function as element 1. /// \return A vector of length 2 containing the selection function as element 0, and the scatter function as element 1.
std::vector<std::shared_ptr<Function>> get_functions() const override std::vector<std::shared_ptr<Function>> get_functions() const override
......
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "ngraph/op/sign.hpp"
using namespace std;
using namespace ngraph;
op::Sign::Sign(const shared_ptr<Node>& arg)
: UnaryElementwiseArithmetic("Sign", arg)
{
}
shared_ptr<Node> op::Sign::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Sign>(new_args.at(0));
}
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment