Unverified Commit 3da93e51 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Eradicate get_argument in op implementations (#3441)

parent 17f6e08c
...@@ -389,6 +389,9 @@ namespace ngraph ...@@ -389,6 +389,9 @@ namespace ngraph
/// \return A vector containing a handle for each of this node's inputs, in order. /// \return A vector containing a handle for each of this node's inputs, in order.
std::vector<Input<const Node>> inputs() const; std::vector<Input<const Node>> inputs() const;
/// \return A vector containing the values for each input
std::vector<Output<Node>> input_values() const;
/// \return A vector containing a handle for each of this node's outputs, in order. /// \return A vector containing a handle for each of this node's outputs, in order.
// TODO: Rename to get_outputs()? // TODO: Rename to get_outputs()?
std::vector<Output<Node>> outputs(); std::vector<Output<Node>> outputs();
...@@ -404,6 +407,8 @@ namespace ngraph ...@@ -404,6 +407,8 @@ namespace ngraph
/// \throw std::out_of_range if the node does not have at least `input_index+1` inputs. /// \throw std::out_of_range if the node does not have at least `input_index+1` inputs.
Input<const Node> input(size_t input_index) const; Input<const Node> input(size_t input_index) const;
Output<Node> input_value(size_t input_index) const;
/// \return A handle to the `output_index`th output of this node. /// \return A handle to the `output_index`th output of this node.
/// \throw std::out_of_range if the node does not have at least `output_index+1` outputs. /// \throw std::out_of_range if the node does not have at least `output_index+1` outputs.
Output<Node> output(size_t output_index); Output<Node> output(size_t output_index);
...@@ -629,6 +634,11 @@ namespace ngraph ...@@ -629,6 +634,11 @@ namespace ngraph
return Input<Node>(this, input_index); return Input<Node>(this, input_index);
} }
inline Output<Node> Node::input_value(size_t input_index) const
{
return input(input_index).get_source_output();
}
inline Input<const Node> Node::input(size_t input_index) const inline Input<const Node> Node::input(size_t input_index) const
{ {
if (input_index >= m_inputs.size()) if (input_index >= m_inputs.size())
...@@ -705,6 +715,18 @@ namespace ngraph ...@@ -705,6 +715,18 @@ namespace ngraph
return result; return result;
} }
inline std::vector<Output<Node>> Node::input_values() const
{
std::vector<Output<Node>> result;
for (size_t i = 0; i < get_input_size(); i++)
{
result.emplace_back(input(i).get_source_output());
}
return result;
}
inline std::vector<Input<const Node>> Node::inputs() const inline std::vector<Input<const Node>> Node::inputs() const
{ {
std::vector<Input<const Node>> result; std::vector<Input<const Node>> result;
......
...@@ -39,7 +39,7 @@ void op::Abs::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& ...@@ -39,7 +39,7 @@ void op::Abs::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector&
{ {
auto delta = deltas.at(0); auto delta = deltas.at(0);
auto x = get_argument(0); auto x = input_value(0);
adjoints.add_delta(x, delta * make_shared<op::Sign>(x)); adjoints.add_delta(x, delta * make_shared<op::Sign>(x));
} }
...@@ -48,7 +48,7 @@ void op::Acos::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& ...@@ -48,7 +48,7 @@ void op::Acos::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector&
{ {
auto delta = deltas.at(0); auto delta = deltas.at(0);
auto x = get_argument(0); auto x = input_value(0);
auto one = make_shared<op::ScalarConstantLike>(x, 1.0); auto one = make_shared<op::ScalarConstantLike>(x, 1.0);
auto ones = make_shared<op::BroadcastLike>(one, x, AxisSet()); auto ones = make_shared<op::BroadcastLike>(one, x, AxisSet());
......
...@@ -42,8 +42,8 @@ void op::Add::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& ...@@ -42,8 +42,8 @@ void op::Add::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector&
auto delta = deltas.at(0); auto delta = deltas.at(0);
auto x = get_argument(0); auto x = input_value(0);
auto y = get_argument(1); auto y = input_value(1);
adjoints.add_delta(x, delta); adjoints.add_delta(x, delta);
adjoints.add_delta(y, delta); adjoints.add_delta(y, delta);
......
...@@ -49,14 +49,14 @@ void op::Asin::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& ...@@ -49,14 +49,14 @@ void op::Asin::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector&
{ {
auto delta = deltas.at(0); auto delta = deltas.at(0);
auto x = get_argument(0); auto x = input_value(0);
auto one = make_shared<op::Constant>(x->get_element_type(), Shape{}, vector<string>{"1"}); auto one = make_shared<op::Constant>(x.get_element_type(), Shape{}, vector<string>{"1"});
AxisSet axes; AxisSet axes;
for (size_t i = 0; i < x->get_shape().size(); i++) for (size_t i = 0; i < x.get_shape().size(); i++)
axes.insert(i); axes.insert(i);
auto ones = make_shared<op::Broadcast>(one, x->get_shape(), axes); auto ones = make_shared<op::Broadcast>(one, x.get_shape(), axes);
adjoints.add_delta(x, delta / make_shared<op::Sqrt>(ones - x * x)); adjoints.add_delta(x, delta / make_shared<op::Sqrt>(ones - x * x));
} }
...@@ -48,14 +48,14 @@ void op::Atan::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& ...@@ -48,14 +48,14 @@ void op::Atan::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector&
{ {
auto delta = deltas.at(0); auto delta = deltas.at(0);
auto x = get_argument(0); auto x = input_value(0);
auto one = make_shared<op::Constant>(x->get_element_type(), Shape{}, vector<string>{"1"}); auto one = make_shared<op::Constant>(x.get_element_type(), Shape{}, vector<string>{"1"});
AxisSet axes; AxisSet axes;
for (size_t i = 0; i < x->get_shape().size(); i++) for (size_t i = 0; i < x.get_shape().size(); i++)
axes.insert(i); axes.insert(i);
auto ones = make_shared<op::Broadcast>(one, x->get_shape(), axes); auto ones = make_shared<op::Broadcast>(one, x.get_shape(), axes);
adjoints.add_delta(x, delta / (ones + x * x)); adjoints.add_delta(x, delta / (ones + x * x));
} }
...@@ -361,7 +361,7 @@ void op::AvgPool::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVect ...@@ -361,7 +361,7 @@ void op::AvgPool::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVect
auto delta = deltas.at(0); auto delta = deltas.at(0);
auto operand = get_argument(0); auto operand = input_value(0);
auto& operand_shape = get_input_shape(0); auto& operand_shape = get_input_shape(0);
auto backprop = make_shared<op::AvgPoolBackprop>(operand_shape, auto backprop = make_shared<op::AvgPoolBackprop>(operand_shape,
delta, delta,
......
...@@ -77,9 +77,9 @@ std::shared_ptr<Node> op::BatchNormTraining::copy_with_new_args(const NodeVector ...@@ -77,9 +77,9 @@ std::shared_ptr<Node> op::BatchNormTraining::copy_with_new_args(const NodeVector
void op::BatchNormTraining::generate_adjoints(autodiff::Adjoints& adjoints, void op::BatchNormTraining::generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) const NodeVector& deltas)
{ {
auto gamma = input(0).get_source_output(); auto gamma = input_value(0);
auto beta = input(1).get_source_output(); auto beta = input_value(1);
auto data = input(2).get_source_output(); auto data = input_value(2);
// Extract mean and variance outputs from BatchNormBase // Extract mean and variance outputs from BatchNormBase
// as these are used by BatchNormTrainingBackprop. // as these are used by BatchNormTrainingBackprop.
......
...@@ -90,7 +90,7 @@ void op::Broadcast::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVe ...@@ -90,7 +90,7 @@ void op::Broadcast::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVe
{ {
auto delta = deltas.at(0); auto delta = deltas.at(0);
auto x = get_argument(0); auto x = input_value(0);
adjoints.add_delta(x, make_shared<op::Sum>(delta, m_broadcast_axes)); adjoints.add_delta(x, make_shared<op::Sum>(delta, m_broadcast_axes));
} }
......
...@@ -111,9 +111,9 @@ void op::Concat::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVecto ...@@ -111,9 +111,9 @@ void op::Concat::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVecto
size_t pos = 0; size_t pos = 0;
for (auto input : inputs()) for (auto value : input_values())
{ {
auto arg_shape = input.get_shape(); auto arg_shape = value.get_shape();
auto slice_width = arg_shape[m_concatenation_axis]; auto slice_width = arg_shape[m_concatenation_axis];
...@@ -123,7 +123,7 @@ void op::Concat::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVecto ...@@ -123,7 +123,7 @@ void op::Concat::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVecto
arg_delta_slice_upper[m_concatenation_axis] = next_pos; arg_delta_slice_upper[m_concatenation_axis] = next_pos;
adjoints.add_delta( adjoints.add_delta(
input.get_source_output(), value,
make_shared<op::Slice>( make_shared<op::Slice>(
delta, arg_delta_slice_lower, arg_delta_slice_upper, arg_delta_slice_strides)); delta, arg_delta_slice_lower, arg_delta_slice_upper, arg_delta_slice_strides));
......
...@@ -45,7 +45,7 @@ void op::Convert::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVect ...@@ -45,7 +45,7 @@ void op::Convert::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVect
{ {
auto delta = deltas.at(0); auto delta = deltas.at(0);
auto x = get_argument(0); auto x = input_value(0);
adjoints.add_delta(x, make_shared<op::Convert>(delta, x->get_element_type())); adjoints.add_delta(x, make_shared<op::Convert>(delta, x.get_element_type()));
} }
...@@ -184,11 +184,11 @@ void op::Convolution::generate_adjoints(autodiff::Adjoints& adjoints, const Node ...@@ -184,11 +184,11 @@ void op::Convolution::generate_adjoints(autodiff::Adjoints& adjoints, const Node
{ {
auto delta = deltas.at(0); auto delta = deltas.at(0);
auto x = get_argument(0); auto x = input_value(0);
const auto x_shape = x->get_shape(); const auto x_shape = x.get_shape();
auto f = get_argument(1); auto f = input_value(1);
const auto f_shape = f->get_shape(); const auto f_shape = f.get_shape();
adjoints.add_delta(x, adjoints.add_delta(x,
make_shared<op::ConvolutionBackpropData>(x_shape, make_shared<op::ConvolutionBackpropData>(x_shape,
...@@ -300,11 +300,11 @@ void op::ConvolutionBackpropData::generate_adjoints(autodiff::Adjoints& adjoints ...@@ -300,11 +300,11 @@ void op::ConvolutionBackpropData::generate_adjoints(autodiff::Adjoints& adjoints
{ {
auto delta = deltas.at(0); auto delta = deltas.at(0);
auto x = get_argument(1); auto x = input_value(1);
const auto x_shape = x->get_shape(); const auto x_shape = x.get_shape();
auto f = get_argument(0); auto f = input_value(0);
const auto f_shape = f->get_shape(); const auto f_shape = f.get_shape();
auto data_conv = make_shared<op::Convolution>(delta, auto data_conv = make_shared<op::Convolution>(delta,
f, f,
......
...@@ -40,7 +40,7 @@ void op::Cos::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& ...@@ -40,7 +40,7 @@ void op::Cos::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector&
{ {
auto delta = deltas.at(0); auto delta = deltas.at(0);
auto x = get_argument(0); auto x = input_value(0);
adjoints.add_delta(x, -delta * (make_shared<op::Sin>(x))); adjoints.add_delta(x, -delta * (make_shared<op::Sin>(x)));
} }
...@@ -39,7 +39,7 @@ void op::Cosh::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& ...@@ -39,7 +39,7 @@ void op::Cosh::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector&
{ {
auto delta = deltas.at(0); auto delta = deltas.at(0);
auto x = get_argument(0); auto x = input_value(0);
adjoints.add_delta(x, delta * (make_shared<op::Sinh>(x))); adjoints.add_delta(x, delta * (make_shared<op::Sinh>(x)));
} }
...@@ -57,8 +57,8 @@ void op::Divide::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVecto ...@@ -57,8 +57,8 @@ void op::Divide::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVecto
auto delta = deltas.at(0); auto delta = deltas.at(0);
auto x = get_argument(0); auto x = input_value(0);
auto y = get_argument(1); auto y = input_value(1);
adjoints.add_delta(x, delta / y); adjoints.add_delta(x, delta / y);
adjoints.add_delta(y, -delta * shared_from_this() / y); adjoints.add_delta(y, -delta * shared_from_this() / y);
......
...@@ -179,11 +179,11 @@ void op::Dot::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& ...@@ -179,11 +179,11 @@ void op::Dot::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector&
{ {
auto delta = deltas.at(0); auto delta = deltas.at(0);
auto x = get_argument(0); auto x = input_value(0);
auto y = get_argument(1); auto y = input_value(1);
auto x_shape = x->get_shape(); // shape IJ auto x_shape = x.get_shape(); // shape IJ
auto y_shape = y->get_shape(); // shape JK auto y_shape = y.get_shape(); // shape JK
auto delta_shape = delta->get_shape(); // shape IK auto delta_shape = delta->get_shape(); // shape IK
Shape I_shape; Shape I_shape;
......
...@@ -38,7 +38,7 @@ void op::Exp::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& ...@@ -38,7 +38,7 @@ void op::Exp::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector&
{ {
auto delta = deltas.at(0); auto delta = deltas.at(0);
auto x = get_argument(0); auto x = input_value(0);
adjoints.add_delta(x, delta * shared_from_this()); adjoints.add_delta(x, delta * shared_from_this());
} }
...@@ -81,15 +81,15 @@ void op::BatchMatMul::generate_adjoints(autodiff::Adjoints& adjoints, const Node ...@@ -81,15 +81,15 @@ void op::BatchMatMul::generate_adjoints(autodiff::Adjoints& adjoints, const Node
{ {
auto delta = deltas.at(0); // NxIxK auto delta = deltas.at(0); // NxIxK
auto arg0 = get_argument(0); // NxIxJ auto arg0 = input_value(0); // NxIxJ
auto arg1 = get_argument(1); // NxJxK auto arg1 = input_value(1); // NxJxK
auto delta_dot_arg1 = auto delta_dot_arg1 = make_shared<op::BatchMatMul>(
make_shared<op::BatchMatMul>(delta, util::batch_mat_transpose(arg1)); // IK.KJ->IJ delta, util::batch_mat_transpose(arg1.get_node_shared_ptr())); // IK.KJ->IJ
adjoints.add_delta(arg0, delta_dot_arg1); adjoints.add_delta(arg0, delta_dot_arg1);
auto arg0_dot_delta = auto arg0_dot_delta = make_shared<BatchMatMul>(
make_shared<BatchMatMul>(util::batch_mat_transpose(arg0), delta); // JI.IK->JK util::batch_mat_transpose(arg0.get_node_shared_ptr()), delta); // JI.IK->JK
adjoints.add_delta(arg1, arg0_dot_delta); adjoints.add_delta(arg1, arg0_dot_delta);
} }
......
...@@ -26,7 +26,7 @@ const string op::CompiledKernel::type_name{"CompiledKernel"}; ...@@ -26,7 +26,7 @@ const string op::CompiledKernel::type_name{"CompiledKernel"};
shared_ptr<Node> ngraph::op::CompiledKernel::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> ngraph::op::CompiledKernel::copy_with_new_args(const NodeVector& new_args) const
{ {
auto args = inputs(); auto args = input_values();
if (new_args.size() != args.size()) if (new_args.size() != args.size())
{ {
throw ngraph_error("number of arguments don't match"); throw ngraph_error("number of arguments don't match");
...@@ -36,17 +36,16 @@ shared_ptr<Node> ngraph::op::CompiledKernel::copy_with_new_args(const NodeVector ...@@ -36,17 +36,16 @@ shared_ptr<Node> ngraph::op::CompiledKernel::copy_with_new_args(const NodeVector
NodeMap nm; NodeMap nm;
for (size_t i = 0; i < args.size(); i++) for (size_t i = 0; i < args.size(); i++)
{ {
nm[args.at(i).get_source_output().get_node()] = new_args.at(i); nm[args.at(i).get_node()] = new_args.at(i);
} }
NodeVector new_node_list; NodeVector new_node_list;
for (auto n : m_node_list) for (auto n : m_node_list)
{ {
OutputVector cur_args; OutputVector cur_args;
for (auto a : n->inputs()) for (auto a : n->input_values())
{ {
auto o = a.get_source_output(); cur_args.push_back(a.for_node(nm.at(a.get_node())));
cur_args.push_back(o.for_node(nm.at(o.get_node())));
} }
auto new_n = n->copy_with_new_inputs(cur_args); auto new_n = n->copy_with_new_inputs(cur_args);
nm[n.get()] = new_n; nm[n.get()] = new_n;
......
...@@ -63,20 +63,18 @@ void op::DynBroadcast::validate_and_infer_types() ...@@ -63,20 +63,18 @@ void op::DynBroadcast::validate_and_infer_types()
axes_shape_rank); axes_shape_rank);
PartialShape result_shape{PartialShape::dynamic()}; PartialShape result_shape{PartialShape::dynamic()};
if (input(1).get_source_output().get_node_shared_ptr()->is_constant()) if (input_value(1).get_node_shared_ptr()->is_constant())
{ {
result_shape = result_shape = static_pointer_cast<op::Constant>(input_value(1).get_node_shared_ptr())
static_pointer_cast<op::Constant>(input(1).get_source_output().get_node_shared_ptr())
->get_shape_val(); ->get_shape_val();
} }
bool axes_known = false; bool axes_known = false;
AxisSet broadcast_axes; AxisSet broadcast_axes;
if (input(2).get_source_output().get_node_shared_ptr()->is_constant()) if (input_value(2).get_node_shared_ptr()->is_constant())
{ {
axes_known = true; axes_known = true;
broadcast_axes = broadcast_axes = static_pointer_cast<op::Constant>(input_value(2).get_node_shared_ptr())
static_pointer_cast<op::Constant>(input(2).get_source_output().get_node_shared_ptr())
->get_axis_set_val(); ->get_axis_set_val();
} }
......
...@@ -107,9 +107,9 @@ void op::DynReplaceSlice::validate_and_infer_types() ...@@ -107,9 +107,9 @@ void op::DynReplaceSlice::validate_and_infer_types()
set_input_is_relevant_to_shape(3); set_input_is_relevant_to_shape(3);
set_input_is_relevant_to_shape(4); set_input_is_relevant_to_shape(4);
auto lower_bounds = dynamic_pointer_cast<op::Constant>(get_argument(2)); auto lower_bounds = dynamic_pointer_cast<op::Constant>(input_value(2).get_node_shared_ptr());
auto upper_bounds = dynamic_pointer_cast<op::Constant>(get_argument(3)); auto upper_bounds = dynamic_pointer_cast<op::Constant>(input_value(3).get_node_shared_ptr());
auto strides = dynamic_pointer_cast<op::Constant>(get_argument(4)); auto strides = dynamic_pointer_cast<op::Constant>(input_value(4).get_node_shared_ptr());
// TODO(amprocte): We can get a bit more information here about the ranks of arg and // TODO(amprocte): We can get a bit more information here about the ranks of arg and
// replacement by inspecting the attributes. // replacement by inspecting the attributes.
......
...@@ -50,7 +50,7 @@ void op::DynReshape::validate_and_infer_types() ...@@ -50,7 +50,7 @@ void op::DynReshape::validate_and_infer_types()
set_input_is_relevant_to_shape(1); set_input_is_relevant_to_shape(1);
if (auto const_shape = dynamic_pointer_cast<op::Constant>(get_argument(1))) if (auto const_shape = dynamic_pointer_cast<op::Constant>(input_value(1).get_node_shared_ptr()))
{ {
std::vector<int64_t> out_shape_val = const_shape->get_vector<int64_t>(); std::vector<int64_t> out_shape_val = const_shape->get_vector<int64_t>();
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
......
...@@ -86,9 +86,9 @@ void op::DynSlice::validate_and_infer_types() ...@@ -86,9 +86,9 @@ void op::DynSlice::validate_and_infer_types()
set_input_is_relevant_to_shape(2); set_input_is_relevant_to_shape(2);
set_input_is_relevant_to_shape(3); set_input_is_relevant_to_shape(3);
auto lower_bounds = dynamic_pointer_cast<op::Constant>(get_argument(1)); auto lower_bounds = dynamic_pointer_cast<op::Constant>(input_value(1).get_node_shared_ptr());
auto upper_bounds = dynamic_pointer_cast<op::Constant>(get_argument(2)); auto upper_bounds = dynamic_pointer_cast<op::Constant>(input_value(2).get_node_shared_ptr());
auto strides = dynamic_pointer_cast<op::Constant>(get_argument(3)); auto strides = dynamic_pointer_cast<op::Constant>(input_value(3).get_node_shared_ptr());
if (lower_bounds && upper_bounds && strides) if (lower_bounds && upper_bounds && strides)
{ {
......
...@@ -40,7 +40,7 @@ op::QuantizedAvgPool::QuantizedAvgPool(const Output<Node>& arg, ...@@ -40,7 +40,7 @@ op::QuantizedAvgPool::QuantizedAvgPool(const Output<Node>& arg,
void op::QuantizedAvgPool::validate_and_infer_types() void op::QuantizedAvgPool::validate_and_infer_types()
{ {
auto arg(input(0).get_source_output()); auto arg(input_value(0));
if (arg.get_element_type() != element::u8 && arg.get_element_type() != element::i8) if (arg.get_element_type() != element::u8 && arg.get_element_type() != element::i8)
{ {
throw ngraph_error("QuantizedAvgPool supported only for i8/u8!"); throw ngraph_error("QuantizedAvgPool supported only for i8/u8!");
......
...@@ -47,9 +47,9 @@ namespace ngraph ...@@ -47,9 +47,9 @@ namespace ngraph
const CoordinateDiff& get_padding_below() const { return m_padding_below; } const CoordinateDiff& get_padding_below() const { return m_padding_below; }
const CoordinateDiff& get_padding_above() const { return m_padding_above; } const CoordinateDiff& get_padding_above() const { return m_padding_above; }
const Strides& get_data_dilation_strides() const { return m_data_dilation_strides; } const Strides& get_data_dilation_strides() const { return m_data_dilation_strides; }
Output<Node> get_bias() { return input(2).get_source_output(); } Output<Node> get_bias() { return input_value(2); }
Output<Node> get_filters() { return input(1).get_source_output(); } Output<Node> get_filters() { return input_value(1); }
Output<Node> get_data_batch() { return input(0).get_source_output(); } Output<Node> get_data_batch() { return input_value(0); }
bool with_relu() const { return m_with_relu; } bool with_relu() const { return m_with_relu; }
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
...@@ -87,9 +87,9 @@ namespace ngraph ...@@ -87,9 +87,9 @@ namespace ngraph
const CoordinateDiff& get_padding_below() const { return m_padding_below; } const CoordinateDiff& get_padding_below() const { return m_padding_below; }
const CoordinateDiff& get_padding_above() const { return m_padding_above; } const CoordinateDiff& get_padding_above() const { return m_padding_above; }
const Strides& get_data_dilation_strides() const { return m_data_dilation_strides; } const Strides& get_data_dilation_strides() const { return m_data_dilation_strides; }
Output<Node> get_bias() { return input(2).get_source_output(); } Output<Node> get_bias() { return input_value(2); }
Output<Node> get_filters() { return input(1).get_source_output(); } Output<Node> get_filters() { return input_value(1); }
Output<Node> get_data_batch() { return input(0).get_source_output(); } Output<Node> get_data_batch() { return input_value(0); }
bool with_relu() const { return m_with_relu; } bool with_relu() const { return m_with_relu; }
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
...@@ -127,9 +127,9 @@ namespace ngraph ...@@ -127,9 +127,9 @@ namespace ngraph
const CoordinateDiff& get_padding_below() const { return m_padding_below; } const CoordinateDiff& get_padding_below() const { return m_padding_below; }
const CoordinateDiff& get_padding_above() const { return m_padding_above; } const CoordinateDiff& get_padding_above() const { return m_padding_above; }
const Strides& get_data_dilation_strides() const { return m_data_dilation_strides; } const Strides& get_data_dilation_strides() const { return m_data_dilation_strides; }
Output<Node> get_bias() { return input(2).get_source_output(); } Output<Node> get_bias() { return input_value(2); }
Output<Node> get_filters() { return input(1).get_source_output(); } Output<Node> get_filters() { return input_value(1); }
Output<Node> get_data_batch() { return input(0).get_source_output(); } Output<Node> get_data_batch() { return input_value(0); }
bool with_relu() const { return m_with_relu; } bool with_relu() const { return m_with_relu; }
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
......
...@@ -45,8 +45,8 @@ namespace ngraph ...@@ -45,8 +45,8 @@ namespace ngraph
const CoordinateDiff& get_padding_below() const { return m_padding_below; } const CoordinateDiff& get_padding_below() const { return m_padding_below; }
const CoordinateDiff& get_padding_above() const { return m_padding_above; } const CoordinateDiff& get_padding_above() const { return m_padding_above; }
const Strides& get_data_dilation_strides() const { return m_data_dilation_strides; } const Strides& get_data_dilation_strides() const { return m_data_dilation_strides; }
Output<Node> get_filters() { return input(1).get_source_output(); } Output<Node> get_filters() { return input_value(1); }
Output<Node> get_data_batch() { return input(0).get_source_output(); } Output<Node> get_data_batch() { return input_value(0); }
bool with_relu() const { return true; } bool with_relu() const { return true; }
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
......
...@@ -123,9 +123,12 @@ static ...@@ -123,9 +123,12 @@ static
template <typename T> template <typename T>
static PartialShape infer_output_shape(const op::Range* node, const element::Type& et) static PartialShape infer_output_shape(const op::Range* node, const element::Type& et)
{ {
auto const_start = dynamic_pointer_cast<op::Constant>(node->get_argument(0)); auto const_start =
auto const_stop = dynamic_pointer_cast<op::Constant>(node->get_argument(1)); dynamic_pointer_cast<op::Constant>(node->input_value(0).get_node_shared_ptr());
auto const_step = dynamic_pointer_cast<op::Constant>(node->get_argument(2)); auto const_stop =
dynamic_pointer_cast<op::Constant>(node->input_value(1).get_node_shared_ptr());
auto const_step =
dynamic_pointer_cast<op::Constant>(node->input_value(2).get_node_shared_ptr());
T start = static_cast<T>(0); T start = static_cast<T>(0);
T stop = static_cast<T>(0); T stop = static_cast<T>(0);
......
...@@ -60,7 +60,8 @@ void op::Tile::validate_and_infer_types() ...@@ -60,7 +60,8 @@ void op::Tile::validate_and_infer_types()
auto out_shape = PartialShape::dynamic(output_rank); auto out_shape = PartialShape::dynamic(output_rank);
if (auto const_repeats = dynamic_pointer_cast<op::Constant>(get_argument(1))) if (auto const_repeats =
dynamic_pointer_cast<op::Constant>(input_value(1).get_node_shared_ptr()))
{ {
if (arg_shape.is_static()) if (arg_shape.is_static())
{ {
......
...@@ -47,7 +47,8 @@ void op::Transpose::validate_and_infer_types() ...@@ -47,7 +47,8 @@ void op::Transpose::validate_and_infer_types()
set_input_is_relevant_to_shape(1); set_input_is_relevant_to_shape(1);
if (auto input_const = std::dynamic_pointer_cast<op::Constant>(get_argument(1))) if (auto input_const =
std::dynamic_pointer_cast<op::Constant>(input_value(1).get_node_shared_ptr()))
{ {
auto permutation = input_const->get_axis_vector_val(); auto permutation = input_const->get_axis_vector_val();
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
......
...@@ -40,7 +40,7 @@ void op::Clamp::pre_validate_and_infer_types() ...@@ -40,7 +40,7 @@ void op::Clamp::pre_validate_and_infer_types()
NodeVector op::Clamp::decompose_op() const NodeVector op::Clamp::decompose_op() const
{ {
const auto data = input(0).get_source_output(); const auto data = input_value(0);
const auto data_shape = data.get_shape(); const auto data_shape = data.get_shape();
const auto clamp_min = builder::make_constant(data.get_element_type(), data_shape, m_min); const auto clamp_min = builder::make_constant(data.get_element_type(), data_shape, m_min);
......
...@@ -91,8 +91,8 @@ op::ConvolutionBias::ConvolutionBias(const Output<Node>& data_batch, ...@@ -91,8 +91,8 @@ op::ConvolutionBias::ConvolutionBias(const Output<Node>& data_batch,
op::ConvolutionBias::ConvolutionBias(const shared_ptr<op::Convolution>& conv, op::ConvolutionBias::ConvolutionBias(const shared_ptr<op::Convolution>& conv,
const Output<Node>& bias, const Output<Node>& bias,
const bool with_relu) const bool with_relu)
: ConvolutionBias(conv->input(0).get_source_output(), : ConvolutionBias(conv->input_value(0),
conv->input(1).get_source_output(), conv->input_value(1),
bias, bias,
conv->get_window_movement_strides(), conv->get_window_movement_strides(),
conv->get_window_dilation_strides(), conv->get_window_dilation_strides(),
...@@ -201,8 +201,8 @@ shared_ptr<Node> op::ConvolutionBias::copy_with_new_args(const NodeVector& new_a ...@@ -201,8 +201,8 @@ shared_ptr<Node> op::ConvolutionBias::copy_with_new_args(const NodeVector& new_a
NodeVector op::ConvolutionBias::decompose_op() const NodeVector op::ConvolutionBias::decompose_op() const
{ {
auto conv = make_shared<op::Convolution>(input(0).get_source_output(), auto conv = make_shared<op::Convolution>(input_value(0),
input(1).get_source_output(), input_value(1),
m_window_movement_strides, m_window_movement_strides,
m_window_dilation_strides, m_window_dilation_strides,
m_padding_below, m_padding_below,
...@@ -216,8 +216,7 @@ NodeVector op::ConvolutionBias::decompose_op() const ...@@ -216,8 +216,7 @@ NodeVector op::ConvolutionBias::decompose_op() const
} }
auto conv_bias = make_shared<op::Add>( auto conv_bias = make_shared<op::Add>(
conv, conv, make_shared<op::Broadcast>(input_value(2), conv->get_shape(), bcast_axes));
make_shared<op::Broadcast>(input(2).get_source_output(), conv->get_shape(), bcast_axes));
if (m_with_relu) if (m_with_relu)
{ {
return {make_shared<op::Relu>(conv_bias)}; return {make_shared<op::Relu>(conv_bias)};
...@@ -236,13 +235,13 @@ void op::ConvolutionBias::generate_adjoints(autodiff::Adjoints& adjoints, const ...@@ -236,13 +235,13 @@ void op::ConvolutionBias::generate_adjoints(autodiff::Adjoints& adjoints, const
delta = make_shared<op::ReluBackprop>(shared_from_this(), delta); delta = make_shared<op::ReluBackprop>(shared_from_this(), delta);
} }
auto data = input(0).get_source_output(); auto data = input_value(0);
const auto data_shape = data.get_shape(); const auto data_shape = data.get_shape();
auto filter = input(1).get_source_output(); auto filter = input_value(1);
const auto filter_shape = filter.get_shape(); const auto filter_shape = filter.get_shape();
auto bias = input(2).get_source_output(); auto bias = input_value(2);
const auto bias_shape = bias.get_shape(); const auto bias_shape = bias.get_shape();
// using regular convolution backprop for data // using regular convolution backprop for data
...@@ -339,9 +338,9 @@ shared_ptr<Node> ...@@ -339,9 +338,9 @@ shared_ptr<Node>
NodeVector op::ConvolutionBiasBackpropFiltersBias::decompose_op() const NodeVector op::ConvolutionBiasBackpropFiltersBias::decompose_op() const
{ {
auto conv_bprop = make_shared<op::ConvolutionBackpropFilters>(input(0).get_source_output(), auto conv_bprop = make_shared<op::ConvolutionBackpropFilters>(input_value(0),
m_filters_shape, m_filters_shape,
input(1).get_source_output(), input_value(1),
m_window_movement_strides_forward, m_window_movement_strides_forward,
m_window_dilation_strides_forward, m_window_dilation_strides_forward,
m_padding_below_forward, m_padding_below_forward,
...@@ -355,7 +354,7 @@ NodeVector op::ConvolutionBiasBackpropFiltersBias::decompose_op() const ...@@ -355,7 +354,7 @@ NodeVector op::ConvolutionBiasBackpropFiltersBias::decompose_op() const
reduce_axes.insert(i); reduce_axes.insert(i);
} }
auto bias_bprop = make_shared<op::Sum>(input(1).get_source_output(), reduce_axes); auto bias_bprop = make_shared<op::Sum>(input_value(1), reduce_axes);
return {conv_bprop, bias_bprop}; return {conv_bprop, bias_bprop};
} }
...@@ -384,9 +383,9 @@ op::ConvolutionBiasAdd::ConvolutionBiasAdd(const Output<Node>& data_batch, ...@@ -384,9 +383,9 @@ op::ConvolutionBiasAdd::ConvolutionBiasAdd(const Output<Node>& data_batch,
op::ConvolutionBiasAdd::ConvolutionBiasAdd(const std::shared_ptr<op::ConvolutionBias>& conv, op::ConvolutionBiasAdd::ConvolutionBiasAdd(const std::shared_ptr<op::ConvolutionBias>& conv,
const Output<Node>& add_input, const Output<Node>& add_input,
bool with_relu) bool with_relu)
: ConvolutionBiasAdd(conv->input(0).get_source_output(), : ConvolutionBiasAdd(conv->input_value(0),
conv->input(1).get_source_output(), conv->input_value(1),
conv->input(2).get_source_output(), conv->input_value(2),
add_input, add_input,
conv->get_window_movement_strides(), conv->get_window_movement_strides(),
conv->get_window_dilation_strides(), conv->get_window_dilation_strides(),
...@@ -457,8 +456,8 @@ std::shared_ptr<Node> op::ConvolutionBiasAdd::copy_with_new_args(const NodeVecto ...@@ -457,8 +456,8 @@ std::shared_ptr<Node> op::ConvolutionBiasAdd::copy_with_new_args(const NodeVecto
NodeVector op::ConvolutionBiasAdd::decompose_op() const NodeVector op::ConvolutionBiasAdd::decompose_op() const
{ {
auto conv = make_shared<op::Convolution>(input(0).get_source_output(), auto conv = make_shared<op::Convolution>(input_value(0),
input(1).get_source_output(), input_value(1),
m_window_movement_strides, m_window_movement_strides,
m_window_dilation_strides, m_window_dilation_strides,
m_padding_below, m_padding_below,
...@@ -472,14 +471,13 @@ NodeVector op::ConvolutionBiasAdd::decompose_op() const ...@@ -472,14 +471,13 @@ NodeVector op::ConvolutionBiasAdd::decompose_op() const
} }
auto conv_bias = make_shared<op::Add>( auto conv_bias = make_shared<op::Add>(
conv, conv, make_shared<op::Broadcast>(input_value(2), conv->get_shape(), bcast_axes));
make_shared<op::Broadcast>(input(2).get_source_output(), conv->get_shape(), bcast_axes));
if (m_with_relu) if (m_with_relu)
{ {
return {make_shared<op::Relu>(conv_bias + input(3).get_source_output())}; return {make_shared<op::Relu>(conv_bias + input_value(3))};
} }
else else
{ {
return {conv_bias + input(3).get_source_output()}; return {conv_bias + input_value(3)};
} }
} }
...@@ -55,9 +55,9 @@ namespace ngraph ...@@ -55,9 +55,9 @@ namespace ngraph
const CoordinateDiff& get_padding_below() const { return m_padding_below; } const CoordinateDiff& get_padding_below() const { return m_padding_below; }
const CoordinateDiff& get_padding_above() const { return m_padding_above; } const CoordinateDiff& get_padding_above() const { return m_padding_above; }
const Strides& get_data_dilation_strides() const { return m_data_dilation_strides; } const Strides& get_data_dilation_strides() const { return m_data_dilation_strides; }
Output<Node> get_bias() { return input(2).get_source_output(); } Output<Node> get_bias() { return input_value(2); }
Output<Node> get_filters() { return input(1).get_source_output(); } Output<Node> get_filters() { return input_value(1); }
Output<Node> get_data_batch() { return input(0).get_source_output(); } Output<Node> get_data_batch() { return input_value(0); }
bool with_relu() const { return m_with_relu; } bool with_relu() const { return m_with_relu; }
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
...@@ -201,8 +201,8 @@ namespace ngraph ...@@ -201,8 +201,8 @@ namespace ngraph
const CoordinateDiff& get_padding_below() const { return m_padding_below; } const CoordinateDiff& get_padding_below() const { return m_padding_below; }
const CoordinateDiff& get_padding_above() const { return m_padding_above; } const CoordinateDiff& get_padding_above() const { return m_padding_above; }
const Strides& get_data_dilation_strides() const { return m_data_dilation_strides; } const Strides& get_data_dilation_strides() const { return m_data_dilation_strides; }
Output<Node> get_filters() { return input(1).get_source_output(); } Output<Node> get_filters() { return input_value(1); }
Output<Node> get_data_batch() { return input(0).get_source_output(); } Output<Node> get_data_batch() { return input_value(0); }
bool with_relu() const { return m_with_relu; } bool with_relu() const { return m_with_relu; }
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
......
...@@ -35,7 +35,7 @@ op::DepthToSpace::DepthToSpace(const Output<Node>& data, const size_t block_size ...@@ -35,7 +35,7 @@ op::DepthToSpace::DepthToSpace(const Output<Node>& data, const size_t block_size
NodeVector op::DepthToSpace::decompose_op() const NodeVector op::DepthToSpace::decompose_op() const
{ {
auto data = input(0).get_source_output(); auto data = input_value(0);
const Shape& data_shape = data.get_shape(); const Shape& data_shape = data.get_shape();
// Set default values to each dimension to be able to work with both 3D or 4D data. // Set default values to each dimension to be able to work with both 3D or 4D data.
......
...@@ -37,8 +37,8 @@ op::Elu::Elu(const Output<Node>& data, const Output<Node>& alpha) ...@@ -37,8 +37,8 @@ op::Elu::Elu(const Output<Node>& data, const Output<Node>& alpha)
NodeVector op::Elu::decompose_op() const NodeVector op::Elu::decompose_op() const
{ {
auto data = input(0).get_source_output(); auto data = input_value(0);
auto alpha_node = input(1).get_source_output(); auto alpha_node = input_value(1);
alpha_node = ngraph::op::numpy_style_broadcast(alpha_node, data.get_shape()); alpha_node = ngraph::op::numpy_style_broadcast(alpha_node, data.get_shape());
......
...@@ -100,11 +100,11 @@ void op::FakeQuantize::pre_validate_and_infer_types() ...@@ -100,11 +100,11 @@ void op::FakeQuantize::pre_validate_and_infer_types()
NodeVector op::FakeQuantize::decompose_op() const NodeVector op::FakeQuantize::decompose_op() const
{ {
Output<Node> data{input(0).get_source_output()}; Output<Node> data{input_value(0)};
Output<Node> input_low{input(1).get_source_output()}; Output<Node> input_low{input_value(1)};
Output<Node> input_high{input(2).get_source_output()}; Output<Node> input_high{input_value(2)};
Output<Node> output_low{input(3).get_source_output()}; Output<Node> output_low{input_value(3)};
Output<Node> output_high{input(4).get_source_output()}; Output<Node> output_high{input_value(4)};
if (input_low.get_shape().size() == 0) if (input_low.get_shape().size() == 0)
{ {
......
...@@ -36,7 +36,7 @@ op::Gelu::Gelu(const Output<Node>& data) ...@@ -36,7 +36,7 @@ op::Gelu::Gelu(const Output<Node>& data)
// f(x) = 0.5 * x * (1.0 + erf( x / sqrt(2.0) ) // f(x) = 0.5 * x * (1.0 + erf( x / sqrt(2.0) )
NodeVector op::Gelu::decompose_op() const NodeVector op::Gelu::decompose_op() const
{ {
auto data = input(0).get_source_output(); auto data = input_value(0);
shared_ptr<ngraph::Node> half = shared_ptr<ngraph::Node> half =
builder::make_constant(data.get_element_type(), data.get_shape(), 0.5); builder::make_constant(data.get_element_type(), data.get_shape(), 0.5);
......
...@@ -45,9 +45,9 @@ op::Gemm::Gemm(const Output<Node>& A, ...@@ -45,9 +45,9 @@ op::Gemm::Gemm(const Output<Node>& A,
NodeVector op::Gemm::decompose_op() const NodeVector op::Gemm::decompose_op() const
{ {
auto A = input(0).get_source_output(); auto A = input_value(0);
auto B = input(1).get_source_output(); auto B = input_value(1);
auto C = input(2).get_source_output(); auto C = input_value(2);
if (m_transA) if (m_transA)
{ {
......
...@@ -56,7 +56,7 @@ void op::GRN::pre_validate_and_infer_types() ...@@ -56,7 +56,7 @@ void op::GRN::pre_validate_and_infer_types()
NodeVector op::GRN::decompose_op() const NodeVector op::GRN::decompose_op() const
{ {
Output<Node> data{input(0).get_source_output()}; Output<Node> data{input_value(0)};
const Shape& input_shape{data.get_shape()}; const Shape& input_shape{data.get_shape()};
// Reshape to 4D tensor. // Reshape to 4D tensor.
......
...@@ -129,8 +129,8 @@ shared_ptr<Node> op::GroupConvolution::copy_with_new_args(const NodeVector& new_ ...@@ -129,8 +129,8 @@ shared_ptr<Node> op::GroupConvolution::copy_with_new_args(const NodeVector& new_
NodeVector op::GroupConvolution::decompose_op() const NodeVector op::GroupConvolution::decompose_op() const
{ {
auto data = input(0); auto data = input_value(0);
auto filters = input(1); auto filters = input_value(1);
// Split one convolution op to N ops where N is the number of groups // Split one convolution op to N ops where N is the number of groups
// and concat results after computation. // and concat results after computation.
// reference: https://github.com/NervanaSystems/ngraph-mxnet/blob/fdd692/src/ngraph/ngraph_emitter.cc#L822-L856 // reference: https://github.com/NervanaSystems/ngraph-mxnet/blob/fdd692/src/ngraph/ngraph_emitter.cc#L822-L856
...@@ -151,13 +151,13 @@ NodeVector op::GroupConvolution::decompose_op() const ...@@ -151,13 +151,13 @@ NodeVector op::GroupConvolution::decompose_op() const
// slice data // slice data
data_lower_bounds[1] = group * data_group_size; data_lower_bounds[1] = group * data_group_size;
data_upper_bounds[1] = (group + 1) * data_group_size; data_upper_bounds[1] = (group + 1) * data_group_size;
auto sliced_data = std::make_shared<ngraph::op::Slice>( auto sliced_data =
data.get_source_output(), data_lower_bounds, data_upper_bounds); std::make_shared<ngraph::op::Slice>(data, data_lower_bounds, data_upper_bounds);
// slice filters // slice filters
filters_lower_bounds[0] = group * filters_group_size; filters_lower_bounds[0] = group * filters_group_size;
filters_upper_bounds[0] = (group + 1) * filters_group_size; filters_upper_bounds[0] = (group + 1) * filters_group_size;
auto sliced_filters = std::make_shared<ngraph::op::Slice>( auto sliced_filters = std::make_shared<ngraph::op::Slice>(
filters.get_source_output(), filters_lower_bounds, filters_upper_bounds); filters, filters_lower_bounds, filters_upper_bounds);
convolution_nodes.push_back( convolution_nodes.push_back(
std::make_shared<ngraph::op::Convolution>(sliced_data, std::make_shared<ngraph::op::Convolution>(sliced_data,
......
...@@ -49,8 +49,8 @@ namespace ngraph ...@@ -49,8 +49,8 @@ namespace ngraph
const CoordinateDiff& get_padding_below() const { return m_padding_below; } const CoordinateDiff& get_padding_below() const { return m_padding_below; }
const CoordinateDiff& get_padding_above() const { return m_padding_above; } const CoordinateDiff& get_padding_above() const { return m_padding_above; }
const Strides& get_data_dilation_strides() const { return m_data_dilation_strides; } const Strides& get_data_dilation_strides() const { return m_data_dilation_strides; }
Output<Node> get_filters() { return input(1).get_source_output(); } Output<Node> get_filters() { return input_value(1); }
Output<Node> get_data_batch() { return input(0).get_source_output(); } Output<Node> get_data_batch() { return input_value(0); }
size_t get_groups() const { return m_groups; } size_t get_groups() const { return m_groups; }
const PadType& get_pad_type() const { return m_pad_type; } const PadType& get_pad_type() const { return m_pad_type; }
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
......
...@@ -267,8 +267,8 @@ Shape op::GroupConvolutionTranspose::get_data_batch_shape() const ...@@ -267,8 +267,8 @@ Shape op::GroupConvolutionTranspose::get_data_batch_shape() const
NodeVector op::GroupConvolutionTranspose::decompose_op() const NodeVector op::GroupConvolutionTranspose::decompose_op() const
{ {
auto data = input(0).get_source_output(); auto data = input_value(0);
auto filters = input(1).get_source_output(); auto filters = input_value(1);
const Shape data_batch_shape = get_data_batch_shape(); const Shape data_batch_shape = get_data_batch_shape();
const size_t num_spatial_dims = data.get_shape().size() - 2; const size_t num_spatial_dims = data.get_shape().size() - 2;
......
...@@ -114,8 +114,8 @@ namespace ngraph ...@@ -114,8 +114,8 @@ namespace ngraph
const Shape& output_shape, const Shape& output_shape,
const std::size_t groups = 1UL); const std::size_t groups = 1UL);
Output<Node> get_data() { return input(0).get_source_output(); } Output<Node> get_data() { return input_value(0); }
Output<Node> get_filters() { return input(1).get_source_output(); } Output<Node> get_filters() { return input_value(1); }
const Strides& get_strides() const { return m_strides; } const Strides& get_strides() const { return m_strides; }
const Strides& get_dilations() const { return m_dilations; } const Strides& get_dilations() const { return m_dilations; }
const CoordinateDiff& get_padding_begin() const { return m_padding_begin; } const CoordinateDiff& get_padding_begin() const { return m_padding_begin; }
......
...@@ -186,11 +186,11 @@ NodeVector op::GRUCell::decompose_op() const ...@@ -186,11 +186,11 @@ NodeVector op::GRUCell::decompose_op() const
// Ht = (1 - zt) (.) ht + zt (.) Ht-1 // Ht = (1 - zt) (.) ht + zt (.) Ht-1
// ------------------- // -------------------
Output<Node> X = input(0).get_source_output(); Output<Node> X = input_value(0);
Output<Node> W = input(1).get_source_output(); Output<Node> W = input_value(1);
Output<Node> R = input(2).get_source_output(); Output<Node> R = input_value(2);
Output<Node> H_t = input(3).get_source_output(); Output<Node> H_t = input_value(3);
Output<Node> B = input(4).get_source_output(); Output<Node> B = input_value(4);
// Get W and R biases separately. // Get W and R biases separately.
NodeVector b_W_R = builder::split(B, 2); NodeVector b_W_R = builder::split(B, 2);
......
...@@ -39,7 +39,7 @@ op::HardSigmoid::HardSigmoid(const Output<Node>& data, float alpha, float beta) ...@@ -39,7 +39,7 @@ op::HardSigmoid::HardSigmoid(const Output<Node>& data, float alpha, float beta)
NodeVector op::HardSigmoid::decompose_op() const NodeVector op::HardSigmoid::decompose_op() const
{ {
auto data = input(0).get_source_output(); auto data = input_value(0);
auto data_shape = data.get_shape(); auto data_shape = data.get_shape();
size_t elem_count = shape_size(data_shape); size_t elem_count = shape_size(data_shape);
......
...@@ -224,11 +224,11 @@ NodeVector op::LSTMCell::decompose_op() const ...@@ -224,11 +224,11 @@ NodeVector op::LSTMCell::decompose_op() const
// Ht = ot (.) h(Ct) // Ht = ot (.) h(Ct)
// -------------------- // --------------------
Output<Node> X = input(0).get_source_output(); Output<Node> X = input_value(0);
Output<Node> W = input(1).get_source_output(); Output<Node> W = input_value(1);
Output<Node> R = input(2).get_source_output(); Output<Node> R = input_value(2);
Output<Node> H_t = input(3).get_source_output(); Output<Node> H_t = input_value(3);
Output<Node> C_t = input(4).get_source_output(); Output<Node> C_t = input_value(4);
Output<Node> bias = get_bias(); Output<Node> bias = get_bias();
NodeVector p_iof = get_peephole_weights(); NodeVector p_iof = get_peephole_weights();
...@@ -278,7 +278,7 @@ Output<Node> op::LSTMCell::get_bias() const ...@@ -278,7 +278,7 @@ Output<Node> op::LSTMCell::get_bias() const
{ {
Output<Node> bias; Output<Node> bias;
// Split B onto Wb an Rb and add them. // Split B onto Wb an Rb and add them.
NodeVector b_W_R = builder::split(input(5).get_source_output(), 2); NodeVector b_W_R = builder::split(input_value(5), 2);
bias = b_W_R.at(0) + b_W_R.at(1); bias = b_W_R.at(0) + b_W_R.at(1);
return bias; return bias;
} }
...@@ -286,7 +286,7 @@ Output<Node> op::LSTMCell::get_bias() const ...@@ -286,7 +286,7 @@ Output<Node> op::LSTMCell::get_bias() const
NodeVector op::LSTMCell::get_peephole_weights() const NodeVector op::LSTMCell::get_peephole_weights() const
{ {
Output<Node> P; Output<Node> P;
P = input(6).get_source_output(); P = input_value(6);
return builder::split(P, s_peepholes_count); return builder::split(P, s_peepholes_count);
} }
......
...@@ -59,7 +59,7 @@ op::MVN::MVN(const Output<Node>& data, AxisSet reduction_axes, bool normalize_va ...@@ -59,7 +59,7 @@ op::MVN::MVN(const Output<Node>& data, AxisSet reduction_axes, bool normalize_va
NodeVector op::MVN::decompose_op() const NodeVector op::MVN::decompose_op() const
{ {
auto data = input(0).get_source_output(); auto data = input_value(0);
auto data_shape = data.get_shape(); // assume that data has n and c channels. auto data_shape = data.get_shape(); // assume that data has n and c channels.
// calculate mean normalization // calculate mean normalization
......
...@@ -88,7 +88,7 @@ void op::Normalize::pre_validate_and_infer_types() ...@@ -88,7 +88,7 @@ void op::Normalize::pre_validate_and_infer_types()
NodeVector op::Normalize::decompose_op() const NodeVector op::Normalize::decompose_op() const
{ {
Output<Node> data{input(0).get_source_output()}; Output<Node> data{input_value(0)};
const Shape input_shape{data.get_shape()}; const Shape input_shape{data.get_shape()};
// Reshape to 4D tensor. // Reshape to 4D tensor.
...@@ -111,7 +111,7 @@ NodeVector op::Normalize::decompose_op() const ...@@ -111,7 +111,7 @@ NodeVector op::Normalize::decompose_op() const
Output<Node> norm = builder::l2_norm(data, reduction_axes, m_eps); Output<Node> norm = builder::l2_norm(data, reduction_axes, m_eps);
norm = make_broadcast_node(norm, data.get_shape(), 0); norm = make_broadcast_node(norm, data.get_shape(), 0);
Output<Node> scale_node{input(1).get_source_output()}; Output<Node> scale_node{input_value(1)};
// Broadcast scale to data tensor shape. // Broadcast scale to data tensor shape.
if (m_channel_shared) if (m_channel_shared)
......
...@@ -37,9 +37,9 @@ op::PRelu::PRelu(const Output<Node>& data, const Output<Node>& slope) ...@@ -37,9 +37,9 @@ op::PRelu::PRelu(const Output<Node>& data, const Output<Node>& slope)
NodeVector op::PRelu::decompose_op() const NodeVector op::PRelu::decompose_op() const
{ {
auto data = input(0).get_source_output(); auto data = input_value(0);
auto data_shape = data.get_shape(); auto data_shape = data.get_shape();
auto slope = input(1).get_source_output(); auto slope = input_value(1);
auto slope_shape = slope.get_shape(); auto slope_shape = slope.get_shape();
if ((slope_shape.size() == 1) && (slope_shape.at(0) != 1)) if ((slope_shape.size() == 1) && (slope_shape.at(0) != 1))
......
...@@ -167,10 +167,10 @@ NodeVector op::RNNCell::decompose_op() const ...@@ -167,10 +167,10 @@ NodeVector op::RNNCell::decompose_op() const
// Ht = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Wbi + Rbi) // Ht = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Wbi + Rbi)
// -------------------- // --------------------
Output<Node> X = input(0).get_source_output(); Output<Node> X = input_value(0);
Output<Node> W = input(1).get_source_output(); Output<Node> W = input_value(1);
Output<Node> R = input(2).get_source_output(); Output<Node> R = input_value(2);
Output<Node> H_t = input(3).get_source_output(); Output<Node> H_t = input_value(3);
Output<Node> bias = get_bias(); Output<Node> bias = get_bias();
// Xt*(W^T) // Xt*(W^T)
...@@ -190,7 +190,7 @@ Output<Node> op::RNNCell::get_bias() const ...@@ -190,7 +190,7 @@ Output<Node> op::RNNCell::get_bias() const
{ {
Output<Node> bias; Output<Node> bias;
// Split B onto Wb an Rb and add them. // Split B onto Wb an Rb and add them.
NodeVector b_W_R = builder::split(input(4).get_source_output(), 2); NodeVector b_W_R = builder::split(input_value(4), 2);
bias = b_W_R.at(0) + b_W_R.at(1); bias = b_W_R.at(0) + b_W_R.at(1);
return bias; return bias;
} }
......
...@@ -33,9 +33,9 @@ op::ScaleShift::ScaleShift(const Output<Node>& data, ...@@ -33,9 +33,9 @@ op::ScaleShift::ScaleShift(const Output<Node>& data,
NodeVector op::ScaleShift::decompose_op() const NodeVector op::ScaleShift::decompose_op() const
{ {
auto data = input(0).get_source_output(); auto data = input_value(0);
auto scale = input(1).get_source_output(); auto scale = input_value(1);
auto shift = input(2).get_source_output(); auto shift = input_value(2);
// broadcast all data // broadcast all data
auto broadcasted_nodes = numpy_style_broadcast_values({data, scale, shift}); auto broadcasted_nodes = numpy_style_broadcast_values({data, scale, shift});
......
...@@ -74,7 +74,7 @@ void op::ShuffleChannels::pre_validate_and_infer_types() ...@@ -74,7 +74,7 @@ void op::ShuffleChannels::pre_validate_and_infer_types()
NodeVector op::ShuffleChannels::decompose_op() const NodeVector op::ShuffleChannels::decompose_op() const
{ {
const auto data = input(0).get_source_output(); const auto data = input_value(0);
const auto& data_shape = data.get_shape(); const auto& data_shape = data.get_shape();
const auto reshaped = builder::reshape(data, get_pre_shuffle_shape(data_shape)); const auto reshaped = builder::reshape(data, get_pre_shuffle_shape(data_shape));
......
...@@ -34,7 +34,7 @@ op::SpaceToDepth::SpaceToDepth(const Output<Node>& data, const size_t block_size ...@@ -34,7 +34,7 @@ op::SpaceToDepth::SpaceToDepth(const Output<Node>& data, const size_t block_size
NodeVector op::SpaceToDepth::decompose_op() const NodeVector op::SpaceToDepth::decompose_op() const
{ {
auto data = input(0).get_source_output(); auto data = input_value(0);
const Shape& data_shape = data.get_shape(); const Shape& data_shape = data.get_shape();
// Set default values to each dimension to be able to work with both 3D or 4D data. // Set default values to each dimension to be able to work with both 3D or 4D data.
......
...@@ -84,7 +84,7 @@ void op::Split::pre_validate_and_infer_types() ...@@ -84,7 +84,7 @@ void op::Split::pre_validate_and_infer_types()
NodeVector op::Split::decompose_op() const NodeVector op::Split::decompose_op() const
{ {
return builder::split(input(0).get_source_output(), m_splits, m_axis); return builder::split(input_value(0), m_splits, m_axis);
} }
shared_ptr<Node> op::Split::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Split::copy_with_new_args(const NodeVector& new_args) const
......
...@@ -34,8 +34,8 @@ op::SquaredDifference::SquaredDifference(const Output<Node>& x1, const Output<No ...@@ -34,8 +34,8 @@ op::SquaredDifference::SquaredDifference(const Output<Node>& x1, const Output<No
NodeVector op::SquaredDifference::decompose_op() const NodeVector op::SquaredDifference::decompose_op() const
{ {
const auto x1 = input(0).get_source_output(); const auto x1 = input_value(0);
const auto x2 = input(1).get_source_output(); const auto x2 = input_value(1);
const auto broadcasted = numpy_style_broadcast_values({x1, x2}); const auto broadcasted = numpy_style_broadcast_values({x1, x2});
......
...@@ -34,8 +34,8 @@ op::Squeeze::Squeeze(const Output<Node>& data, const Output<Node>& axes) ...@@ -34,8 +34,8 @@ op::Squeeze::Squeeze(const Output<Node>& data, const Output<Node>& axes)
NodeVector op::Squeeze::decompose_op() const NodeVector op::Squeeze::decompose_op() const
{ {
auto data = input(0).get_source_output(); auto data = input_value(0);
auto axes_node = input(1).get_source_output().get_node_shared_ptr(); auto axes_node = input_value(1).get_node_shared_ptr();
// Currently only support Constant node for axes. // Currently only support Constant node for axes.
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
......
...@@ -34,7 +34,7 @@ op::Unsqueeze::Unsqueeze(const Output<Node>& data, const Output<Node>& axes) ...@@ -34,7 +34,7 @@ op::Unsqueeze::Unsqueeze(const Output<Node>& data, const Output<Node>& axes)
void op::Unsqueeze::pre_validate_and_infer_types() void op::Unsqueeze::pre_validate_and_infer_types()
{ {
auto axes_node = input(1).get_source_output().get_node_shared_ptr(); auto axes_node = input_value(1).get_node_shared_ptr();
// Currently only support Constant node for axes. // Currently only support Constant node for axes.
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
...@@ -44,8 +44,8 @@ void op::Unsqueeze::pre_validate_and_infer_types() ...@@ -44,8 +44,8 @@ void op::Unsqueeze::pre_validate_and_infer_types()
NodeVector op::Unsqueeze::decompose_op() const NodeVector op::Unsqueeze::decompose_op() const
{ {
auto data = input(0).get_source_output(); auto data = input_value(0);
auto axes_node = input(1).get_source_output().get_node_shared_ptr(); auto axes_node = input_value(1).get_node_shared_ptr();
// Get value of axes from Constant // Get value of axes from Constant
auto axes_constant = dynamic_pointer_cast<op::Constant>(axes_node); auto axes_constant = dynamic_pointer_cast<op::Constant>(axes_node);
......
...@@ -33,7 +33,7 @@ op::GetOutputElement::GetOutputElement(const shared_ptr<Node>& arg, size_t n) ...@@ -33,7 +33,7 @@ op::GetOutputElement::GetOutputElement(const shared_ptr<Node>& arg, size_t n)
void op::GetOutputElement::validate_and_infer_types() void op::GetOutputElement::validate_and_infer_types()
{ {
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
m_n < input(0).get_source_output().get_node()->get_output_size(), m_n < input_value(0).get_node()->get_output_size(),
"Output at index ", "Output at index ",
m_n, m_n,
" requested, but node has only ", " requested, but node has only ",
...@@ -51,19 +51,19 @@ shared_ptr<Node> op::GetOutputElement::copy_with_new_args(const NodeVector& new_ ...@@ -51,19 +51,19 @@ shared_ptr<Node> op::GetOutputElement::copy_with_new_args(const NodeVector& new_
Output<Node> op::GetOutputElement::get_as_output() const Output<Node> op::GetOutputElement::get_as_output() const
{ {
return input(0).get_source_output(); return input_value(0);
} }
NodeVector op::GetOutputElement::get_arguments() const NodeVector op::GetOutputElement::get_arguments() const
{ {
return NodeVector{input(0).get_source_output().get_node_shared_ptr()}; return NodeVector{input_value(0).get_node_shared_ptr()};
} }
void op::GetOutputElement::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas) void op::GetOutputElement::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
{ {
auto delta = deltas.at(0); auto delta = deltas.at(0);
adjoints.add_delta(input(0).get_source_output().get_node_shared_ptr(), delta, get_n()); adjoints.add_delta(input_value(0).get_node_shared_ptr(), delta, get_n());
} }
NodeVector op::get_output_elements(const shared_ptr<Node>& mon) NodeVector op::get_output_elements(const shared_ptr<Node>& mon)
......
...@@ -38,7 +38,7 @@ void op::Log::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& ...@@ -38,7 +38,7 @@ void op::Log::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector&
{ {
auto delta = deltas.at(0); auto delta = deltas.at(0);
auto x = get_argument(0); auto x = input_value(0);
adjoints.add_delta(x, delta / x); adjoints.add_delta(x, delta / x);
} }
...@@ -256,7 +256,7 @@ void op::MaxPool::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVect ...@@ -256,7 +256,7 @@ void op::MaxPool::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVect
auto delta = deltas.at(0); auto delta = deltas.at(0);
auto operand = input(0).get_source_output(); auto operand = input_value(0);
auto backprop = auto backprop =
make_shared<op::MaxPoolBackprop>(operand, make_shared<op::MaxPoolBackprop>(operand,
delta, delta,
......
...@@ -50,10 +50,10 @@ void op::Maximum::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVect ...@@ -50,10 +50,10 @@ void op::Maximum::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVect
auto delta = deltas.at(0); auto delta = deltas.at(0);
auto x = get_argument(0); auto x = input_value(0);
auto y = get_argument(1); auto y = input_value(1);
adjoints.add_delta( adjoints.add_delta(
x, delta * make_shared<op::Convert>(make_shared<op::Greater>(x, y), x->get_element_type())); x, delta * make_shared<op::Convert>(make_shared<op::Greater>(x, y), x.get_element_type()));
adjoints.add_delta( adjoints.add_delta(
y, delta * make_shared<op::Convert>(make_shared<op::Greater>(y, x), y->get_element_type())); y, delta * make_shared<op::Convert>(make_shared<op::Greater>(y, x), y.get_element_type()));
} }
...@@ -50,11 +50,11 @@ void op::Minimum::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVect ...@@ -50,11 +50,11 @@ void op::Minimum::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVect
auto delta = deltas.at(0); auto delta = deltas.at(0);
auto x = get_argument(0); auto x = input_value(0);
auto y = get_argument(1); auto y = input_value(1);
adjoints.add_delta( adjoints.add_delta(
x, delta * make_shared<op::Convert>(make_shared<op::Less>(x, y), x->get_element_type())); x, delta * make_shared<op::Convert>(make_shared<op::Less>(x, y), x.get_element_type()));
adjoints.add_delta( adjoints.add_delta(
y, delta * make_shared<op::Convert>(make_shared<op::Less>(y, x), y->get_element_type())); y, delta * make_shared<op::Convert>(make_shared<op::Less>(y, x), y.get_element_type()));
} }
...@@ -44,8 +44,8 @@ void op::Multiply::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVec ...@@ -44,8 +44,8 @@ void op::Multiply::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVec
auto delta = deltas.at(0); auto delta = deltas.at(0);
auto x = get_argument(0); auto x = input_value(0);
auto y = get_argument(1); auto y = input_value(1);
adjoints.add_delta(x, delta * y); adjoints.add_delta(x, delta * y);
adjoints.add_delta(y, x * delta); adjoints.add_delta(y, x * delta);
......
...@@ -37,7 +37,7 @@ void op::Negative::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVec ...@@ -37,7 +37,7 @@ void op::Negative::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVec
{ {
auto delta = deltas.at(0); auto delta = deltas.at(0);
auto x = get_argument(0); auto x = input_value(0);
adjoints.add_delta(x, -delta); adjoints.add_delta(x, -delta);
} }
......
...@@ -172,5 +172,5 @@ std::shared_ptr<Node> op::Pad::get_default_value() const ...@@ -172,5 +172,5 @@ std::shared_ptr<Node> op::Pad::get_default_value() const
{ {
axes.insert(i); axes.insert(i);
} }
return std::make_shared<op::Broadcast>(get_argument(1), get_shape(), axes); return std::make_shared<op::Broadcast>(input_value(1), get_shape(), axes);
} }
...@@ -45,8 +45,8 @@ void op::Power::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector ...@@ -45,8 +45,8 @@ void op::Power::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector
auto delta = deltas.at(0); auto delta = deltas.at(0);
auto x = get_argument(0); auto x = input_value(0);
auto y = get_argument(1); auto y = input_value(1);
auto log_x = make_shared<op::Log>(x); auto log_x = make_shared<op::Log>(x);
......
...@@ -52,5 +52,5 @@ void op::Relu::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& ...@@ -52,5 +52,5 @@ void op::Relu::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector&
auto delta = deltas.at(0); auto delta = deltas.at(0);
auto backprop = make_shared<op::ReluBackprop>(shared_from_this(), delta); auto backprop = make_shared<op::ReluBackprop>(shared_from_this(), delta);
adjoints.add_delta(get_argument(0), backprop); adjoints.add_delta(input_value(0), backprop);
} }
...@@ -178,10 +178,10 @@ void op::ReplaceSlice::generate_adjoints(autodiff::Adjoints& adjoints, const Nod ...@@ -178,10 +178,10 @@ void op::ReplaceSlice::generate_adjoints(autodiff::Adjoints& adjoints, const Nod
{ {
auto delta = deltas.at(0); auto delta = deltas.at(0);
auto x = get_argument(0); auto x = input_value(0);
auto y = get_argument(1); auto y = input_value(1);
auto& y_element_type = input(1).get_element_type(); auto& y_element_type = y.get_element_type();
auto y_shape = input(1).get_shape(); auto y_shape = y.get_shape();
auto zeros_shaped_like_y = op::Constant::create(y_element_type, y_shape, {0.0}); auto zeros_shaped_like_y = op::Constant::create(y_element_type, y_shape, {0.0});
......
...@@ -143,5 +143,5 @@ void op::Reshape::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVect ...@@ -143,5 +143,5 @@ void op::Reshape::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVect
reshape = make_shared<op::Reshape>(reshape, x_input_order, x_shape); reshape = make_shared<op::Reshape>(reshape, x_input_order, x_shape);
} }
adjoints.add_delta(get_argument(0), reshape); adjoints.add_delta(input_value(0), reshape);
} }
...@@ -32,7 +32,7 @@ op::Result::Result(const Output<Node>& arg, bool needs_default_layout) ...@@ -32,7 +32,7 @@ op::Result::Result(const Output<Node>& arg, bool needs_default_layout)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
// always borrow the placement conf even the default one // always borrow the placement conf even the default one
set_placement_index(get_argument(0)->get_placement_index()); set_placement_index(input_value(0).get_node()->get_placement_index());
} }
void op::Result::validate_and_infer_types() void op::Result::validate_and_infer_types()
...@@ -55,5 +55,5 @@ void op::Result::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVecto ...@@ -55,5 +55,5 @@ void op::Result::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVecto
{ {
auto delta = deltas.at(0); auto delta = deltas.at(0);
adjoints.add_delta(get_argument(0), delta); adjoints.add_delta(input_value(0), delta);
} }
...@@ -65,7 +65,7 @@ void op::Reverse::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVect ...@@ -65,7 +65,7 @@ void op::Reverse::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVect
{ {
auto delta = deltas.at(0); auto delta = deltas.at(0);
auto x = input(0).get_source_output(); auto x = input_value(0);
adjoints.add_delta(x, make_shared<op::Reverse>(delta, m_reversed_axes)); adjoints.add_delta(x, make_shared<op::Reverse>(delta, m_reversed_axes));
} }
...@@ -102,8 +102,8 @@ shared_ptr<Node> op::ReverseSequence::copy_with_new_args(const NodeVector& new_a ...@@ -102,8 +102,8 @@ shared_ptr<Node> op::ReverseSequence::copy_with_new_args(const NodeVector& new_a
void op::ReverseSequence::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas) void op::ReverseSequence::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
{ {
auto x = input(0).get_source_output(); auto x = input_value(0);
auto rs_delta = make_shared<ReverseSequence>( auto rs_delta =
deltas.at(0), input(1).get_source_output(), m_batch_axis, m_seq_axis); make_shared<ReverseSequence>(deltas.at(0), input_value(1), m_batch_axis, m_seq_axis);
adjoints.add_delta(x, rs_delta); adjoints.add_delta(x, rs_delta);
} }
...@@ -71,9 +71,9 @@ void op::Select::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVecto ...@@ -71,9 +71,9 @@ void op::Select::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVecto
{ {
auto delta = deltas.at(0); auto delta = deltas.at(0);
auto p = input(0).get_source_output(); auto p = input_value(0);
auto x = input(1).get_source_output(); auto x = input_value(1);
auto y = input(2).get_source_output(); auto y = input_value(2);
auto p_as_x_type = make_shared<op::Convert>(p, x.get_element_type()); auto p_as_x_type = make_shared<op::Convert>(p, x.get_element_type());
auto not_p_as_y_type = make_shared<op::Convert>(make_shared<op::Not>(p), y.get_element_type()); auto not_p_as_y_type = make_shared<op::Convert>(make_shared<op::Not>(p), y.get_element_type());
......
...@@ -52,6 +52,6 @@ void op::Sigmoid::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVect ...@@ -52,6 +52,6 @@ void op::Sigmoid::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVect
{ {
auto delta = deltas.at(0); auto delta = deltas.at(0);
auto backprop = make_shared<op::SigmoidBackprop>(input(0).get_source_output(), delta); auto backprop = make_shared<op::SigmoidBackprop>(input_value(0), delta);
adjoints.add_delta(input(0).get_source_output(), backprop); adjoints.add_delta(input_value(0), backprop);
} }
...@@ -39,7 +39,7 @@ void op::Sin::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& ...@@ -39,7 +39,7 @@ void op::Sin::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector&
{ {
auto delta = deltas.at(0); auto delta = deltas.at(0);
auto x = input(0).get_source_output(); auto x = input_value(0);
adjoints.add_delta(x, delta * (make_shared<op::Cos>(x))); adjoints.add_delta(x, delta * (make_shared<op::Cos>(x)));
} }
...@@ -39,7 +39,7 @@ void op::Sinh::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& ...@@ -39,7 +39,7 @@ void op::Sinh::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector&
{ {
auto delta = deltas.at(0); auto delta = deltas.at(0);
auto x = get_argument(0); auto x = input_value(0);
adjoints.add_delta(x, delta * (make_shared<op::Cosh>(x))); adjoints.add_delta(x, delta * (make_shared<op::Cosh>(x)));
} }
...@@ -135,7 +135,7 @@ void op::Slice::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector ...@@ -135,7 +135,7 @@ void op::Slice::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector
{ {
auto delta = deltas.at(0); auto delta = deltas.at(0);
auto x = input(0).get_source_output(); auto x = input_value(0);
adjoints.add_delta_to_slice(x, delta, m_lower_bounds, m_upper_bounds, m_strides); adjoints.add_delta_to_slice(x, delta, m_lower_bounds, m_upper_bounds, m_strides);
} }
...@@ -87,6 +87,6 @@ void op::Softmax::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVect ...@@ -87,6 +87,6 @@ void op::Softmax::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVect
auto adjoint = z - builder::make_with_numpy_broadcast<op::Multiply>(output(0), zreshape); auto adjoint = z - builder::make_with_numpy_broadcast<op::Multiply>(output(0), zreshape);
auto x = input(0).get_source_output(); auto x = input_value(0);
adjoints.add_delta(x, adjoint); adjoints.add_delta(x, adjoint);
} }
...@@ -39,7 +39,7 @@ void op::Sqrt::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& ...@@ -39,7 +39,7 @@ void op::Sqrt::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector&
{ {
auto delta = deltas.at(0); auto delta = deltas.at(0);
auto x = input(0).get_source_output(); auto x = input_value(0);
adjoints.add_delta(x, delta / (shared_from_this() + shared_from_this())); adjoints.add_delta(x, delta / (shared_from_this() + shared_from_this()));
} }
...@@ -45,8 +45,8 @@ void op::Subtract::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVec ...@@ -45,8 +45,8 @@ void op::Subtract::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVec
auto delta = deltas.at(0); auto delta = deltas.at(0);
auto x = input(0).get_source_output(); auto x = input_value(0);
auto y = input(1).get_source_output(); auto y = input_value(1);
adjoints.add_delta(x, delta); adjoints.add_delta(x, delta);
adjoints.add_delta(y, -delta); adjoints.add_delta(y, -delta);
......
...@@ -44,8 +44,8 @@ void op::Sum::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& ...@@ -44,8 +44,8 @@ void op::Sum::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector&
{ {
auto delta = deltas.at(0); auto delta = deltas.at(0);
auto x = get_argument(0); auto x = input_value(0);
auto& x_shape = input(0).get_shape(); auto& x_shape = x.get_shape();
adjoints.add_delta(x, make_shared<op::Broadcast>(delta, x_shape, get_reduction_axes())); adjoints.add_delta(x, make_shared<op::Broadcast>(delta, x_shape, get_reduction_axes()));
} }
...@@ -40,7 +40,7 @@ void op::Tan::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& ...@@ -40,7 +40,7 @@ void op::Tan::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector&
{ {
auto delta = deltas.at(0); auto delta = deltas.at(0);
auto x = input(0).get_source_output(); auto x = input_value(0);
auto c = make_shared<op::Cos>(x); auto c = make_shared<op::Cos>(x);
......
...@@ -39,7 +39,7 @@ void op::Tanh::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& ...@@ -39,7 +39,7 @@ void op::Tanh::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector&
{ {
auto delta = deltas.at(0); auto delta = deltas.at(0);
auto x = input(0).get_source_output(); auto x = input_value(0);
adjoints.add_delta(x, delta - (delta * (shared_from_this() * shared_from_this()))); adjoints.add_delta(x, delta - (delta * (shared_from_this() * shared_from_this())));
} }
...@@ -59,8 +59,7 @@ op::TopK::TopK(const Output<Node>& arg, ...@@ -59,8 +59,7 @@ op::TopK::TopK(const Output<Node>& arg,
size_t op::TopK::get_k() const size_t op::TopK::get_k() const
{ {
size_t k = 0; size_t k = 0;
if (auto const_op = if (auto const_op = dynamic_pointer_cast<op::Constant>(input_value(1).get_node_shared_ptr()))
dynamic_pointer_cast<op::Constant>(input(1).get_source_output().get_node_shared_ptr()))
{ {
k = const_op->get_vector<int64_t>()[0]; k = const_op->get_vector<int64_t>()[0];
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment